diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -74,54 +74,6 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); -/// Collect a set of vector.shape_cast folding patterns. -void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - -/// Cast away the leading unit dim, if exists, for the given contract op. -/// Return success if the transformation applies; return failure otherwise. -LogicalResult castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, - RewriterBase &rewriter); - -/// Collect a set of leading one dimension removal patterns. -/// -/// These patterns insert vector.shape_cast to remove leading one dimensions -/// to expose more canonical forms of read/write/insert/extract operations. -/// With them, there are more chances that we can cancel out extract-insert -/// pairs or forward write-read pairs. -void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - -/// Collect a set of one dimension removal patterns. -/// -/// These patterns insert rank-reducing memref.subview ops to remove one -/// dimensions. With them, there are more chances that we can avoid -/// potentially exensive vector.shape_cast operations. -void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - -/// Collect a set of patterns to flatten n-D vector transfers on contiguous -/// memref. -/// -/// These patterns insert memref.collapse_shape + vector.shape_cast patterns -/// to transform multiple small n-D transfers into a larger 1-D transfer where -/// the memref contiguity properties allow it. -void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - -/// Collect a set of patterns that bubble up/down bitcast ops. -/// -/// These patterns move vector.bitcast ops to be before insert ops or after -/// extract ops where suitable. With them, bitcast will happen on smaller -/// vectors and there are more chances to share extract/insert ops. -void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - -/// These patterns materialize masks for various vector ops such as transfers. -void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, - bool force32BitVectorIndices, - PatternBenefit benefit = 1); - /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -13,12 +13,12 @@ #include #include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc" + namespace mlir { class RewritePatternSet; @@ -57,7 +57,7 @@ } /// Function that returns the traversal order (in terms of "for loop order", - /// i.e. slowest varying dimension to fastest varying dimension) that shoudl + /// i.e. slowest varying dimension to fastest varying dimension) that should /// be used when unrolling the given operation into units of the native vector /// size. using UnrollTraversalOrderFnType = @@ -70,10 +70,6 @@ } }; -//===----------------------------------------------------------------------===// -// Vector transformation exposed as populate functions over rewrite patterns. -//===----------------------------------------------------------------------===// - /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul /// semantics to a contraction with MMT semantics (matrix matrix multiplication /// with the RHS transposed). This specific form is meant to have the vector @@ -134,10 +130,6 @@ void populateVectorTransferFullPartialPatterns( RewritePatternSet &patterns, const VectorTransformsOptions &options); -//===----------------------------------------------------------------------===// -// Vector.transfer patterns. -//===----------------------------------------------------------------------===// - /// Collect a set of patterns to reduce the rank of the operands of vector /// transfer ops to operate on the largest contigious vector. /// These patterns are useful when lowering to dialects with 1d vector type @@ -263,6 +255,49 @@ const UnrollVectorOptions &options, PatternBenefit benefit = 1); +/// Collect a set of vector.shape_cast folding patterns. +void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Collect a set of leading one dimension removal patterns. +/// +/// These patterns insert vector.shape_cast to remove leading one dimensions +/// to expose more canonical forms of read/write/insert/extract operations. +/// With them, there are more chances that we can cancel out extract-insert +/// pairs or forward write-read pairs. +void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Collect a set of one dimension removal patterns. +/// +/// These patterns insert rank-reducing memref.subview ops to remove one +/// dimensions. With them, there are more chances that we can avoid +/// potentially expensive vector.shape_cast operations. +void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Collect a set of patterns to flatten n-D vector transfers on contiguous +/// memref. +/// +/// These patterns insert memref.collapse_shape + vector.shape_cast patterns +/// to transform multiple small n-D transfers into a larger 1-D transfer where +/// the memref contiguity properties allow it. +void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Collect a set of patterns that bubble up/down bitcast ops. +/// +/// These patterns move vector.bitcast ops to be before insert ops or after +/// extract ops where suitable. With them, bitcast will happen on smaller +/// vectors and there are more chances to share extract/insert ops. +void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// These patterns materialize masks for various vector ops such as transfers. +void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, + bool force32BitVectorIndices, + PatternBenefit benefit = 1); + } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h @@ -27,6 +27,7 @@ //===----------------------------------------------------------------------===// // Vector transformation options exposed as auxiliary structs. //===----------------------------------------------------------------------===// + /// Structure to control the behavior of vector transform patterns. struct VectorTransformsOptions { /// Option to control the lowering of vector.contract. @@ -63,6 +64,7 @@ //===----------------------------------------------------------------------===// // Standalone transformations and helpers. //===----------------------------------------------------------------------===// + /// Split a vector.transfer operation into an in-bounds (i.e., no /// out-of-bounds masking) fastpath and a slowpath. If `ifOp` is not null and /// the result is `success, the `ifOp` points to the newly created conditional @@ -106,6 +108,11 @@ /// optimizations. void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp); +/// Cast away the leading unit dim, if exists, for the given contract op. +/// Return success if the transformation applies; return failure otherwise. +LogicalResult castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, + RewriterBase &rewriter); + } // namespace vector } // namespace mlir diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -9,9 +9,9 @@ #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #define DEBUG_TYPE "vector-drop-unit-dim"