Index: mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h =================================================================== --- mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -28,6 +28,13 @@ struct LinalgFusionOptions; struct LinalgTilingOptions; +using ControlFoldingReshapesFn = std::function; + +/// Default function to control reshape folding. Skips folding unit dimension +/// reshapes. +bool skipUnitDimReshape(const OpResult &producer, const OpOperand &consumer); + //===----------------------------------------------------------------------===// // Transformations exposed as function calls. //===----------------------------------------------------------------------===// @@ -46,7 +53,8 @@ /// producer (consumer) generic operation by expanding the dimensionality of the /// loop in the generic op. void populateFoldReshapeOpsByExpansionPatterns( - RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false); + RewritePatternSet &patterns, + ControlFoldingReshapesFn controlFoldingReshapes = skipUnitDimReshape); /// Patterns to fold a collapsing (expanding) tensor_reshape operation with its /// producer (consumer) generic/indexed_generic operation by linearizing the @@ -76,12 +84,13 @@ /// Options that control fusion of elementwise operations. struct LinalgElementwiseFusionOptions { - /// Enable fusion of reshapes that are introducing unit-dimensions into the - /// shape with elementwise operations. By default this is disabled. - bool allowFoldingUnitDimReshapes = false; + /// Enable fusion of reshapes into the shape with elementwise operations. By + /// default it is disabled for unit dimensions reshape. + ControlFoldingReshapesFn controlFoldingReshapesFn = skipUnitDimReshape; - LinalgElementwiseFusionOptions &setAllowFoldingUnitDimReshapes(bool val) { - allowFoldingUnitDimReshapes = val; + LinalgElementwiseFusionOptions & + setControlFoldingReshapes(ControlFoldingReshapesFn fun) { + controlFoldingReshapesFn = std::move(fun); return *this; } Index: mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -1162,10 +1162,10 @@ : public OpRewritePattern { public: FoldWithProducerReshapeOpByExpansion(MLIRContext *context, - bool foldUnitDimReshapes, + ControlFoldingReshapesFn foldReshapes, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), - allowFoldingUnitDimReshapes(foldUnitDimReshapes) {} + controlFoldingReshapes(foldReshapes) {} LogicalResult matchAndRewrite(GenericOpTy genericOp, PatternRewriter &rewriter) const override { @@ -1175,16 +1175,15 @@ operand.value().getDefiningOp(); if (!reshapeOp) continue; - // Fold only if // - The tensor reshape op is folding. // - All constraints of fusing with reshape by expansion are met. if (reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank() || !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) || - (!allowFoldingUnitDimReshapes && - isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), - reshapeOp.getReassociationMaps()))) + (!controlFoldingReshapes( + reshapeOp->getResult(0), + linalgOp.getInputOpOperands()[operand.index()]))) continue; Optional> replacementValues = @@ -1199,7 +1198,7 @@ } private: - bool allowFoldingUnitDimReshapes; + ControlFoldingReshapesFn controlFoldingReshapes; }; /// Pattern to fold tensor_reshape op with its producer. The corresponding index @@ -1391,6 +1390,13 @@ controlFn, rewriter); } +bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, + const OpOperand &consumer) { + auto reshapeOp = producer.getDefiningOp(); + return !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), + reshapeOp.getReassociationMaps()); +} + namespace { /// Patterns to fuse a generic op, with the producer of its operands. template @@ -1428,10 +1434,14 @@ void runOnOperation() override { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); + ControlFoldingReshapesFn allowFoldingFn = [](const OpResult &producer, + const OpOperand &consumer) { + return true; + }; populateElementwiseOpsFusionPatterns( patterns, - LinalgElementwiseFusionOptions().setAllowFoldingUnitDimReshapes( - allowFoldingUnitDimReshapes)); + LinalgElementwiseFusionOptions().setControlFoldingReshapes( + allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape)); (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; @@ -1468,11 +1478,12 @@ } void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( - RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) { + RewritePatternSet &patterns, + ControlFoldingReshapesFn controlFoldingReshapes) { patterns.add(patterns.getContext()); patterns.add, FoldWithProducerReshapeOpByExpansion>( - patterns.getContext(), allowFoldingUnitDimReshapes); + patterns.getContext(), controlFoldingReshapes); } void mlir::linalg::populateElementwiseOpsFusionPatterns( @@ -1482,8 +1493,8 @@ .add, FuseElementwiseOps, FoldSplatConstants, FoldSplatConstants>( context, options.controlElementwiseOpsFusionFn); - populateFoldReshapeOpsByExpansionPatterns( - patterns, options.allowFoldingUnitDimReshapes); + populateFoldReshapeOpsByExpansionPatterns(patterns, + options.controlFoldingReshapesFn); AffineApplyOp::getCanonicalizationPatterns(patterns, context); GenericOp::getCanonicalizationPatterns(patterns, context); IndexedGenericOp::getCanonicalizationPatterns(patterns, context);