diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -122,10 +122,31 @@ }]; } -def LowerMaskOp : TransformWithPatternsOp<"vector.lower_mask"> { +def LowerMasksOp : TransformWithPatternsOp<"vector.lower_masks"> { let description = [{ - Indicates that the vector mask operations nested under the isolated from - above op `target` should be lowered to finer-grained vector primitives. + Indicates that the vector.create_mask and vector.constant_mask operations + nested under the isolated from above op `target` should be lowered to + finer-grained vector primitives. + + This is usually a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; +} + +def LowerMaskedTransfersOp : TransformWithPatternsOp<"vector.lower_masked_transfers"> { + let description = [{ + Indicates that masked vector.transfer and vector.gather operations nested + under the isolated from above op `target` should be lowered to finer-grained + vector primitives. This is usually a late step that is run after bufferization as part of the process of lowering to e.g. LLVM or NVVM. diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -63,11 +63,19 @@ } //===----------------------------------------------------------------------===// -// LowerMaskOp +// LowerMasksOp //===----------------------------------------------------------------------===// -void transform::LowerMaskOp::populatePatterns(RewritePatternSet &patterns) { +void transform::LowerMasksOp::populatePatterns(RewritePatternSet &patterns) { populateVectorMaskOpLoweringPatterns(patterns); +} + +//===----------------------------------------------------------------------===// +// LowerMaskedTransfersOp +//===----------------------------------------------------------------------===// + +void transform::LowerMaskedTransfersOp::populatePatterns( + RewritePatternSet &patterns) { populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); } diff --git a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir @@ -96,6 +96,36 @@ %f = transform.structured.match ops{["func.func"]} in %module_op : (!pdl.operation) -> !pdl.operation - transform.vector.lower_mask %f + transform.vector.lower_masks %f + : (!pdl.operation) -> !pdl.operation +} + +// ----- + +// CHECK-LABEL: func @transfer_read_3d( +func.func @transfer_read_3d( + %t: tensor, %arg0: index, %arg1: index, %arg2: index) + -> vector<2x1x7xf32> { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + // CHECK: %[[mask:.*]] = vector.create_mask + // CHECK-NOT: vector.mask + // CHECK: vector.transfer_read {{.*}}, %[[mask]] {in_bounds = [true, true, true]} + // CHECK-SAME: : tensor, vector<2x1x7xf32> + %0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1> + %1 = vector.mask %0 { + vector.transfer_read %t[%c0, %c0, %c0], %f0 {in_bounds = [true, true, true]} + : tensor, vector<2x1x7xf32> + } : vector<2x1x7xi1> -> vector<2x1x7xf32> + + return %1: vector<2x1x7xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + transform.vector.lower_masked_transfers %f : (!pdl.operation) -> !pdl.operation }