diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -67,7 +67,8 @@ /// producer (consumer) generic operation by expanding the dimensionality of the /// loop in the generic op. void populateFoldReshapeOpsByExpansionPatterns( - MLIRContext *context, OwningRewritePatternList &patterns); + MLIRContext *context, OwningRewritePatternList &patterns, + bool allowFoldingUnitDimReshapes); /// Patterns to fold a collapsing (expanding) tensor_reshape operation with its /// producer (consumer) generic/indexed_generic operation by linearizing the @@ -87,7 +88,8 @@ /// Patterns for fusing linalg operation on tensors. void populateLinalgTensorOpsFusionPatterns(MLIRContext *context, - OwningRewritePatternList &patterns); + OwningRewritePatternList &patterns, + bool allowFoldingUnitDimReshapes); /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on /// tensors. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -37,6 +37,12 @@ def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> { let summary = "Fuse operations on RankedTensorType in linalg dialect"; let constructor = "mlir::createLinalgFusionOfTensorOpsPass()"; + let options = [ + Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes", + "bool", /*default=*/"false", + "Allow fusing linalg.tensor_reshape ops that performs unit " + "dimension collapsing"> + ]; let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"]; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -868,9 +868,14 @@ /// generic/indexed_generic op, when the reshape op is collapsing /// dimensions. The dimensionality of the loop in the consumer is expanded. template -struct FoldWithProducerReshapeOpByExpansion +class FoldWithProducerReshapeOpByExpansion : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +public: + FoldWithProducerReshapeOpByExpansion(MLIRContext *context, + bool foldUnitDimReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + allowFoldingUnitDimReshapes(foldUnitDimReshapes) {} LogicalResult matchAndRewrite(GenericOpTy genericOp, PatternRewriter &rewriter) const override { @@ -887,8 +892,9 @@ if (reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank() || !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) || - isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), - reshapeOp.getReassociationMaps())) + (!allowFoldingUnitDimReshapes && + isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), + reshapeOp.getReassociationMaps()))) continue; Optional> replacementValues = @@ -903,6 +909,9 @@ } return failure(); } + +private: + bool allowFoldingUnitDimReshapes; }; /// Pattern to fold tensor_reshape op with its producer. The corresponding index @@ -1114,7 +1123,8 @@ void runOnOperation() override { OwningRewritePatternList patterns; Operation *op = getOperation(); - populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns); + populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns, + allowFoldingUnitDimReshapes); (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; @@ -1149,19 +1159,22 @@ } void mlir::populateFoldReshapeOpsByExpansionPatterns( - MLIRContext *context, OwningRewritePatternList &patterns) { - patterns.insert, + MLIRContext *context, OwningRewritePatternList &patterns, + bool allowFoldingUnitDimReshapes) { + patterns.insert(context); + patterns.insert, FoldWithProducerReshapeOpByExpansion>( - context); + context, allowFoldingUnitDimReshapes); } void mlir::populateLinalgTensorOpsFusionPatterns( - MLIRContext *context, OwningRewritePatternList &patterns) { + MLIRContext *context, OwningRewritePatternList &patterns, + bool allowFoldingUnitDimReshapes) { patterns.insert, FuseTensorOps, FoldSplatConstants, FoldSplatConstants>(context); - populateFoldReshapeOpsByExpansionPatterns(context, patterns); + populateFoldReshapeOpsByExpansionPatterns(context, patterns, + allowFoldingUnitDimReshapes); GenericOp::getCanonicalizationPatterns(patterns, context); IndexedGenericOp::getCanonicalizationPatterns(patterns, context); TensorReshapeOp::getCanonicalizationPatterns(patterns, context); diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops -split-input-file -verify-each=0 | FileCheck %s +// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=false" -split-input-file -verify-each=0 | FileCheck %s +// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=true" -split-input-file -verify-each=0 | FileCheck %s --check-prefix=FOLDUNITDIM #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> @@ -300,7 +301,7 @@ %5 = addi %3, %4 : i32 %6 = index_cast %arg2 : index to i32 %7 = addi %5, %6 : i32 - linalg.yield %7 : i32 + linalg.yield %7 : i32 } -> tensor<6x4x210xi32> %d = linalg.tensor_reshape %c [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, @@ -531,3 +532,11 @@ // CHECK-DAG: linalg.tensor_reshape // CHECK-DAG: linalg.init_tensor // CHECK: linalg.generic +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor, tensor) + +// FOLDUNITDIM: func @unit_dim_reshape_expansion_full +// FOLDUNITDIM: linalg.init_tensor +// FOLDUNITDIM-COUNT-2: linalg.tensor_reshape +// FOLDUNITDIM: linalg.generic +// FOLDUNITDIM-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>) +