diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -122,6 +122,17 @@ }]; } +def ApplyLowerCreateMaskPatternsOp : Op]> { + let description = [{ + Indicates that vector create_mask-like operations should be lowered to + finer-grained vector primitives. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyLowerMasksPatternsOp : Op]> { diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -64,6 +64,11 @@ vector::populateVectorReductionToContractPatterns(patterns); } +void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorMaskOpLoweringPatterns(patterns); +} + void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorTransferDropUnitDimsPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -58,13 +58,15 @@ return rewriter.notifyMatchFailure( op, "0-D and 1-D vectors are handled separately"); + if (dstType.getScalableDims().front()) + return rewriter.notifyMatchFailure( + op, "Cannot unroll leading scalable dim in dstType"); + auto loc = op.getLoc(); - auto eltType = dstType.getElementType(); int64_t dim = dstType.getDimSize(0); Value idx = op.getOperand(0); - VectorType lowType = - VectorType::get(dstType.getShape().drop_front(), eltType); + VectorType lowType = VectorType::Builder(dstType).dropDim(0); Value trueVal = rewriter.create( loc, lowType, op.getOperands().drop_front()); Value falseVal = rewriter.create( diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -434,7 +434,7 @@ vectorShape.end()); for (unsigned i : broadcastedDims) unbroadcastedVectorShape[i] = 1; - VectorType unbroadcastedVectorType = VectorType::get( + VectorType unbroadcastedVectorType = read.getVectorType().cloneWith( unbroadcastedVectorShape, read.getVectorType().getElementType()); // `vector.load` supports vector types as memref's elements only when the diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1743,6 +1743,28 @@ // ----- +// CHECK-LABEL: func @transfer_read_1d_scalable_mask +// CHECK: %[[passtru:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32> +// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %[[passtru]] {alignment = 4 : i32} : (!llvm.ptr, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32> +// CHECK: return %[[r]] : vector<[4]xf32> +func.func @transfer_read_1d_scalable_mask(%arg0: memref<1x?xf32>, %mask: vector<[4]xi1>) -> vector<[4]xf32> { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + %vec = vector.transfer_read %arg0[%c0, %c0], %pad, %mask {in_bounds = [true]} : memref<1x?xf32>, vector<[4]xf32> + return %vec : vector<[4]xf32> +} + +// ----- +// CHECK-LABEL: func @transfer_write_1d_scalable_mask +// CHECK: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into !llvm.ptr +func.func @transfer_write_1d_scalable_mask(%arg0: memref<1x?xf32>, %vec: vector<[4]xf32>, %mask: vector<[4]xi1>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true]} : vector<[4]xf32>, memref<1x?xf32> + return +} + +// ----- + func.func @genbool_0d_f() -> vector { %0 = vector.constant_mask [0] : vector return %0 : vector diff --git a/mlir/test/Dialect/Vector/vector-scalable-create-mask-lowering.mlir b/mlir/test/Dialect/Vector/vector-scalable-create-mask-lowering.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-scalable-create-mask-lowering.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s + +// CHECK-LABEL: func.func @create_mask_2d_trailing_scalable( +// CHECK-SAME: %[[arg:.*]]: index) -> vector<3x[4]xi1> { +// CHECK-NEXT: %[[zero_mask_1d:.*]] = arith.constant dense : vector<[4]xi1> +// CHECK-NEXT: %[[zero_mask_2d:.*]] = arith.constant dense : vector<3x[4]xi1> +// CHECK-NEXT: %[[create_mask_1d:.*]] = vector.create_mask %[[arg]] : vector<[4]xi1> +// CHECK-NEXT: %[[res_0:.*]] = vector.insert %[[create_mask_1d]], %[[zero_mask_2d]] [0] : vector<[4]xi1> into vector<3x[4]xi1> +// CHECK-NEXT: %[[res_1:.*]] = vector.insert %[[create_mask_1d]], %[[res_0]] [1] : vector<[4]xi1> into vector<3x[4]xi1> +// CHECK-NEXT: %[[res_2:.*]] = vector.insert %[[zero_mask_1d]], %[[res_1]] [2] : vector<[4]xi1> into vector<3x[4]xi1> +// CHECK-NEXT: return %[[res_2]] : vector<3x[4]xi1> +func.func @create_mask_2d_trailing_scalable(%a: index) -> vector<3x[4]xi1> { + %c2 = arith.constant 2 : index + %mask = vector.create_mask %c2, %a : vector<3x[4]xi1> + return %mask : vector<3x[4]xi1> +} + +// ----- + +/// The following cannot be lowered as the current lowering requires unrolling +/// the leading dim. + +// CHECK-LABEL: func.func @cannot_create_mask_2d_leading_scalable( +// CHECK-SAME: %[[arg:.*]]: index) -> vector<[4]x4xi1> { +// CHECK: %{{.*}} = vector.create_mask %[[arg]], %{{.*}} : vector<[4]x4xi1> +func.func @cannot_create_mask_2d_leading_scalable(%a: index) -> vector<[4]x4xi1> { + %c1 = arith.constant 1 : index + %mask = vector.create_mask %a, %c1 : vector<[4]x4xi1> + return %mask : vector<[4]x4xi1> +} + +transform.sequence failures(suppress) { +^bb1(%module_op: !transform.any_op): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + + transform.apply_patterns to %f { + transform.apply_patterns.vector.lower_create_mask + } : !transform.any_op +}