diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -110,6 +110,10 @@ /// Patterns that are used to bubble up extract slice op above linalg op. void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); +/// Adds patterns that waps tensor.extract_slice(linalg.fill(%cst, %init)) into +/// linalg.fill(%cst, tensor.extract_slice(%init)). +void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns); + /// Return true if two `linalg.generic` operations with producer/consumer /// relationship through `fusedOperand` can be fused using elementwise op /// fusion. diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -24,6 +24,7 @@ Promotion.cpp Split.cpp SplitReduction.cpp + SwapExtractSliceWithFillPatterns.cpp Tiling.cpp TilingInterfaceImpl.cpp Transforms.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp @@ -0,0 +1,41 @@ +//===- SwapExtractSliceWithFillPatterns.cpp -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::linalg; + +/// Swaps tensor.extract_slice(linalg.fill(%cst, %init)) into linalg.fill(%cst, +/// tensor.extract_slice(%init)) when the linalg.fill op have no other users. +/// This helps to reduce the fill footprint. +struct SwapExtractSliceOfFill final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp, + PatternRewriter &rewriter) const override { + auto fillOp = extractOp.getSource().getDefiningOp(); + if (!fillOp || !fillOp->hasOneUse()) + return failure(); + + auto newExtractOp = rewriter.create( + extractOp.getLoc(), extractOp.getType(), fillOp.getOutputs()[0], + extractOp.getMixedOffsets(), extractOp.getMixedSizes(), + extractOp.getMixedStrides()); + rewriter.replaceOpWithNewOp(extractOp, fillOp.getInputs(), + ValueRange{newExtractOp.getResult()}); + return success(); + } +}; + +void mlir::linalg::populateSwapExtractSliceWithFillPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir b/mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir @@ -0,0 +1,28 @@ +//RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-swap-extract-slice-with-fill-pattern %s | FileCheck %s + +// CHECK-LABEL: func.func @swap_fill_insert_slice +// CHECK-SAME: (%[[INIT:.+]]: tensor, %[[OFFSET0:.+]]: index, %[[SIZE1:.+]]: index) +// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EXT:.+]] = tensor.extract_slice %[[INIT]][%[[OFFSET0]], 8, 4] [1, %[[SIZE1]], 6] [1, 3, 1] +// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[EXT]] : tensor) -> tensor +// CHECK: return %[[FILL]] +func.func @swap_fill_insert_slice(%init : tensor, %offset0: index, %size1: index) -> tensor { + %f0 = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill ins(%f0 : f32) outs(%init : tensor) -> tensor + %1 = tensor.extract_slice %0[%offset0, 8, 4] [1, %size1, 6] [1, 3, 1] + : tensor to tensor + return %1: tensor +} + +// ----- + +// CHECK-LABEL: func.func @dont_swap_fill_insert_slice_multi_user +// CHECK: linalg.fill +// CHECK: tensor.extract_slice +func.func @dont_swap_fill_insert_slice_multi_user(%init : tensor, %offset0: index, %size1: index) -> (tensor, tensor<2x?x6xf32>) { + %f0 = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill ins(%f0 : f32) outs(%init : tensor) -> tensor + %1 = tensor.extract_slice %0[%offset0, 8, 4] [2, %size1, 6] [1, 3, 1] + : tensor to tensor<2x?x6xf32> + return %0, %1: tensor, tensor<2x?x6xf32> +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -123,6 +123,11 @@ llvm::cl::desc("Test rewrite of linalgOp + extract_slice into " "extract_slice + linalgOp"), llvm::cl::init(false)}; + Option testSwapExtractSliceWithFill{ + *this, "test-swap-extract-slice-with-fill-pattern", + llvm::cl::desc( + "Test patterns to swap tensor.extract_slice(linalg.fill())"), + llvm::cl::init(false)}; }; } // namespace @@ -508,6 +513,12 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateSwapExtractSliceWithFillPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { auto lambda = [&](void *) { @@ -551,6 +562,8 @@ return applySplitReduction(getOperation()); if (testBubbleUpExtractSliceOpPattern) return applyBubbleUpExtractSliceOpPattern(getOperation()); + if (testSwapExtractSliceWithFill) + return applySwapExtractSliceWithFillPattern(getOperation()); } namespace mlir {