diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -65,7 +65,8 @@ /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// producer (consumer) generic operation by expanding the dimensionality of the /// loop in the generic op. -void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns); +void populateFoldReshapeOpsByExpansionPatterns( + RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false); /// Patterns to fold a collapsing (expanding) tensor_reshape operation with its /// producer (consumer) generic/indexed_generic operation by linearizing the @@ -83,7 +84,8 @@ RewritePatternSet &patterns); /// Patterns for fusing linalg operation on tensors. -void populateLinalgTensorOpsFusionPatterns(RewritePatternSet &patterns); +void populateLinalgTensorOpsFusionPatterns( + RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false); /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on /// tensors. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -37,6 +37,12 @@ def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> { let summary = "Fuse operations on RankedTensorType in linalg dialect"; let constructor = "mlir::createLinalgFusionOfTensorOpsPass()"; + let options = [ + Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes", + "bool", /*default=*/"false", + "Allow fusing linalg.tensor_reshape ops that performs unit " + "dimension collapsing"> + ]; let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"]; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -897,9 +897,14 @@ /// generic/indexed_generic op, when the reshape op is collapsing /// dimensions. The dimensionality of the loop in the consumer is expanded. template -struct FoldWithProducerReshapeOpByExpansion +class FoldWithProducerReshapeOpByExpansion : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +public: + FoldWithProducerReshapeOpByExpansion(MLIRContext *context, + bool foldUnitDimReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + allowFoldingUnitDimReshapes(foldUnitDimReshapes) {} LogicalResult matchAndRewrite(GenericOpTy genericOp, PatternRewriter &rewriter) const override { @@ -916,8 +921,9 @@ if (reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank() || !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) || - isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), - reshapeOp.getReassociationMaps())) + (!allowFoldingUnitDimReshapes && + isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), + reshapeOp.getReassociationMaps()))) continue; Optional> replacementValues = @@ -930,6 +936,9 @@ } return failure(); } + +private: + bool allowFoldingUnitDimReshapes; }; /// Pattern to fold tensor_reshape op with its producer. The corresponding index @@ -1134,7 +1143,8 @@ void runOnOperation() override { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); - populateLinalgTensorOpsFusionPatterns(patterns); + populateLinalgTensorOpsFusionPatterns(patterns, + allowFoldingUnitDimReshapes); (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; @@ -1171,20 +1181,22 @@ } void mlir::populateFoldReshapeOpsByExpansionPatterns( - RewritePatternSet &patterns) { - patterns.add, + RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) { + patterns.add(patterns.getContext()); + patterns.add, FoldWithProducerReshapeOpByExpansion>( - patterns.getContext()); + patterns.getContext(), allowFoldingUnitDimReshapes); } -void mlir::populateLinalgTensorOpsFusionPatterns(RewritePatternSet &patterns) { +void mlir::populateLinalgTensorOpsFusionPatterns( + RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) { auto *context = patterns.getContext(); patterns .add, FuseTensorOps, FoldSplatConstants, FoldSplatConstants>( context); - populateFoldReshapeOpsByExpansionPatterns(patterns); + populateFoldReshapeOpsByExpansionPatterns(patterns, + allowFoldingUnitDimReshapes); GenericOp::getCanonicalizationPatterns(patterns, context); IndexedGenericOp::getCanonicalizationPatterns(patterns, context); TensorReshapeOp::getCanonicalizationPatterns(patterns, context); diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops -split-input-file -verify-each=0 | FileCheck %s +// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=false" -split-input-file -verify-each=0 | FileCheck %s +// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=true" -split-input-file -verify-each=0 | FileCheck %s --check-prefix=FOLDUNITDIM #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> @@ -300,7 +301,7 @@ %5 = addi %3, %4 : i32 %6 = index_cast %arg2 : index to i32 %7 = addi %5, %6 : i32 - linalg.yield %7 : i32 + linalg.yield %7 : i32 } -> tensor<6x4x210xi32> %d = linalg.tensor_reshape %c [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, @@ -531,3 +532,11 @@ // CHECK-DAG: linalg.tensor_reshape // CHECK-DAG: linalg.init_tensor // CHECK: linalg.generic +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor, tensor) + +// FOLDUNITDIM: func @unit_dim_reshape_expansion_full +// FOLDUNITDIM: linalg.init_tensor +// FOLDUNITDIM-COUNT-2: linalg.tensor_reshape +// FOLDUNITDIM: linalg.generic +// FOLDUNITDIM-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>) +