diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1540,9 +1540,10 @@ /// Patterns that are used to bubble up extract slice op above linalg op. void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); -/// Adds patterns that waps tensor.extract_slice(linalg.fill(%cst, %init)) into -/// linalg.fill(%cst, tensor.extract_slice(%init)). -void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns); +/// Adds patterns that fold `linalg.fill` operations with its consumers. +/// Currently folds it with consumers that are `tensor.extract_slice` (if only +/// use) and `tensor.collapse/expand_shape` operations. +void populateFoldLinalgFillPatterns(RewritePatternSet &patterns); /// Patterns to apply `splitReduction` below. void populateSplitReductionPattern( diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ ElementwiseToLinalg.cpp EliminateEmptyTensors.cpp EraseUnusedOperandsAndResults.cpp + FoldLinalgFillPatterns.cpp FusePadOpWithLinalgProducer.cpp Fusion.cpp Generalization.cpp @@ -27,7 +28,6 @@ Split.cpp SplitReduction.cpp SubsetHoisting.cpp - SwapExtractSliceWithFillPatterns.cpp Tiling.cpp TilingInterfaceImpl.cpp Transforms.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldLinalgFillPatterns.cpp rename from mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp rename to mlir/lib/Dialect/Linalg/Transforms/FoldLinalgFillPatterns.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FoldLinalgFillPatterns.cpp @@ -35,7 +35,32 @@ } }; -void mlir::linalg::populateSwapExtractSliceWithFillPatterns( - RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); +/// Fold `tensor.expand/collapse_shape(linalg.fill(%val, %init))` to +/// `linalg.fill(%val, tensor.expand/collapse_shape(%init))`. +template +struct SwapTensorShapeChangePatternsWithFillOp final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(T reshapeOp, + PatternRewriter &rewriter) const override { + auto fillOp = reshapeOp->getOperand(0).template getDefiningOp(); + if (!fillOp) { + return failure(); + } + auto reshapeInitOp = + rewriter.create(reshapeOp.getLoc(), reshapeOp.getResult().getType(), + fillOp.getDpsInitOperand(0)->get(), + reshapeOp.getReassociationIndices()); + rewriter.replaceOpWithNewOp(reshapeOp, fillOp.getInputs(), + reshapeInitOp.getResult()); + return success(); + } +}; + +void mlir::linalg::populateFoldLinalgFillPatterns(RewritePatternSet &patterns) { + patterns.add, + SwapTensorShapeChangePatternsWithFillOp>( + patterns.getContext()); } diff --git a/mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir b/mlir/test/Dialect/Linalg/fold-fill-patterns.mlir rename from mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir rename to mlir/test/Dialect/Linalg/fold-fill-patterns.mlir --- a/mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir +++ b/mlir/test/Dialect/Linalg/fold-fill-patterns.mlir @@ -1,4 +1,4 @@ -//RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-swap-extract-slice-with-fill-pattern %s | FileCheck %s +//RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-linalg-fill-patterns %s | FileCheck %s // CHECK-LABEL: func.func @swap_fill_insert_slice // CHECK-SAME: (%[[INIT:.+]]: tensor, %[[OFFSET0:.+]]: index, %[[SIZE1:.+]]: index) @@ -26,3 +26,29 @@ : tensor to tensor<2x?x6xf32> return %0, %1: tensor, tensor<2x?x6xf32> } + +// ----- + +func.func @fold_fill_with_collapse_shape(%cst : f32, %init : tensor) -> tensor { + %0 = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor + %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4]] : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func @fold_fill_with_collapse_shape +// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor +// CHECK: %[[INIT_COLLAPSE:.+]] = tensor.collapse_shape %[[INIT]] +// CHECK: %[[FILL:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT_COLLAPSE]] : + +// ----- + +func.func @fold_fill_with_expand_shape(%cst : f32, %init : tensor) -> tensor { + %0 = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor + %1 = tensor.expand_shape %0 [[0, 1], [2, 3, 4]] : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func @fold_fill_with_expand_shape +// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor +// CHECK: %[[INIT_EXPAND:.+]] = tensor.expand_shape %[[INIT]] +// CHECK: %[[FILL:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT_EXPAND]] : diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -110,10 +110,10 @@ llvm::cl::desc("Test rewrite of linalgOp + extract_slice into " "extract_slice + linalgOp"), llvm::cl::init(false)}; - Option testSwapExtractSliceWithFill{ - *this, "test-swap-extract-slice-with-fill-pattern", + Option testFoldLinalgFillPatterns{ + *this, "test-fold-linalg-fill-patterns", llvm::cl::desc( - "Test patterns to swap tensor.extract_slice(linalg.fill())"), + "Test patterns to fold linalg.fill with tensor operations"), llvm::cl::init(false)}; Option testEraseUnusedOperandsAndResults{ *this, "test-erase-unused-operands-and-results", @@ -191,7 +191,7 @@ static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); - populateSwapExtractSliceWithFillPatterns(patterns); + populateFoldLinalgFillPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } @@ -225,7 +225,7 @@ return applyExtractSliceOfPadTensorSwapPattern(getOperation()); if (testBubbleUpExtractSliceOpPattern) return applyBubbleUpExtractSliceOpPattern(getOperation()); - if (testSwapExtractSliceWithFill) + if (testFoldLinalgFillPatterns) return applySwapExtractSliceWithFillPattern(getOperation()); if (testEraseUnusedOperandsAndResults) return applyEraseUnusedOperandsAndResultsPatterns(getOperation());