diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -15,6 +15,21 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def ApplyCastAwayVectorLeadingOneDimPatternsOp : Op]> { + let description = [{ + Collect a set of leading one dimension removal patterns. + + These patterns insert vector.shape_cast to remove leading one dimensions + to expose more canonical forms of read/write/insert/extract operations. + With them, there are more chances that we can cancel out extract-insert + pairs or forward write-read pairs. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyRankReducingSubviewPatternsOp : Op]> { diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -27,6 +27,11 @@ // Apply...PatternsOp //===----------------------------------------------------------------------===// +void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); +} + void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorTransferDropUnitDimsPatterns(patterns);