diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -136,6 +136,9 @@ /// tensors via rank-reducing slices. void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns); +/// A pattern that converts init operands to input operands. +void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns); + /// Patterns that are used to inline constant operands into linalg generic ops. void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -233,7 +233,7 @@ } }; -/// Pattern to add init operands to ins when all the loops are parallel and +/// Pattern to move init operands to ins when all the loops are parallel and /// blockArgument corresponding to init is used in the region. This is a fix-up /// when unit reduction dimensions are all folded away. In this context, it /// becomes a elementwise generic op. E.g., it converts @@ -269,7 +269,7 @@ /// %4 = arith.addf %in, %in_0 : f32 /// linalg.yield %4 : f32 /// } -> tensor<1x1xf32> -struct AddInitOperandsToInput : public OpRewritePattern { +struct MoveInitOperandsToInput : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { @@ -667,10 +667,10 @@ patterns.add(context, RankReductionStrategy::ReassociativeReshape); // TODO: Patterns unrelated to unit dim folding should be factored out. - patterns - .add, - RankReducedInsertSliceOp>(context); + patterns.add, + RankReducedInsertSliceOp>( + context); linalg::FillOp::getCanonicalizationPatterns(patterns, context); tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); @@ -688,6 +688,11 @@ patterns.add(context); } +void mlir::linalg::populateMoveInitOperandsToInputPattern( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + namespace { /// Pass that removes unit-extent dims within generic ops. struct LinalgFoldUnitExtentDimsPass @@ -697,11 +702,13 @@ MLIRContext *context = op->getContext(); RewritePatternSet patterns(context); if (foldOneTripLoopsOnly) { - patterns.add(context); + patterns.add(context); } else if (useRankReducingSlices) { populateFoldUnitExtentDimsViaSlicesPatterns(patterns); + populateMoveInitOperandsToInputPattern(patterns); } else { populateFoldUnitExtentDimsViaReshapesPatterns(patterns); + populateMoveInitOperandsToInputPattern(patterns); } (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); }