diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1069,6 +1069,12 @@ const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc = false); +/// Collapses dimensions of linalg.generic operation. It also collapses inputs +/// before the op and expands outputs after the op. +FailureOr> collapseGenericOpIterationDims( + GenericOp genericOp, ArrayRef foldedIterationDims, + RewriterBase &rewriter); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1353,7 +1353,7 @@ void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ValueRange loopRange, - PatternRewriter &rewriter) { + RewriterBase &rewriter) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointToStart(block); @@ -1389,9 +1389,9 @@ } /// Implementation of fusion with reshape operation by collapsing dimensions. -static FailureOr> collapseGenericOpIterationDims( +FailureOr> mlir::linalg::collapseGenericOpIterationDims( GenericOp genericOp, ArrayRef foldedIterationDims, - PatternRewriter &rewriter) { + RewriterBase &rewriter) { // Bail on trivial no-op cases. if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() || llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { @@ -1570,7 +1570,7 @@ genericOp, collapsableIterationDims, rewriter); if (!replacements) { return rewriter.notifyMatchFailure(genericOp, - "failed to collpase dimensions"); + "failed to collapse dimensions"); } rewriter.replaceOp(genericOp, *replacements); return success();