diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -58,13 +58,15 @@ return rewriter.notifyMatchFailure( op, "0-D and 1-D vectors are handled separately"); + if (dstType.getScalableDims().front()) + return rewriter.notifyMatchFailure( + op, "Cannot unroll leading scalable dim in dstType"); + auto loc = op.getLoc(); - auto eltType = dstType.getElementType(); int64_t dim = dstType.getDimSize(0); Value idx = op.getOperand(0); - VectorType lowType = - VectorType::get(dstType.getShape().drop_front(), eltType); + VectorType lowType = VectorType::Builder(dstType).dropDim(0); Value trueVal = rewriter.create( loc, lowType, op.getOperands().drop_front()); Value falseVal = rewriter.create( diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -434,7 +434,7 @@ vectorShape.end()); for (unsigned i : broadcastedDims) unbroadcastedVectorShape[i] = 1; - VectorType unbroadcastedVectorType = VectorType::get( + VectorType unbroadcastedVectorType = read.getVectorType().cloneWith( unbroadcastedVectorShape, read.getVectorType().getElementType()); // `vector.load` supports vector types as memref's elements only when the diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1743,6 +1743,28 @@ // ----- +// CHECK-LABEL: func @transfer_read_1d_scalable_mask +// CHECK: %[[passtru:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32> +// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %[[passtru]] {alignment = 4 : i32} : (!llvm.ptr, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32> +// CHECK: return %[[r]] : vector<[4]xf32> +func.func @transfer_read_1d_scalable_mask(%arg0: memref<1x?xf32>, %mask: vector<[4]xi1>) -> vector<[4]xf32> { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + %vec = vector.transfer_read %arg0[%c0, %c0], %pad, %mask {in_bounds = [true]} : memref<1x?xf32>, vector<[4]xf32> + return %vec : vector<[4]xf32> +} + +// ----- +// CHECK-LABEL: func @transfer_write_1d_scalable_mask +// CHECK: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into !llvm.ptr +func.func @transfer_write_1d_scalable_mask(%arg0: memref<1x?xf32>, %vec: vector<[4]xf32>, %mask: vector<[4]xi1>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true]} : vector<[4]xf32>, memref<1x?xf32> + return +} + +// ----- + func.func @genbool_0d_f() -> vector { %0 = vector.constant_mask [0] : vector return %0 : vector @@ -1847,6 +1869,42 @@ // ----- +// CHECK-LABEL: func.func @create_mask_2d_trailing_scalable( +// CHECK-SAME: %[[arg:.*]]: index) -> vector<1x[4]xi1> { +// CHECK: %[[zero_mask2d:.*]] = arith.constant dense : vector<1x[4]xi1> +// CHECK: %[[zero_llvm_mask2d:.*]] = builtin.unrealized_conversion_cast %[[zero_mask2d]] : vector<1x[4]xi1> to !llvm.array<1 x vector<[4]xi1>> +// CHECK: %[[indices:.*]] = llvm.intr.experimental.stepvector : vector<[4]xi32> +// CHECK: %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32 +// CHECK: %[[undef:.*]] = llvm.mlir.undef : vector<[4]xi32> +// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[arg_vec:.*]] = llvm.insertelement %[[arg_i32]], %[[undef]]{{\[}}%[[c0]] : i32] : vector<[4]xi32> +// CHECK: %[[splat_arg:.*]] = llvm.shufflevector %[[arg_vec]], %[[undef]] [0, 0, 0, 0] : vector<[4]xi32> +// CHECK: %[[llvm_mask1d:.*]] = arith.cmpi slt, %[[indices]], %[[splat_arg]] : vector<[4]xi32> +// CHECK: %[[llvm_mask2d:.*]] = llvm.insertvalue %[[llvm_mask1d]], %[[zero_llvm_mask2d]][0] : !llvm.array<1 x vector<[4]xi1>> +// CHECK: %[[mask2d:.*]] = builtin.unrealized_conversion_cast %[[llvm_mask2d]] : !llvm.array<1 x vector<[4]xi1>> to vector<1x[4]xi1> +// CHECK: return %[[mask2d]] : vector<1x[4]xi1> +func.func @create_mask_2d_trailing_scalable(%a: index) -> vector<1x[4]xi1> { + %c1 = arith.constant 1 : index + %mask = vector.create_mask %c1, %a : vector<1x[4]xi1> + return %mask : vector<1x[4]xi1> +} + +// ----- + +/// The following cannot be lowered as the current lowering requires unrolling +/// the leading dim. + +// CHECK-LABEL: func.func @cannot_create_mask_2d_leading_scalable( +// CHECK-SAME: %[[arg:.*]]: index) -> vector<[4]x4xi1> { +// CHECK: %{{.*}} = vector.create_mask %[[arg]], %{{.*}} : vector<[4]x4xi1> +func.func @cannot_create_mask_2d_leading_scalable(%a: index) -> vector<[4]x4xi1> { + %c1 = arith.constant 1 : index + %mask = vector.create_mask %a, %c1 : vector<[4]x4xi1> + return %mask : vector<[4]x4xi1> +} + +// ----- + func.func @transpose_0d(%arg0: vector) -> vector { %0 = vector.transpose %arg0, [] : vector to vector return %0 : vector