diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -28,6 +28,10 @@ struct LinalgFusionOptions; struct LinalgTilingOptions; +/// Default function to control reshape folding. Skips folding unit dimension +/// reshapes. +bool skipUnitDimReshape(const OpResult &producer, const OpOperand &consumer); + //===----------------------------------------------------------------------===// // Transformations exposed as function calls. //===----------------------------------------------------------------------===// @@ -42,11 +46,15 @@ /// parallel loops. void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns); +using ControlElementwiseOpsFusionFn = + std::function; + /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// producer (consumer) generic operation by expanding the dimensionality of the /// loop in the generic op. void populateFoldReshapeOpsByExpansionPatterns( - RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false); + RewritePatternSet &patterns, + ControlElementwiseOpsFusionFn controlFoldingReshapes = skipUnitDimReshape); /// Patterns to fold a collapsing (expanding) tensor_reshape operation with its /// producer (consumer) generic/indexed_generic operation by linearizing the @@ -71,17 +79,15 @@ /// tensors. void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns); -using ControlElementwiseOpsFusionFn = - std::function; - /// Options that control fusion of elementwise operations. struct LinalgElementwiseFusionOptions { - /// Enable fusion of reshapes that are introducing unit-dimensions into the - /// shape with elementwise operations. By default this is disabled. - bool allowFoldingUnitDimReshapes = false; + /// Enable fusion of reshapes into the shape with elementwise operations. By + /// default it is disabled for unit dimensions reshape. + ControlElementwiseOpsFusionFn controlFoldingReshapesFn = skipUnitDimReshape; - LinalgElementwiseFusionOptions &setAllowFoldingUnitDimReshapes(bool val) { - allowFoldingUnitDimReshapes = val; + LinalgElementwiseFusionOptions & + setControlFoldingReshapes(ControlElementwiseOpsFusionFn fun) { + controlFoldingReshapesFn = std::move(fun); return *this; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -1164,11 +1164,11 @@ class FoldWithProducerReshapeOpByExpansion : public OpRewritePattern { public: - FoldWithProducerReshapeOpByExpansion(MLIRContext *context, - bool foldUnitDimReshapes, - PatternBenefit benefit = 1) + FoldWithProducerReshapeOpByExpansion( + MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, + PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), - allowFoldingUnitDimReshapes(foldUnitDimReshapes) {} + controlFoldingReshapes(foldReshapes) {} LogicalResult matchAndRewrite(GenericOpTy genericOp, PatternRewriter &rewriter) const override { @@ -1178,16 +1178,15 @@ operand.value().getDefiningOp(); if (!reshapeOp) continue; - // Fold only if // - The tensor reshape op is folding. // - All constraints of fusing with reshape by expansion are met. if (reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank() || !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) || - (!allowFoldingUnitDimReshapes && - isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), - reshapeOp.getReassociationMaps()))) + (!controlFoldingReshapes( + reshapeOp->getResult(0), + linalgOp.getInputOpOperands()[operand.index()]))) continue; Optional> replacementValues = @@ -1202,7 +1201,7 @@ } private: - bool allowFoldingUnitDimReshapes; + ControlElementwiseOpsFusionFn controlFoldingReshapes; }; /// Pattern to fold tensor_reshape op with its producer. The corresponding index @@ -1394,6 +1393,13 @@ controlFn, rewriter); } +bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, + const OpOperand &consumer) { + auto reshapeOp = producer.getDefiningOp(); + return !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), + reshapeOp.getReassociationMaps()); +} + namespace { /// Patterns to fuse a generic op, with the producer of its operands. template @@ -1431,10 +1437,14 @@ void runOnOperation() override { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); + ControlElementwiseOpsFusionFn allowFoldingFn = + [](const OpResult &producer, const OpOperand &consumer) { + return true; + }; populateElementwiseOpsFusionPatterns( patterns, - LinalgElementwiseFusionOptions().setAllowFoldingUnitDimReshapes( - allowFoldingUnitDimReshapes)); + LinalgElementwiseFusionOptions().setControlFoldingReshapes( + allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape)); (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; @@ -1471,11 +1481,12 @@ } void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( - RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) { + RewritePatternSet &patterns, + ControlElementwiseOpsFusionFn controlFoldingReshapes) { patterns.add(patterns.getContext()); patterns.add, FoldWithProducerReshapeOpByExpansion>( - patterns.getContext(), allowFoldingUnitDimReshapes); + patterns.getContext(), controlFoldingReshapes); } void mlir::linalg::populateElementwiseOpsFusionPatterns( @@ -1485,8 +1496,8 @@ .add, FuseElementwiseOps, FoldSplatConstants, FoldSplatConstants>( context, options.controlElementwiseOpsFusionFn); - populateFoldReshapeOpsByExpansionPatterns( - patterns, options.allowFoldingUnitDimReshapes); + populateFoldReshapeOpsByExpansionPatterns(patterns, + options.controlFoldingReshapesFn); AffineApplyOp::getCanonicalizationPatterns(patterns, context); GenericOp::getCanonicalizationPatterns(patterns, context); IndexedGenericOp::getCanonicalizationPatterns(patterns, context);