diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1881,8 +1881,9 @@ Value v1, Value v2, ArrayAttr mask, ArrayRef attrs) { auto containerType = v1.getType(); - auto vType = LLVM::getFixedVectorType( - LLVM::getVectorElementType(containerType), mask.size()); + auto vType = LLVM::getVectorType( + LLVM::getVectorElementType(containerType), mask.size(), + containerType.cast().isScalable()); build(b, result, vType, v1, v2, mask); result.addAttributes(attrs); } @@ -1914,8 +1915,9 @@ if (!LLVM::isCompatibleVectorType(typeV1)) return parser.emitError( loc, "expected LLVM IR dialect vector type for operand #1"); - auto vType = LLVM::getFixedVectorType(LLVM::getVectorElementType(typeV1), - maskAttr.size()); + auto vType = + LLVM::getVectorType(LLVM::getVectorElementType(typeV1), maskAttr.size(), + typeV1.cast().isScalable()); result.addTypes(vType); return success(); } @@ -1925,6 +1927,11 @@ Type type2 = getV2().getType(); if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2)) return emitOpError("expected matching LLVM IR Dialect element types"); + if (LLVM::isScalableVectorType(type1)) + if (llvm::any_of(getMask(), [](Attribute attr) { + return attr.cast().getInt() != 0; + })) + return emitOpError("on scalable vectors only admits splat operations"); return success(); } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1250,3 +1250,11 @@ llvm.getelementptr %ptr[%idx, 1, 3] : (!llvm.ptr)>>, i64) -> !llvm.ptr return } + +// ----- + +func @non_splat_shuffle_on_scalable_vector(%arg0: vector<[4]xf32>) { + // expected-error@+1 {{op on scalable vectors only admits splat operations}} + %0 = llvm.shufflevector %arg0, %arg0 [0 : i32, 0 : i32, 0 : i32, 1 : i32] : vector<[4]xf32>, vector<[4]xf32> + return +} \ No newline at end of file diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -281,6 +281,19 @@ return } +// CHECK-LABEL: @scalable_vect +func @scalable_vect(%arg0: vector<[4]xf32>, %arg1: i32, %arg2: f32) { +// CHECK: = llvm.extractelement {{.*}} : vector<[4]xf32> + %0 = llvm.extractelement %arg0[%arg1 : i32] : vector<[4]xf32> +// CHECK: = llvm.insertelement {{.*}} : vector<[4]xf32> + %1 = llvm.insertelement %arg2, %arg0[%arg1 : i32] : vector<[4]xf32> +// CHECK: = llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<[4]xf32>, vector<[4]xf32> + %2 = llvm.shufflevector %arg0, %arg0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<[4]xf32>, vector<[4]xf32> +// CHECK: = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32> + %3 = llvm.mlir.constant(dense<1.0> : vector<[4]xf32>) : vector<[4]xf32> + return +} + // CHECK-LABEL: @alloca func @alloca(%size : i64) { // CHECK: llvm.alloca %{{.*}} x i32 : (i64) -> !llvm.ptr diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1168,6 +1168,26 @@ llvm.return } +// CHECK-LABEL: @scalable_vect +llvm.func @scalable_vect(%arg0: vector<[4]xf32>, %arg1: i32, %arg2: f32) { + // CHECK-NEXT: extractelement {{.*}}, i32 + // CHECK-NEXT: insertelement {{.*}}, float %2, i32 + // CHECK-NEXT: shufflevector %0, %0, zeroinitializer + %0 = llvm.extractelement %arg0[%arg1 : i32] : vector<[4]xf32> + %1 = llvm.insertelement %arg2, %arg0[%arg1 : i32] : vector<[4]xf32> + %2 = llvm.shufflevector %arg0, %arg0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<[4]xf32>, vector<[4]xf32> + llvm.return +} + +// CHECK-LABEL: @scalable_vect_i64idx +llvm.func @scalable_vect_i64idx(%arg0: vector<[4]xf32>, %arg1: i64, %arg2: f32) { + // CHECK-NEXT: extractelement {{.*}}, i64 + // CHECK-NEXT: insertelement {{.*}}, float %2, i64 + %0 = llvm.extractelement %arg0[%arg1 : i64] : vector<[4]xf32> + %1 = llvm.insertelement %arg2, %arg0[%arg1 : i64] : vector<[4]xf32> + llvm.return +} + // CHECK-LABEL: @alloca llvm.func @alloca(%size : i64) { // Alignment automatically set by the LLVM IR builder when alignment attribute