diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -33,7 +33,7 @@ /// Default function to control reshape folding. Skips folding unit dimension /// reshapes. -bool skipUnitDimReshape(const OpResult &producer, const OpOperand &consumer); +bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer); //===----------------------------------------------------------------------===// // Transformations exposed as function calls. @@ -49,8 +49,11 @@ /// parallel loops. void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns); +/// Function type which is used to control when to stop fusion. It is expected +/// that OpOperand is not modified in the callback. The OpOperand is not marked +/// as const to allow callers to use non-const methods. using ControlElementwiseOpsFusionFn = - std::function; + std::function; /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// producer (consumer) generic operation by expanding the dimensionality of the @@ -104,7 +107,7 @@ /// can be used to abort the fusion based on non-structural constraints. This /// is the hook for cost models to control the amount of fusion done. ControlElementwiseOpsFusionFn controlElementwiseOpsFusionFn = - [](const OpResult & /*producer */, const OpOperand & /*consumer */) { + [](const OpResult & /*producer */, OpOperand & /*consumer */) { return true; }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1241,7 +1241,7 @@ } bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, - const OpOperand &consumer) { + OpOperand &consumer) { auto expandShapeOp = producer.getDefiningOp(); if (expandShapeOp) return !isUnitDimExpansionOnly(expandShapeOp);