diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -62,11 +62,12 @@ std::pair *mismatchingDims = nullptr); /// Collect a set of vector-to-vector canonicalization patterns. -void populateVectorToVectorCanonicalizationPatterns( - RewritePatternSet &patterns); +void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Collect a set of vector.shape_cast folding patterns. -void populateShapeCastFoldingPatterns(RewritePatternSet &patterns); +void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Collect a set of leading one dimension removal patterns. /// @@ -74,14 +75,16 @@ /// to expose more canonical forms of read/write/insert/extract operations. /// With them, there are more chances that we can cancel out extract-insert /// pairs or forward write-read pairs. -void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns); +void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Collect a set of one dimension removal patterns. /// /// These patterns insert rank-reducing memref.subview ops to remove one /// dimensions. With them, there are more chances that we can avoid /// potentially exensive vector.shape_cast operations. -void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns); +void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Collect a set of patterns to flatten n-D vector transfers on contiguous /// memref. @@ -89,14 +92,16 @@ /// These patterns insert memref.collapse_shape + vector.shape_cast patterns /// to transform multiple small n-D transfers into a larger 1-D transfer where /// the memref contiguity properties allow it. -void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns); +void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Collect a set of patterns that bubble up/down bitcast ops. /// /// These patterns move vector.bitcast ops to be before insert ops or after /// extract ops where suitable. With them, bitcast will happen on smaller /// vectors and there are more chances to share extract/insert ops. -void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns); +void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Collect a set of transfer read/write lowering patterns. /// @@ -106,28 +111,34 @@ /// VectorToSCF, which reduces the rank of vector transfer ops. void populateVectorTransferLoweringPatterns( RewritePatternSet &patterns, - llvm::Optional maxTransferRank = llvm::None); + llvm::Optional maxTransferRank = llvm::None, + PatternBenefit benefit = 1); /// These patterns materialize masks for various vector ops such as transfers. void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, - bool force32BitVectorIndices); + bool force32BitVectorIndices, + PatternBenefit benefit = 1); /// Collect a set of patterns to propagate insert_map/extract_map in the ssa /// chain. -void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns); +void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Collects patterns to progressively lower vector.broadcast ops on high-D /// vectors to low-D vector ops. -void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns); +void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Collects patterns to progressively lower vector mask ops into elementary /// selection and insertion ops. -void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns); +void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Collects patterns to progressively lower vector.shape_cast ops on high-D /// vectors into 1-D/2-D vector ops by generating data movement extract/insert /// ops. -void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns); +void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h @@ -37,7 +37,8 @@ void populateWarpExecuteOnLane0OpToScfForPattern( RewritePatternSet &patterns, - const WarpExecuteOnLane0LoweringOptions &options); + const WarpExecuteOnLane0LoweringOptions &options, + PatternBenefit benefit = 1); using DistributionMapFn = std::function; @@ -59,7 +60,8 @@ /// } /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> void populateDistributeTransferWriteOpPatterns( - RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn); + RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, + PatternBenefit benefit = 1); /// Move scalar operations with no dependency on the warp op outside of the /// region. @@ -67,7 +69,7 @@ /// Collect patterns to propagate warp distribution. void populatePropagateWarpVectorDistributionPatterns( - RewritePatternSet &pattern); + RewritePatternSet &pattern, PatternBenefit benefit = 1); /// Lambda signature to compute a reduction of a distributed value for the given /// reduction kind and size. @@ -78,7 +80,8 @@ /// distribute reduction op. void populateDistributeReduction( RewritePatternSet &pattern, - const DistributedReductionFn &distributedReductionFn); + const DistributedReductionFn &distributedReductionFn, + PatternBenefit benefit = 1); } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -150,7 +150,8 @@ /// Insert TransposeLowering patterns into extraction/insertion. void populateVectorTransposeLoweringPatterns( RewritePatternSet &patterns, - VectorTransformsOptions options = VectorTransformsOptions()); + VectorTransformsOptions options = VectorTransformsOptions(), + PatternBenefit benefit = 1); /// Collect a set of patterns to convert vector.multi_reduction op into /// a sequence of vector.reduction ops. The patterns comprise: @@ -175,20 +176,24 @@ /// the other patterns can kick in, thus fully exiting out of the /// vector.multi_reduction abstraction. void populateVectorMultiReductionLoweringPatterns( - RewritePatternSet &patterns, VectorMultiReductionLowering options); + RewritePatternSet &patterns, VectorMultiReductionLowering options, + PatternBenefit benefit = 1); /// Collects patterns to progressively lower vector contraction ops on high-D /// into low-D reduction and product ops. void populateVectorContractLoweringPatterns( RewritePatternSet &patterns, - VectorTransformsOptions options = VectorTransformsOptions()); + VectorTransformsOptions options = VectorTransformsOptions(), + PatternBenefit benefit = 1); /// Collect patterns to convert reduction op to vector.contract and fold /// transpose/broadcast ops into the contract. -void populateVectorReductionToContractPatterns(RewritePatternSet &patterns); +void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Collect patterns to convert scan op -void populateVectorScanLoweringPatterns(RewritePatternSet &patterns); +void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); //===----------------------------------------------------------------------===// // Vector.transfer patterns. @@ -246,14 +251,14 @@ /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) /// vector.broadcast %v void populateVectorTransferPermutationMapLoweringPatterns( - RewritePatternSet &patterns); + RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Collect a set of patterns to reduce the rank of the operands of vector /// transfer ops to operate on the largest contigious vector. /// These patterns are useful when lowering to dialects with 1d vector type /// such as llvm and it will result fewer memory reads. void populateVectorTransferCollapseInnerMostContiguousDimsPatterns( - RewritePatternSet &patterns); + RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Populate `patterns` with the following patterns. /// @@ -278,7 +283,7 @@ /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case. void populateVectorInsertExtractStridedSliceDecompositionPatterns( - RewritePatternSet &patterns); + RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Populate `patterns` with the following patterns. /// @@ -299,7 +304,7 @@ /// ========================================= /// For such cases, we can lower it to a ShuffleOp. void populateVectorInsertExtractStridedSliceTransforms( - RewritePatternSet &patterns); + RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Collect a set of pattern to unroll vector operations to a smaller shapes. /// `options` structure controls which operations are unrolled and the target @@ -332,7 +337,8 @@ /// Other local patterns then kick in iteratively (including DCE) and compose /// to combine the ExtractStridedSlice/InsertStridedSlice. void populateVectorUnrollPatterns(RewritePatternSet &patterns, - const UnrollVectorOptions &options); + const UnrollVectorOptions &options, + PatternBenefit benefit = 1); //===----------------------------------------------------------------------===// // Finer-grained patterns exposed for more control over individual lowerings. @@ -377,7 +383,8 @@ class ContractionOpToMatmulOpLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; + using FilterConstraintType = std::function; @@ -387,8 +394,9 @@ ContractionOpToMatmulOpLowering( vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, FilterConstraintType constraint = defaultFilter) - : OpRewritePattern(context), + MLIRContext *context, PatternBenefit benefit = 1, + FilterConstraintType constraint = defaultFilter) + : OpRewritePattern(context, benefit), vectorTransformOptions(vectorTransformOptions), filter(std::move(constraint)) {} @@ -419,7 +427,8 @@ class ContractionOpToOuterProductOpLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; + using FilterConstraintType = std::function; @@ -429,8 +438,9 @@ ContractionOpToOuterProductOpLowering( vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, FilterConstraintType constraint = defaultFilter) - : OpRewritePattern(context), + MLIRContext *context, PatternBenefit benefit = 1, + FilterConstraintType constraint = defaultFilter) + : OpRewritePattern(context, benefit), vectorTransformOptions(vectorTransformOptions), filter(std::move(constraint)) {} @@ -464,7 +474,8 @@ class ContractionOpToDotLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; + using FilterConstraintType = std::function; @@ -474,9 +485,9 @@ ContractionOpToDotLowering( vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, + MLIRContext *context, PatternBenefit benefit = 1, const FilterConstraintType &constraint = defaultFilter) - : OpRewritePattern(context), + : OpRewritePattern(context, benefit), vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} LogicalResult matchAndRewrite(vector::ContractionOp op, @@ -504,7 +515,7 @@ /// to Dot or when other contraction patterns fail. class ContractionOpLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; using FilterConstraintType = std::function; @@ -513,9 +524,9 @@ } ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, + MLIRContext *context, PatternBenefit benefit = 1, FilterConstraintType constraint = defaultFilter) - : OpRewritePattern(context), + : OpRewritePattern(context, benefit), vectorTransformOptions(vectorTransformOptions), filter(std::move(constraint)) {} diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1464,7 +1464,7 @@ // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast. class ExtractOpFromBroadcast final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -1494,7 +1494,7 @@ // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp. class ExtractOpConstantFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -1681,7 +1681,7 @@ // Fold broadcast1(broadcast2(x)) into broadcast1(x). struct BroadcastFolder : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(BroadcastOp broadcastOp, PatternRewriter &rewriter) const override { @@ -1828,7 +1828,7 @@ // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector // to a broadcast. struct Canonicalize0DShuffleOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ShuffleOp shuffleOp, PatternRewriter &rewriter) const override { @@ -1852,7 +1852,7 @@ /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp. class ShuffleSplat final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ShuffleOp op, PatternRewriter &rewriter) const override { @@ -1979,7 +1979,7 @@ // broadcast. class InsertToBroadcast final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InsertOp insertOp, PatternRewriter &rewriter) const override { @@ -1996,7 +1996,7 @@ /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp. class InsertSplatToSplat final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InsertOp op, PatternRewriter &rewriter) const override { @@ -2202,7 +2202,7 @@ class FoldInsertStridedSliceSplat final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, PatternRewriter &rewriter) const override { @@ -2227,7 +2227,7 @@ class FoldInsertStridedSliceOfExtract final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, PatternRewriter &rewriter) const override { @@ -2587,7 +2587,7 @@ class StridedSliceConstantMaskFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, PatternRewriter &rewriter) const override { @@ -2640,7 +2640,7 @@ class StridedSliceConstantFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, PatternRewriter &rewriter) const override { @@ -2666,7 +2666,7 @@ class StridedSliceBroadcast final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -2709,7 +2709,7 @@ /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp. class StridedSliceSplat final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -3182,7 +3182,7 @@ struct FoldExtractSliceIntoTransferRead : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TransferReadOp xferOp, PatternRewriter &rewriter) const override { @@ -3279,7 +3279,7 @@ /// ``` struct TransferReadAfterWriteToBroadcast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TransferReadOp readOp, PatternRewriter &rewriter) const override { @@ -3628,7 +3628,7 @@ /// any other uses. class FoldWaw final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TransferWriteOp writeOp, PatternRewriter &rewriter) const override { if (!writeOp.getShapedType().isa()) @@ -3674,7 +3674,7 @@ struct FoldInsertSliceIntoTransferWrite : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, PatternRewriter &rewriter) const override { @@ -3768,7 +3768,7 @@ struct SwapExtractSliceOfTransferWrite : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, PatternRewriter &rewriter) const override { @@ -3947,7 +3947,7 @@ namespace { class MaskedLoadFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(MaskedLoadOp load, PatternRewriter &rewriter) const override { switch (getMaskFormat(load.getMask())) { @@ -3998,7 +3998,7 @@ namespace { class MaskedStoreFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(MaskedStoreOp store, PatternRewriter &rewriter) const override { switch (getMaskFormat(store.getMask())) { @@ -4056,7 +4056,7 @@ namespace { class GatherFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GatherOp gather, PatternRewriter &rewriter) const override { switch (getMaskFormat(gather.getMask())) { @@ -4102,7 +4102,7 @@ namespace { class ScatterFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ScatterOp scatter, PatternRewriter &rewriter) const override { switch (getMaskFormat(scatter.getMask())) { @@ -4148,7 +4148,7 @@ namespace { class ExpandLoadFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExpandLoadOp expand, PatternRewriter &rewriter) const override { switch (getMaskFormat(expand.getMask())) { @@ -4193,7 +4193,7 @@ namespace { class CompressStoreFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CompressStoreOp compress, PatternRewriter &rewriter) const override { switch (getMaskFormat(compress.getMask())) { @@ -4333,7 +4333,7 @@ // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp. class ShapeCastConstantFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { @@ -4359,7 +4359,7 @@ /// enough to capture the result in a single op). class ShapeCastBroadcastFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { @@ -4589,7 +4589,7 @@ // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. class TransposeFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -4651,7 +4651,7 @@ // Folds transpose(splat x : src_type) : res_type into splat x : res_type. class FoldTransposeSplat final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -4751,7 +4751,7 @@ // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. class CreateMaskFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, PatternRewriter &rewriter) const override { @@ -4850,12 +4850,12 @@ } void mlir::vector::populateVectorToVectorCanonicalizationPatterns( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, PatternBenefit benefit) { patterns .add( - patterns.getContext()); + patterns.getContext(), benefit); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1056,26 +1056,30 @@ void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( RewritePatternSet &patterns, - const WarpExecuteOnLane0LoweringOptions &options) { - patterns.add(patterns.getContext(), options); + const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) { + patterns.add(patterns.getContext(), options, benefit); } void mlir::vector::populateDistributeTransferWriteOpPatterns( - RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn) { - patterns.add(patterns.getContext(), distributionMapFn); + RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, + PatternBenefit benefit) { + patterns.add(patterns.getContext(), distributionMapFn, + benefit); } void mlir::vector::populatePropagateWarpVectorDistributionPatterns( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext()); + WarpOpScfForOp, WarpOpConstant>(patterns.getContext(), benefit); } void mlir::vector::populateDistributeReduction( RewritePatternSet &patterns, - const DistributedReductionFn &distributedReductionFn) { - patterns.add(patterns.getContext(), distributedReductionFn); + const DistributedReductionFn &distributedReductionFn, + PatternBenefit benefit) { + patterns.add(patterns.getContext(), distributedReductionFn, + benefit); } void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -401,8 +401,9 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern { public: - CastAwayElementwiseLeadingOneDim(MLIRContext *context) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + CastAwayElementwiseLeadingOneDim(MLIRContext *context, + PatternBenefit benefit = 1) + : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { @@ -436,12 +437,12 @@ } // namespace void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, PatternBenefit benefit) { patterns .add(patterns.getContext()); - populateShapeCastFoldingPatterns(patterns); + CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit); + populateShapeCastFoldingPatterns(patterns, benefit); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -286,15 +286,17 @@ }; void mlir::vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext()); + DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit); } /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::vector::populateVectorInsertExtractStridedSliceTransforms( - RewritePatternSet &patterns) { - populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns); + RewritePatternSet &patterns, PatternBenefit benefit) { + populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns, + benefit); patterns.add(patterns.getContext()); + Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(), + benefit); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp @@ -29,11 +29,12 @@ class InnerOuterDimReductionConversion : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; explicit InnerOuterDimReductionConversion( - MLIRContext *context, vector::VectorMultiReductionLowering options) - : mlir::OpRewritePattern(context), + MLIRContext *context, vector::VectorMultiReductionLowering options, + PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit), useInnerDimsForReduction( options == vector::VectorMultiReductionLowering::InnerReduction) {} @@ -101,11 +102,12 @@ class ReduceMultiDimReductionRank : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; explicit ReduceMultiDimReductionRank( - MLIRContext *context, vector::VectorMultiReductionLowering options) - : mlir::OpRewritePattern(context), + MLIRContext *context, vector::VectorMultiReductionLowering options, + PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit), useInnerDimsForReduction( options == vector::VectorMultiReductionLowering::InnerReduction) {} @@ -224,7 +226,7 @@ /// and combines results struct TwoDimMultiReductionToElementWise : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { @@ -261,7 +263,7 @@ /// a sequence of vector.reduction ops. struct TwoDimMultiReductionToReduction : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { @@ -301,7 +303,7 @@ /// separately. struct OneDimMultiReductionToTwoDim : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { @@ -338,12 +340,15 @@ }; void mlir::vector::populateVectorMultiReductionLoweringPatterns( - RewritePatternSet &patterns, VectorMultiReductionLowering options) { + RewritePatternSet &patterns, VectorMultiReductionLowering options, + PatternBenefit benefit) { patterns.add( - patterns.getContext(), options); - patterns.add(patterns.getContext()); + patterns.getContext(), options, benefit); + patterns.add(patterns.getContext(), benefit); if (options == VectorMultiReductionLowering ::InnerReduction) - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext(), + benefit); else - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext(), + benefit); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -257,7 +257,7 @@ /// inserting a memref.subview dropping those unit dims. class TransferReadDropUnitDimsPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, PatternRewriter &rewriter) const override { @@ -300,7 +300,7 @@ /// unit dims, by inserting a memref.subview dropping those unit dims. class TransferWriteDropUnitDimsPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, PatternRewriter &rewriter) const override { @@ -412,7 +412,7 @@ /// already reduced i.e. without unit dims. class FlattenContiguousRowMajorTransferReadPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, PatternRewriter &rewriter) const override { @@ -470,7 +470,7 @@ /// already reduced i.e. without unit dims. class FlattenContiguousRowMajorTransferWritePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, PatternRewriter &rewriter) const override { @@ -543,17 +543,17 @@ } void mlir::vector::populateVectorTransferDropUnitDimsPatterns( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, PatternBenefit benefit) { patterns .add( - patterns.getContext()); + patterns.getContext(), benefit); populateShapeCastFoldingPatterns(patterns); } void mlir::vector::populateFlattenVectorTransferPatterns( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add( - patterns.getContext()); - populateShapeCastFoldingPatterns(patterns); + patterns.getContext(), benefit); + populateShapeCastFoldingPatterns(patterns, benefit); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp @@ -53,7 +53,7 @@ /// vector.transfer_read to do the transpose in memory instead. struct TransferReadPermutationLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp op, PatternRewriter &rewriter) const override { @@ -142,7 +142,7 @@ /// permutation_map: (d0, d1, d2, d3) -> (d2, d3) struct TransferWritePermutationLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferWriteOp op, PatternRewriter &rewriter) const override { @@ -201,7 +201,7 @@ /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) /// vector.broadcast %v struct TransferOpReduceRank : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp op, PatternRewriter &rewriter) const override { @@ -271,8 +271,8 @@ }; void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add( - patterns.getContext()); + patterns.getContext(), benefit); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -216,7 +216,7 @@ // %1 = user %0 : vector<5x4x2xf32> // struct ShapeCastOpFolder : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { @@ -250,7 +250,7 @@ /// Progressive lowering of BroadcastOp. class BroadcastOpLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::BroadcastOp op, PatternRewriter &rewriter) const override { @@ -381,11 +381,11 @@ /// %x = vector.insert .., .. [.., ..] class TransposeOpLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context) - : OpRewritePattern(context), + MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), vectorTransformOptions(vectorTransformOptions) {} LogicalResult matchAndRewrite(vector::TransposeOp op, @@ -470,12 +470,12 @@ class TransposeOp2DToShuffleLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; TransposeOp2DToShuffleLowering( vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context) - : OpRewritePattern(context), + MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), vectorTransformOptions(vectorTransformOptions) {} LogicalResult matchAndRewrite(vector::TransposeOp op, @@ -534,7 +534,7 @@ /// class OuterProductOpLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::OuterProductOp op, PatternRewriter &rewriter) const override { @@ -593,9 +593,9 @@ } ContractOpToElementwise( vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, + MLIRContext *context, PatternBenefit benefit = 1, const FilterConstraintType &constraint = defaultFilter) - : OpRewritePattern(context), + : OpRewritePattern(context, benefit), vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} LogicalResult matchAndRewrite(vector::ContractionOp contractOp, @@ -715,7 +715,7 @@ /// will be folded at LLVM IR level. class ConstantMaskOpLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ConstantMaskOp op, PatternRewriter &rewriter) const override { @@ -789,7 +789,7 @@ /// until a one-dimensional vector is reached. class CreateMaskOpLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { @@ -835,7 +835,7 @@ class ShapeCastOp2DDownCastRewritePattern : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { @@ -868,7 +868,7 @@ class ShapeCastOp2DUpCastRewritePattern : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { @@ -900,7 +900,7 @@ // into the right place if we get here. class ShapeCastOpRewritePattern : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { @@ -974,7 +974,7 @@ /// ``` struct MultiReduceToContract : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, PatternRewriter &rewriter) const override { @@ -1030,7 +1030,7 @@ /// ``` struct CombineContractTranspose : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { @@ -1087,7 +1087,7 @@ /// ``` struct CombineContractBroadcast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { @@ -2036,8 +2036,9 @@ struct TransferReadToVectorLoadLowering : public OpRewritePattern { TransferReadToVectorLoadLowering(MLIRContext *context, - llvm::Optional maxRank) - : OpRewritePattern(context), + llvm::Optional maxRank, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), maxTransferRank(maxRank) {} LogicalResult matchAndRewrite(vector::TransferReadOp read, @@ -2124,7 +2125,7 @@ // trivial case (for architectures for which this matters). struct VectorLoadToMemrefLoadLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::LoadOp loadOp, PatternRewriter &rewriter) const override { @@ -2142,7 +2143,7 @@ /// Replace a 0-d vector.store with a vector.extractelement + memref.store. struct VectorStoreToMemrefStoreLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::StoreOp storeOp, PatternRewriter &rewriter) const override { @@ -2177,8 +2178,9 @@ struct TransferWriteToVectorStoreLowering : public OpRewritePattern { TransferWriteToVectorStoreLowering(MLIRContext *context, - llvm::Optional maxRank) - : OpRewritePattern(context), + llvm::Optional maxRank, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), maxTransferRank(maxRank) {} LogicalResult matchAndRewrite(vector::TransferWriteOp write, @@ -2415,6 +2417,7 @@ struct BubbleUpBitCastForStridedSliceInsert : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, PatternRewriter &rewriter) const override { VectorType castSrcType = bitcastOp.getSourceVectorType(); @@ -2530,8 +2533,9 @@ template struct MaterializeTransferMask : public OpRewritePattern { public: - explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt) - : mlir::OpRewritePattern(context), + explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt, + PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit), force32BitVectorIndices(enableIndexOpt) {} LogicalResult matchAndRewrite(ConcreteOp xferOp, @@ -2583,8 +2587,9 @@ : public OpRewritePattern { public: explicit VectorCreateMaskOpConversion(MLIRContext *context, - bool enableIndexOpt) - : mlir::OpRewritePattern(context), + bool enableIndexOpt, + PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit), force32BitVectorIndices(enableIndexOpt) {} LogicalResult matchAndRewrite(vector::CreateMaskOp op, @@ -2608,7 +2613,7 @@ // Drop inner most contiguous unit dimensions from transfer_read operand. class DropInnerMostUnitDims : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { @@ -2815,7 +2820,7 @@ /// vector<2x3xi32>, vector<2xi32> /// ``` struct ScanToArithOps : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ScanOp scanOp, PatternRewriter &rewriter) const override { @@ -2896,81 +2901,87 @@ } // namespace void mlir::vector::populateVectorMaskMaterializationPatterns( - RewritePatternSet &patterns, bool force32BitVectorIndices) { + RewritePatternSet &patterns, bool force32BitVectorIndices, + PatternBenefit benefit) { patterns.add, MaterializeTransferMask>( - patterns.getContext(), force32BitVectorIndices); + patterns.getContext(), force32BitVectorIndices, benefit); } -void mlir::vector::populateShapeCastFoldingPatterns( - RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); +void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); } void mlir::vector::populateBubbleVectorBitCastOpPatterns( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext()); + BubbleUpBitCastForStridedSliceInsert>(patterns.getContext(), + benefit); } void mlir::vector::populateVectorBroadcastLoweringPatterns( - RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); } void mlir::vector::populateVectorMaskOpLoweringPatterns( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add( - patterns.getContext()); + patterns.getContext(), benefit); } void mlir::vector::populateVectorShapeCastLoweringPatterns( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add( - patterns.getContext()); + patterns.getContext(), benefit); } void mlir::vector::populateVectorContractLoweringPatterns( - RewritePatternSet &patterns, VectorTransformsOptions options) { - patterns.add(patterns.getContext()); + RewritePatternSet &patterns, VectorTransformsOptions options, + PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); patterns.add(options, - patterns.getContext()); + ContractionOpToOuterProductOpLowering>( + options, patterns.getContext(), benefit); } void mlir::vector::populateVectorTransposeLoweringPatterns( - RewritePatternSet &patterns, VectorTransformsOptions options) { + RewritePatternSet &patterns, VectorTransformsOptions options, + PatternBenefit benefit) { patterns.add( - options, patterns.getContext()); + options, patterns.getContext(), benefit); } void mlir::vector::populateVectorReductionToContractPatterns( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext()); + ReorderElementwiseOpsOnTranspose>(patterns.getContext(), + benefit); } void mlir::vector:: populateVectorTransferCollapseInnerMostContiguousDimsPatterns( - RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); } void mlir::vector::populateVectorTransferLoweringPatterns( - RewritePatternSet &patterns, llvm::Optional maxTransferRank) { + RewritePatternSet &patterns, llvm::Optional maxTransferRank, + PatternBenefit benefit) { patterns.add(patterns.getContext(), - maxTransferRank); + maxTransferRank, benefit); patterns .add( - patterns.getContext()); + patterns.getContext(), benefit); } void mlir::vector::populateVectorScanLoweringPatterns( - RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp @@ -181,9 +181,11 @@ struct UnrollTransferReadPattern : public OpRewritePattern { UnrollTransferReadPattern(MLIRContext *context, - const vector::UnrollVectorOptions &options) - : OpRewritePattern(context, /*benefit=*/1), + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. @@ -236,9 +238,11 @@ struct UnrollTransferWritePattern : public OpRewritePattern { UnrollTransferWritePattern(MLIRContext *context, - const vector::UnrollVectorOptions &options) - : OpRewritePattern(context, /*benefit=*/1), + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} + LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. @@ -306,8 +310,9 @@ struct UnrollContractionPattern : public OpRewritePattern { UnrollContractionPattern(MLIRContext *context, - const vector::UnrollVectorOptions &options) - : OpRewritePattern(context, /*benefit=*/1), + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} LogicalResult matchAndRewrite(vector::ContractionOp contractOp, @@ -408,8 +413,9 @@ struct UnrollMultiReductionPattern : public OpRewritePattern { UnrollMultiReductionPattern(MLIRContext *context, - const vector::UnrollVectorOptions &options) - : OpRewritePattern(context, /*benefit=*/1), + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, @@ -481,9 +487,11 @@ struct UnrollElementwisePattern : public RewritePattern { UnrollElementwisePattern(MLIRContext *context, - const vector::UnrollVectorOptions &options) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options) {} + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) @@ -539,7 +547,8 @@ /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32> /// %dv = arith.addf %da, %db : vector<1xf32> struct PointwiseExtractPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ExtractMapOp extract, PatternRewriter &rewriter) const override { Operation *definedOp = extract.getVector().getDefiningOp(); @@ -570,7 +579,8 @@ /// Canonicalize an extract_map using the result of a contract operation. /// This propagate the extract_map to operands. struct ContractExtractPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ExtractMapOp extract, PatternRewriter &rewriter) const override { Operation *definedOp = extract.getVector().getDefiningOp(); @@ -631,8 +641,8 @@ /// ``` struct TransferReadExtractPattern : public OpRewritePattern { - TransferReadExtractPattern(MLIRContext *context) - : OpRewritePattern(context) {} + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. @@ -682,8 +692,8 @@ struct TransferWriteInsertPattern : public OpRewritePattern { - TransferWriteInsertPattern(MLIRContext *context) - : OpRewritePattern(context) {} + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. @@ -726,8 +736,9 @@ struct UnrollReductionPattern : public OpRewritePattern { UnrollReductionPattern(MLIRContext *context, - const vector::UnrollVectorOptions &options) - : OpRewritePattern(context, /*benefit=*/1), + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, @@ -772,9 +783,11 @@ struct UnrollTranposePattern : public OpRewritePattern { UnrollTranposePattern(MLIRContext *context, - const vector::UnrollVectorOptions &options) - : OpRewritePattern(context, /*benefit=*/1), + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} + LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp, PatternRewriter &rewriter) const override { if (tranposeOp.getResultType().getRank() == 0) @@ -821,16 +834,17 @@ } // namespace void mlir::vector::populateVectorUnrollPatterns( - RewritePatternSet &patterns, const UnrollVectorOptions &options) { + RewritePatternSet &patterns, const UnrollVectorOptions &options, + PatternBenefit benefit) { patterns.add(patterns.getContext(), options); + UnrollTranposePattern>(patterns.getContext(), options, benefit); } void mlir::vector::populatePropagateVectorDistributionPatterns( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add( - patterns.getContext()); + patterns.getContext(), benefit); } diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -160,7 +160,7 @@ VectorContractLowering lowering = VectorContractLowering::OuterProduct; VectorTransformsOptions options{lowering}; patterns.add( - options, &getContext(), [](vector::ContractionOp op) { + options, &getContext(), /*benefit=*/1, [](vector::ContractionOp op) { // Only lowers vector.contract where the lhs as a type vector // where M is not 4. if (op.getRhsType().getShape()[0] == 4)