diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -408,6 +408,75 @@ let assemblyFormat = "attr-dict `:` type($res)"; } +/// Create a call to vector.insert intrinsic +def LLVM_vector_insert + : LLVM_Op<"intr.vector.insert", + [NoSideEffect, AllTypesMatch<["dstvec", "res"]>, + PredOpTrait<"vectors are not bigger than 2^17 bits.", And<[ + CPred<"getSrcVectorBitWidth() <= 131072">, + CPred<"getDstVectorBitWidth() <= 131072"> + ]>>, + PredOpTrait<"it is not inserting scalable into fixed-length vectors.", + CPred<"!isScalableVectorType($srcvec.getType()) || " + "isScalableVectorType($dstvec.getType())">>]> { + let arguments = (ins LLVM_AnyVector:$srcvec, LLVM_AnyVector:$dstvec, + I64Attr:$pos); + let results = (outs LLVM_AnyVector:$res); + let builders = [LLVM_OneResultOpBuilder]; + string llvmBuilder = [{ + $res = builder.CreateInsertVector( + $_resultType, $dstvec, $srcvec, builder.getInt64($pos)); + }]; + let assemblyFormat = "$srcvec `,` $dstvec `[` $pos `]` attr-dict `:` " + "type($srcvec) `into` type($res)"; + let extraClassDeclaration = [{ + uint64_t getVectorBitWidth(Type vector) { + return getVectorNumElements(vector).getKnownMinValue() * + getVectorElementType(vector).getIntOrFloatBitWidth(); + } + uint64_t getSrcVectorBitWidth() { + return getVectorBitWidth(getSrcvec().getType()); + } + uint64_t getDstVectorBitWidth() { + return getVectorBitWidth(getDstvec().getType()); + } + }]; +} + +/// Create a call to vector.extract intrinsic +def LLVM_vector_extract + : LLVM_Op<"intr.vector.extract", + [NoSideEffect, + PredOpTrait<"vectors are not bigger than 2^17 bits.", And<[ + CPred<"getSrcVectorBitWidth() <= 131072">, + CPred<"getResVectorBitWidth() <= 131072"> + ]>>, + PredOpTrait<"it is not extracting scalable from fixed-length vectors.", + CPred<"!isScalableVectorType($res.getType()) || " + "isScalableVectorType($srcvec.getType())">>]> { + let arguments = (ins LLVM_AnyVector:$srcvec, I64Attr:$pos); + let results = (outs LLVM_AnyVector:$res); + let builders = [LLVM_OneResultOpBuilder]; + string llvmBuilder = [{ + $res = builder.CreateExtractVector( + $_resultType, $srcvec, builder.getInt64($pos)); + }]; + let assemblyFormat = "$srcvec `[` $pos `]` attr-dict `:` " + "type($res) `from` type($srcvec)"; + let extraClassDeclaration = [{ + uint64_t getVectorBitWidth(Type vector) { + return getVectorNumElements(vector).getKnownMinValue() * + getVectorElementType(vector).getIntOrFloatBitWidth(); + } + uint64_t getSrcVectorBitWidth() { + return getVectorBitWidth(getSrcvec().getType()); + } + uint64_t getResVectorBitWidth() { + return getVectorBitWidth(getRes().getType()); + } + }]; +} + // // LLVM Vector Predication operations. // diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -162,6 +162,16 @@ def LLVM_AnyVector : Type, "LLVM dialect-compatible vector type">; +// Type constraint accepting any LLVM fixed-length vector type. +def LLVM_AnyFixedVector : Type, + "LLVM dialect-compatible fixed-length vector type">; + +// Type constraint accepting any LLVM scalable vector type. +def LLVM_AnyScalableVector : Type, + "LLVM dialect-compatible scalable vector type">; + // Type constraint accepting an LLVM vector type with an additional constraint // on the vector element type. class LLVM_VectorOf : Type< diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1363,3 +1363,45 @@ func.func @invalid_res_struct_attr_size(%arg0 : !llvm.struct<(i32)>) -> (!llvm.struct<(i32)> {llvm.struct_attrs = []}) { return %arg0 : !llvm.struct<(i32)> } + +// ----- + +func.func @insert_vector_invalid_source_vector_size(%arg0 : vector<16385xi8>, %arg1 : vector<[16]xi8>) { + // expected-error@+1 {{op failed to verify that vectors are not bigger than 2^17 bits.}} + %0 = llvm.intr.vector.insert %arg0, %arg1[0] : vector<16385xi8> into vector<[16]xi8> +} + +// ----- + +func.func @insert_vector_invalid_dest_vector_size(%arg0 : vector<16xi8>, %arg1 : vector<[16385]xi8>) { + // expected-error@+1 {{op failed to verify that vectors are not bigger than 2^17 bits.}} + %0 = llvm.intr.vector.insert %arg0, %arg1[0] : vector<16xi8> into vector<[16385]xi8> +} + +// ----- + +func.func @insert_scalable_into_fixed_length_vector(%arg0 : vector<[8]xf32>, %arg1 : vector<16xf32>) { + // expected-error@+1 {{op failed to verify that it is not inserting scalable into fixed-length vectors.}} + %0 = llvm.intr.vector.insert %arg0, %arg1[0] : vector<[8]xf32> into vector<16xf32> +} + +// ----- + +func.func @extract_vector_invalid_source_vector_size(%arg0 : vector<[16385]xi8>) { + // expected-error@+1 {{op failed to verify that vectors are not bigger than 2^17 bits.}} + %0 = llvm.intr.vector.extract %arg0[0] : vector<16xi8> from vector<[16385]xi8> +} + +// ----- + +func.func @extract_vector_invalid_result_vector_size(%arg0 : vector<[16]xi8>) { + // expected-error@+1 {{op failed to verify that vectors are not bigger than 2^17 bits.}} + %0 = llvm.intr.vector.extract %arg0[0] : vector<16385xi8> from vector<[16]xi8> +} + +// ----- + +func.func @extract_scalable_from_fixed_length_vector(%arg0 : vector<16xf32>) { + // expected-error@+1 {{op failed to verify that it is not extracting scalable from fixed-length vectors.}} + %0 = llvm.intr.vector.extract %arg0[0] : vector<[8]xf32> from vector<16xf32> +} diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -305,6 +305,23 @@ return } +// CHECK-LABEL: @mixed_vect +func.func @mixed_vect(%arg0: vector<8xf32>, %arg1: vector<4xf32>, %arg2: vector<[4]xf32>) { + // CHECK: = llvm.intr.vector.insert {{.*}} : vector<8xf32> into vector<[4]xf32> + %0 = llvm.intr.vector.insert %arg0, %arg2[0] : vector<8xf32> into vector<[4]xf32> + // CHECK: = llvm.intr.vector.insert {{.*}} : vector<4xf32> into vector<[4]xf32> + %1 = llvm.intr.vector.insert %arg1, %arg2[0] : vector<4xf32> into vector<[4]xf32> + // CHECK: = llvm.intr.vector.insert {{.*}} : vector<4xf32> into vector<[4]xf32> + %2 = llvm.intr.vector.insert %arg1, %1[4] : vector<4xf32> into vector<[4]xf32> + // CHECK: = llvm.intr.vector.insert {{.*}} : vector<4xf32> into vector<8xf32> + %3 = llvm.intr.vector.insert %arg1, %arg0[4] : vector<4xf32> into vector<8xf32> + // CHECK: = llvm.intr.vector.extract {{.*}} : vector<8xf32> from vector<[4]xf32> + %4 = llvm.intr.vector.extract %2[0] : vector<8xf32> from vector<[4]xf32> + // CHECK: = llvm.intr.vector.extract {{.*}} : vector<2xf32> from vector<8xf32> + %5 = llvm.intr.vector.extract %arg0[6] : vector<2xf32> from vector<8xf32> + return +} + // CHECK-LABEL: @alloca func.func @alloca(%size : i64) { // CHECK: llvm.alloca %{{.*}} x i32 : (i64) -> !llvm.ptr diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -680,6 +680,33 @@ llvm.return } +// CHECK-LABEL: @vector_insert_extract +llvm.func @vector_insert_extract(%f256: vector<8xi32>, %f128: vector<4xi32>, + %sv: vector<[4]xi32>) { + // CHECK: call @llvm.vector.insert.nxv4i32.v8i32 + %0 = llvm.intr.vector.insert %f256, %sv[0] : + vector<8xi32> into vector<[4]xi32> + // CHECK: call @llvm.vector.insert.nxv4i32.v4i32 + %1 = llvm.intr.vector.insert %f128, %sv[0] : + vector<4xi32> into vector<[4]xi32> + // CHECK: call @llvm.vector.insert.nxv4i32.v4i32 + %2 = llvm.intr.vector.insert %f128, %1[4] : + vector<4xi32> into vector<[4]xi32> + // CHECK: call <8 x i32> @llvm.vector.insert.v8i32.v4i32 + %3 = llvm.intr.vector.insert %f128, %f256[4] : + vector<4xi32> into vector<8xi32> + // CHECK: call <8 x i32> @llvm.vector.extract.v8i32.nxv4i32 + %4 = llvm.intr.vector.extract %2[0] : + vector<8xi32> from vector<[4]xi32> + // CHECK: call <4 x i32> @llvm.vector.extract.v4i32.nxv4i32 + %5 = llvm.intr.vector.extract %2[0] : + vector<4xi32> from vector<[4]xi32> + // CHECK: call <2 x i32> @llvm.vector.extract.v2i32.v8i32 + %6 = llvm.intr.vector.extract %f256[6] : + vector<2xi32> from vector<8xi32> + llvm.return +} + // Check that intrinsics are declared with appropriate types. // CHECK-DAG: declare float @llvm.fma.f32(float, float, float) // CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0 @@ -781,3 +808,9 @@ // CHECK-DAG: declare <8 x i64> @llvm.vp.fptosi.v8i64.v8f64(<8 x double>, <8 x i1>, i32) #2 // CHECK-DAG: declare <8 x i64> @llvm.vp.ptrtoint.v8i64.v8p0(<8 x ptr>, <8 x i1>, i32) #2 // CHECK-DAG: declare <8 x ptr> @llvm.vp.inttoptr.v8p0.v8i64(<8 x i64>, <8 x i1>, i32) #2 +// CHECK-DAG: declare @llvm.vector.insert.nxv4i32.v8i32(, <8 x i32>, i64 immarg) #2 +// CHECK-DAG: declare @llvm.vector.insert.nxv4i32.v4i32(, <4 x i32>, i64 immarg) #2 +// CHECK-DAG: declare <8 x i32> @llvm.vector.insert.v8i32.v4i32(<8 x i32>, <4 x i32>, i64 immarg) #2 +// CHECK-DAG: declare <8 x i32> @llvm.vector.extract.v8i32.nxv4i32(, i64 immarg) #2 +// CHECK-DAG: declare <4 x i32> @llvm.vector.extract.v4i32.nxv4i32(, i64 immarg) #2 +// CHECK-DAG: declare <2 x i32> @llvm.vector.extract.v2i32.v8i32(<8 x i32>, i64 immarg) #2