diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -50,10 +50,6 @@ /// buffers instead. std::unique_ptr> createLinalgBufferizePass(); -/// Populate patterns that convert `ElementwiseMappable` ops to linalg -/// parallel loops. -void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns); - /// Create a pass to conver named Linalg operations to Linalg generic /// operations. std::unique_ptr> createLinalgGeneralizationPass(); @@ -62,35 +58,6 @@ /// work on primitive types, if possible. std::unique_ptr createLinalgDetensorizePass(); -/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its -/// producer (consumer) generic operation by expanding the dimensionality of the -/// loop in the generic op. -void populateFoldReshapeOpsByExpansionPatterns( - RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false); - -/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its -/// producer (consumer) generic/indexed_generic operation by linearizing the -/// indexing map used to access the source (target) of the reshape operation in -/// the generic/indexed_generic operation. -void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns); - -/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its -/// producer (consumer) generic/indexed_generic operation by linearizing the -/// indexing map used to access the source (target) of the reshape operation in -/// the generic/indexed_generic operation. The patterns are applied only when -/// the tensor reshape involved is collapsing (introducing) unit-extent -/// dimensions. -void populateFoldUnitDimsReshapeOpsByLinearizationPatterns( - RewritePatternSet &patterns); - -/// Patterns for fusing linalg operation on tensors. -void populateLinalgTensorOpsFusionPatterns( - RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false); - -/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on -/// tensors. -void populateLinalgFoldUnitExtentDimsPatterns(RewritePatternSet &patterns); - //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -36,10 +36,43 @@ MLIRContext *context, SmallVectorImpl &patterns, ArrayRef tileSizes); +/// Populate patterns that convert `ElementwiseMappable` ops to linalg +/// parallel loops. +void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns); + +/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its +/// producer (consumer) generic operation by expanding the dimensionality of the +/// loop in the generic op. +void populateFoldReshapeOpsByExpansionPatterns( + RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false); + +/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its +/// producer (consumer) generic/indexed_generic operation by linearizing the +/// indexing map used to access the source (target) of the reshape operation in +/// the generic/indexed_generic operation. +void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns); + +/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its +/// producer (consumer) generic/indexed_generic operation by linearizing the +/// indexing map used to access the source (target) of the reshape operation in +/// the generic/indexed_generic operation. The patterns are applied only when +/// the tensor reshape involved is collapsing (introducing) unit-extent +/// dimensions. +void populateFoldUnitDimsReshapeOpsByLinearizationPatterns( + RewritePatternSet &patterns); + /// Populates the given list with patterns to bufferize linalg ops. void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter, RewritePatternSet &patterns); +/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on +/// tensors. +void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns); + +/// Patterns for fusing linalg operation on tensors. +void populateElementwiseOpsFusionPatterns( + RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false); + /// Performs standalone tiling of a single LinalgOp by `tileSizes`. /// and permute the loop nest according to `interchangeVector` /// The permutation is expressed as a list of integers that specify diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -136,11 +136,6 @@ OpResult producerOpResult, OpOperand &consumerOpOperand); -/// Fuse linalg operation on tensors, with the producer of the operand at -/// position `consumerIdx` of the consumer. -Optional> fuseTensorOps(PatternRewriter &rewriter, - OpOperand &consumerOpOperand); - //===----------------------------------------------------------------------===// // Distribution utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" @@ -556,7 +557,7 @@ /// Patterns that are used to canonicalize the use of unit-extent dims for /// broadcasting. -void mlir::populateLinalgFoldUnitExtentDimsPatterns( +void mlir::linalg::populateFoldUnitExtentDimsPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); patterns.add, FoldUnitDimLoops, @@ -580,7 +581,7 @@ .add, FoldUnitDimLoops>( context); else - populateLinalgFoldUnitExtentDimsPatterns(patterns); + populateFoldUnitExtentDimsPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -10,6 +10,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Utils/Utils.h" @@ -115,7 +116,7 @@ }; } // namespace -void mlir::populateElementwiseToLinalgConversionPatterns( +void mlir::linalg::populateElementwiseToLinalgConversionPatterns( RewritePatternSet &patterns) { patterns.add( patterns.getContext()); @@ -131,7 +132,7 @@ ConversionTarget target(*context); RewritePatternSet patterns(context); - populateElementwiseToLinalgConversionPatterns(patterns); + mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns); target.markUnknownOpDynamicallyLegal([](Operation *op) { return !isElementwiseMappableOpOnRankedTensors(op); }); diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -26,8 +26,8 @@ using namespace mlir::linalg; /// Implementation of fusion of generic ops and indexed_generic ops. -static bool areTensorOpsFusable(LinalgOp producer, LinalgOp consumer, - unsigned consumerIdx) { +static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer, + unsigned consumerIdx) { // Producer and consumer must have tensor semantics. if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) return false; @@ -91,11 +91,11 @@ /// Generate the region of the fused tensor operation. The region of the fused /// op must be empty. -static void generateFusedTensorOpRegion(PatternRewriter &rewriter, - Operation *fusedOp, LinalgOp producer, - LinalgOp consumer, - AffineMap consumerToProducerLoopsMap, - unsigned consumerIdx, unsigned nloops) { +static void +generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp, + LinalgOp producer, LinalgOp consumer, + AffineMap consumerToProducerLoopsMap, + unsigned consumerIdx, unsigned nloops) { // Build the region of the fused op. Block &producerBlock = producer->getRegion(0).front(); Block &consumerBlock = consumer->getRegion(0).front(); @@ -208,11 +208,11 @@ } static Optional> -fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand, - PatternRewriter &rewriter) { +fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand, + PatternRewriter &rewriter) { LinalgOp consumer = cast(consumerOpOperand.getOwner()); unsigned consumerIdx = consumerOpOperand.getOperandNumber(); - if (!areTensorOpsFusable(producer, consumer, consumerIdx)) + if (!areElementwiseOpsFusable(producer, consumer, consumerIdx)) return llvm::None; unsigned numFusedOperands = @@ -291,9 +291,9 @@ AffineMap consumerToProducerLoopsMap = invProducerResultIndexMap.compose(consumerResultIndexMap); - generateFusedTensorOpRegion(rewriter, fusedOp.getOperation(), producer, - consumer, consumerToProducerLoopsMap, consumerIdx, - consumer.getNumLoops()); + generateFusedElementwiseOpRegion(rewriter, fusedOp.getOperation(), producer, + consumer, consumerToProducerLoopsMap, + consumerIdx, consumer.getNumLoops()); return SmallVector(fusedOp->getResults()); } @@ -1102,9 +1102,8 @@ }; } // namespace -Optional> -mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, - OpOperand &consumerOpOperand) { +static Optional> +fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand) { Operation *producer = consumerOpOperand.get().getDefiningOp(); if (!producer || producer->getNumResults() != 1) return llvm::None; @@ -1114,14 +1113,14 @@ !isa(producer)) return llvm::None; - return fuseTensorOpsImpl(cast(producer), consumerOpOperand, - rewriter); + return fuseElementwiseOpsImpl(cast(producer), consumerOpOperand, + rewriter); } namespace { /// Patterns to fuse a generic op, with the producer of its operands. template -struct FuseTensorOps : public OpRewritePattern { +struct FuseElementwiseOps : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(LinalgOpTy op, @@ -1133,7 +1132,7 @@ if (!producerOp || !producerOp.hasTensorSemantics()) continue; Optional> fusedOpResults = - fuseTensorOps(rewriter, opOperand); + fuseElementwiseOps(rewriter, opOperand); if (fusedOpResults) { rewriter.replaceOp(op, *fusedOpResults); return success(); @@ -1149,8 +1148,7 @@ void runOnOperation() override { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); - populateLinalgTensorOpsFusionPatterns(patterns, - allowFoldingUnitDimReshapes); + populateElementwiseOpsFusionPatterns(patterns, allowFoldingUnitDimReshapes); (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; @@ -1170,7 +1168,7 @@ } // namespace -void mlir::populateFoldReshapeOpsByLinearizationPatterns( +void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns) { patterns.add, FoldProducerReshapeOpByLinearization, @@ -1178,7 +1176,7 @@ patterns.getContext()); } -void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( +void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns) { patterns.add, FoldProducerReshapeOpByLinearization, @@ -1186,7 +1184,7 @@ patterns.getContext()); } -void mlir::populateFoldReshapeOpsByExpansionPatterns( +void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) { patterns.add(patterns.getContext()); patterns.add, @@ -1194,11 +1192,11 @@ patterns.getContext(), allowFoldingUnitDimReshapes); } -void mlir::populateLinalgTensorOpsFusionPatterns( +void mlir::linalg::populateElementwiseOpsFusionPatterns( RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) { auto *context = patterns.getContext(); patterns - .add, FuseTensorOps, + .add, FuseElementwiseOps, FoldSplatConstants, FoldSplatConstants>( context); populateFoldReshapeOpsByExpansionPatterns(patterns,