diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -41,6 +41,12 @@ /// destination tensor of its producer tensor.insert_slice op. void populateExtractFromInsertSliceDestOpPatterns(RewritePatternSet &patterns); +/// Collects patterns to hoist pairing tensor.extract_slice/insert_slice ops +/// out of scf.for loops when possible. The slice op pair should have matching +/// loop invairant offsets/sizes/strides; they should extract from and insert +/// into the same loop carried value. +void populateHoistExtractInsertSliceOpPatterns(RewritePatternSet &patterns); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ ExtractFromInsertSliceDest.cpp ExtractSliceFromReshape.cpp FoldConsecutiveInsertExtract.cpp + HoistExtractInsertSlice.cpp SplitPadding.cpp SwapExtractSliceWithProducer.cpp diff --git a/mlir/lib/Dialect/Tensor/Transforms/HoistExtractInsertSlice.cpp b/mlir/lib/Dialect/Tensor/Transforms/HoistExtractInsertSlice.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/HoistExtractInsertSlice.cpp @@ -0,0 +1,214 @@ +//===- HoistExtractInsertSlice.cpp ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "mlir-hoist-extract-insert-slice" + +using namespace mlir; +using namespace mlir::tensor; + +namespace { + +/// Hoists pairs of tensor.extract_slice and tensor.insert_slice ops out of the +/// surrounding scf.for loops. +/// +/// This requires the extract/insert slice op pair to have the exact same +/// loop-invariant offsets, strides, and sizes. Also they should extract from / +/// insert into the same loop carried value. +struct HoistExtractInsertSlice : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const override { + for (unsigned i = 0; i < forOp.getNumRegionIterArgs(); ++i) + if (succeeded(hoistLoopCarriedValueUses(forOp, i, rewriter))) + return success(); + + return failure(); + } + + /// Hoists extract/insert slice ops that are users of the `index`-th loop + /// carried value out of the given `forOp`. + LogicalResult hoistLoopCarriedValueUses(scf::ForOp forOp, unsigned index, + PatternRewriter &rewriter) const { + Value loopValue = forOp.getRegionIterArgs()[index]; + LLVM_DEBUG(llvm::dbgs() << "inspecting loop value #" << index << "\n"); + // Make sure the users of the loop carried value is all insert/extract + // slice ops. This helps to simplify further logic. + SmallVector extractOps; + for (Operation *user : loopValue.getUsers()) { + if (auto op = dyn_cast(user)) { + extractOps.push_back(op); + } else if (!isa(user)) { + LLVM_DEBUG(llvm::dbgs() + << "loop carried value has non extract/insert slice user\n"); + return failure(); + } + } + + InsertSliceOp insertOp = getHoistableInsertSlice(forOp, index); + if (!insertOp) + return failure(); + ExtractSliceOp extractOp = findMatchingExtractSlice(insertOp, extractOps); + if (!extractOp) + return failure(); + + hoistExtractInsertSlice(forOp, index, extractOp, insertOp); + + return success(); + } + + /// Verifies that the `index`-th yielded value is coming from a hoistable + /// insert_slice op and returns the insert_slice op. + InsertSliceOp getHoistableInsertSlice(scf::ForOp forOp, + unsigned index) const { + // Expect the yielded value is coming from a insert_slice op. + auto yieldOp = cast(forOp.getBody()->getTerminator()); + Value yieldValue = yieldOp.getOperands()[index]; + auto insertOp = yieldValue.getDefiningOp(); + if (!insertOp) { + LLVM_DEBUG(llvm::dbgs() + << "yielded value not coming from insert slice op\n"); + return nullptr; + } + LLVM_DEBUG(llvm::dbgs() << "last insert op: " << insertOp << "\n"); + + // The destination tensor of the insert_slice op should be the block + // argument representing the loop carried value. + Value insertDest = insertOp.getDest(); + auto destBlockArg = insertDest.dyn_cast(); + if (!destBlockArg) { + // Allow a chain of insert_slice ops that build upon on another. But the + // first insert_slice op must insert into the block argument. + while (auto prevOp = insertDest.getDefiningOp()) { + LLVM_DEBUG(llvm::dbgs() << "prevous insert op: " << prevOp << "\n"); + // To be conservative, require all the previous slices they should be + // disjoint from this one. + if (!isDisjointSlices(prevOp, insertOp)) { + LLVM_DEBUG(llvm::dbgs() << "insert slice op not disjoint with: " + << prevOp << "\n"); + return nullptr; + } + + insertDest = prevOp.getDest(); + destBlockArg = insertDest.dyn_cast(); + } + } + if (!destBlockArg) { + LLVM_DEBUG(llvm::dbgs() + << "no insert slice (chain) updating loop carried value\n"); + return nullptr; + } + if (destBlockArg.getOwner()->getParentOp() != forOp) { + LLVM_DEBUG(llvm::dbgs() + << "insert slice updating other loop's carried value\n"); + return nullptr; + } + if (destBlockArg.getArgNumber() != index + 1) { + LLVM_DEBUG(llvm::dbgs() + << "index mismatch between yield and insert slice dest\n"); + return nullptr; + } + + // All insert_slice offsets/sizes/strides must be loop invariant. + for (Value v : insertOp->getOperands().drop_front( + InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex())) { + if (!forOp.isDefinedOutsideOfLoop(v)) { + LLVM_DEBUG(llvm::dbgs() + << "slice offset/size/stride defined inside loop:" << v + << "\n"); + return nullptr; + } + } + + LLVM_DEBUG(llvm::dbgs() << "hoistable insert op: " << insertOp << "\n"); + return insertOp; + } + + /// Finds the extract_slice op that have the same offsets/strides/sizes as the + /// given `insertOp` from `extractOps`. + ExtractSliceOp + findMatchingExtractSlice(InsertSliceOp insertOp, + ArrayRef extractOps) const { + unsigned opIndex = 0; + ExtractSliceOp extractOp; + for (; opIndex < extractOps.size(); ++opIndex) { + extractOp = extractOps[opIndex]; + const auto isSame = [](OpFoldResult a, OpFoldResult b) { + return getConstantIntValue(a) == getConstantIntValue(b); + }; + if (extractOp.getType() == insertOp.getSourceType() && + extractOp.isSameAs(insertOp, isSame)) + break; + } + if (opIndex == extractOps.size()) { + LLVM_DEBUG(llvm::dbgs() + << "missing matched extract slice for yielded insert slice\n"); + return nullptr; + } + + // To be conservative, make sure all extract_slice ops folowing this one are + // disjoint. (We have already checked before insert_slice ops are disjoint.) + for (++opIndex; opIndex < extractOps.size(); ++opIndex) + if (!isDisjointSlices(extractOps[opIndex], extractOp)) { + LLVM_DEBUG(llvm::dbgs() << "insert slice op chain not disjoint with: " + << extractOps[opIndex] << "\n"); + return nullptr; + } + + LLVM_DEBUG(llvm::dbgs() << "matching extract op: " << extractOp << "\n"); + return extractOp; + } + + /// Hoists the `extractOp` and `insertOp` pair that updates the `index`-th + /// loop carried value out of the given `forOp`. + void hoistExtractInsertSlice(scf::ForOp forOp, unsigned index, + ExtractSliceOp extractOp, + InsertSliceOp insertOp) const { + // Update the extract_slice op's source and move it out. + extractOp.getSourceMutable().assign(forOp.getInitArgs()[index]); + forOp.moveOutOfLoop(extractOp); + + // Update the terminator yielded value and move the insert_slice op out. + auto yieldOp = cast(forOp.getBody()->getTerminator()); + yieldOp->setOperand(index, insertOp.getDest()); + insertOp->moveAfter(forOp); + + // Build a new loop to additionally yield the insert_slice op's source. + OpBuilder builder(forOp); + NewYieldValueFn yieldFn = [&](OpBuilder &, Location, + ArrayRef) { + return SmallVector{insertOp.getSource()}; + }; + auto newForOp = replaceLoopWithNewYields(builder, forOp, + extractOp.getResult(), yieldFn); + + // Point all uses of the loop result value to the hoisted insert_slice. + newForOp.getResult(index).replaceAllUsesWith(insertOp.getResult()); + // Fix hoisted insert_slice op's source and destination tensors. + insertOp.getSourceMutable().assign(newForOp.getResults().back()); + insertOp.getDestMutable().assign(newForOp.getResult(index)); + + forOp.erase(); + } +}; + +} // namespace + +void tensor::populateHoistExtractInsertSliceOpPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Tensor/hoist-extract-insert-slice.mlir b/mlir/test/Dialect/Tensor/hoist-extract-insert-slice.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/hoist-extract-insert-slice.mlir @@ -0,0 +1,206 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-hoist-extract-insert-slice -allow-unregistered-dialect -canonicalize %s | FileCheck %s + +func.func @hoist_slices_in_double_loop( + %input: tensor<1x9x9x3xf32>, %filter: tensor<3x3x3x16xf32>, %init: tensor<1x2x2x4xf32>, + %offset0: index, %offset1: index, %offset2: index) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %0 = scf.for %iv0 = %c0 to %c3 step %c1 iter_args(%arg0 = %init) -> (tensor<1x2x2x4xf32>) { + %1 = scf.for %iv1 = %c0 to %c3 step %c1 iter_args(%arg1 = %arg0) -> (tensor<1x2x2x4xf32>) { + %2 = affine.apply affine_map<(d0, d1) -> (d0 * 2 + d1)>(%offset0, %iv0) + %3 = affine.apply affine_map<(d0, d1) -> (d0 * 2 + d1)>(%offset1, %iv1) + %4 = tensor.extract_slice %input[%c0, %2, %3, %c0] [1, 1, 3, 3] [1, 1, 1, 1] : tensor<1x9x9x3xf32> to tensor<1x3x3xf32> + %5 = tensor.extract_slice %filter[%iv0, %iv1, %c0, %offset2] [1, 1, 3, 4] [1, 1, 1, 1] : tensor<3x3x3x16xf32> to tensor<1x3x4xf32> + %6 = tensor.extract_slice %arg1[%c0, %c0, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %7 = linalg.conv_1d_nwc_wcf {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>} + ins(%4, %5 : tensor<1x3x3xf32>, tensor<1x3x4xf32>) outs(%6 : tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %8 = tensor.insert_slice %7 into %arg1[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + %9 = affine.apply affine_map<(d0, d1) -> (d0 * 2 + d1 + 2)>(%offset0, %iv0) + %10 = affine.apply affine_map<(d0, d1) -> (d0 * 2 + d1)>(%offset1, %iv1) + %11 = tensor.extract_slice %input[%c0, %9, %10, %c0] [1, 1, 3, 3] [1, 1, 1, 1] : tensor<1x9x9x3xf32> to tensor<1x3x3xf32> + %12 = tensor.extract_slice %filter[%iv0, %iv1, %c0, %offset2] [1, 1, 3, 4] [1, 1, 1, 1] : tensor<3x3x3x16xf32> to tensor<1x3x4xf32> + %13 = tensor.extract_slice %arg1[%c0, %c1, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %14 = linalg.conv_1d_nwc_wcf {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>} + ins(%11, %12 : tensor<1x3x3xf32>, tensor<1x3x4xf32>) outs(%13 : tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %15 = tensor.insert_slice %14 into %8[0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + scf.yield %15 : tensor<1x2x2x4xf32> + } + scf.yield %1 : tensor<1x2x2x4xf32> + } + return %0 : tensor<1x2x2x4xf32> +} + +// CHECK-LABEL: func.func @hoist_slices_in_double_loop +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x9x9x3xf32>, %[[FILTER:.+]]: tensor<3x3x3x16xf32>, %[[INIT:.+]]: tensor<1x2x2x4xf32> +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C3:.+]] = arith.constant 3 : index +// CHECK: %[[INIT_SLICE1:.+]] = tensor.extract_slice %[[INIT]][0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> +// CHECK: %[[INIT_SLICE0:.+]] = tensor.extract_slice %[[INIT]][0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> +// CHECK: %[[FOR0:.+]]:2 = scf.for %[[IV0:.+]] = %[[C0]] to %[[C3]] step %[[C1]] +// CHECK-SAME: iter_args(%[[FOR0_ARG1:.+]] = %[[INIT_SLICE1]], %[[FOR0_ARG0:.+]] = %[[INIT_SLICE0]]) +// CHECK: %[[FOR1:.+]]:2 = scf.for %[[IV1:.+]] = %[[C0]] to %[[C3]] step %[[C1]] +// CHECK-SAME: iter_args(%[[FOR1_ARG1:.+]] = %[[FOR0_ARG1]], %[[FOR1_ARG0:.+]] = %[[FOR0_ARG0]]) +// CHECK: %[[INPUT_SLICE0:.+]] = tensor.extract_slice %[[INPUT]] +// CHECK: %[[FILTER_SLICE0:.+]] = tensor.extract_slice %[[FILTER]] +// CHECK: %[[CONV0:.+]] = linalg.conv_1d_nwc_wcf +// CHECK-SAME: ins(%[[INPUT_SLICE0]], %[[FILTER_SLICE0]] +// CHECK-SAME: outs(%[[FOR1_ARG0]] : tensor<1x2x4xf32>) +// CHECK: %[[INPUT_SLICE1:.+]] = tensor.extract_slice %[[INPUT]] +// CHECK: %[[FILTER_SLICE1:.+]] = tensor.extract_slice %[[FILTER]] +// CHECK: %[[CONV1:.+]] = linalg.conv_1d_nwc_wcf +// CHECK-SAME: ins(%[[INPUT_SLICE1]], %[[FILTER_SLICE1]] +// CHECK-SAME: outs(%[[FOR1_ARG1]] : tensor<1x2x4xf32>) +// CHECK: scf.yield %[[CONV1]], %[[CONV0]] : tensor<1x2x4xf32>, tensor<1x2x4xf32> +// CHECK: } +// CHECK: scf.yield %[[FOR1]]#0, %[[FOR1]]#1 : tensor<1x2x4xf32>, tensor<1x2x4xf32> +// CHECK: } +// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[FOR0]]#1 into %[[INIT]][0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> +// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[FOR0]]#0 into %[[INSERT0]][0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> +// CHECK: return %[[INSERT1]] : tensor<1x2x2x4xf32> + +// ----- + +func.func @dont_hoist_non_extract_insert_slice_usage_of_loop_carried_value( + %input: tensor<1x9x9x3xf32>, %filter: tensor<3x3x3x16xf32>, %init: tensor<1x2x2x4xf32>, + %offset0: index, %offset1: index, %offset2: index) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %0 = scf.for %iv0 = %c0 to %c3 step %c1 iter_args(%arg0 = %init) -> (tensor<1x2x2x4xf32>) { + %1 = scf.for %iv1 = %c0 to %c3 step %c1 iter_args(%arg1 = %arg0) -> (tensor<1x2x2x4xf32>) { + %2 = affine.apply affine_map<(d0, d1) -> (d0 * 2 + d1)>(%offset0, %iv0) + %3 = affine.apply affine_map<(d0, d1) -> (d0 * 2 + d1)>(%offset1, %iv1) + %4 = tensor.extract_slice %input[%c0, %2, %3, %c0] [1, 1, 3, 3] [1, 1, 1, 1] : tensor<1x9x9x3xf32> to tensor<1x3x3xf32> + %5 = tensor.extract_slice %filter[%iv0, %iv1, %c0, %offset2] [1, 1, 3, 4] [1, 1, 1, 1] : tensor<3x3x3x16xf32> to tensor<1x3x4xf32> + %6 = tensor.extract_slice %arg1[%c0, %c0, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %7 = linalg.conv_1d_nwc_wcf {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>} + ins(%4, %5 : tensor<1x3x3xf32>, tensor<1x3x4xf32>) outs(%6 : tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %8 = tensor.insert_slice %7 into %arg1[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + "dialect.op"(%arg1) : (tensor<1x2x2x4xf32>) -> () + scf.yield %8 : tensor<1x2x2x4xf32> + } + scf.yield %1 : tensor<1x2x2x4xf32> + } + return %0 : tensor<1x2x2x4xf32> +} + +// CHECK-LABEL: func.func @dont_hoist_non_extract_insert_slice_usage_of_loop_carried_value +// CHECK-NOT: tensor.extract_slice +// CHECK-COUNT-2: scf.for +// CHECK-COUNT-3: tensor.extract_slice +// CHECK: tensor.insert_slice +// CHECK-COUNT-2: scf.yield +// CHECK-NOT: tensor.insert_slice + +// ----- + +func.func @dont_hoist_loop_dependent_slice_parameters( + %input: tensor<1x9x9x3xf32>, %filter: tensor<3x3x3x16xf32>, %init: tensor<1x2x2x4xf32>, + %offset0: index, %offset1: index, %offset2: index) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %0 = scf.for %iv0 = %c0 to %c3 step %c1 iter_args(%arg0 = %init) -> (tensor<1x2x2x4xf32>) { + %1 = scf.for %iv1 = %c0 to %c3 step %c1 iter_args(%arg1 = %arg0) -> (tensor<1x2x2x4xf32>) { + %2 = affine.apply affine_map<(d0, d1) -> (d0 * 2 + d1)>(%offset0, %iv0) + %3 = affine.apply affine_map<(d0, d1) -> (d0 * 2 + d1)>(%offset1, %iv1) + %4 = tensor.extract_slice %input[%c0, %2, %3, %c0] [1, 1, 3, 3] [1, 1, 1, 1] : tensor<1x9x9x3xf32> to tensor<1x3x3xf32> + %5 = tensor.extract_slice %filter[%iv0, %iv1, %c0, %offset2] [1, 1, 3, 4] [1, 1, 1, 1] : tensor<3x3x3x16xf32> to tensor<1x3x4xf32> + %mod = affine.apply affine_map<(d0) -> (d0 mod 2)>(%iv1) + %6 = tensor.extract_slice %arg1[%c0, %mod, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %7 = linalg.conv_1d_nwc_wcf {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>} + ins(%4, %5 : tensor<1x3x3xf32>, tensor<1x3x4xf32>) outs(%6 : tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %8 = tensor.insert_slice %7 into %arg1[0, %mod, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + "dialect.op"(%arg1) : (tensor<1x2x2x4xf32>) -> () + scf.yield %8 : tensor<1x2x2x4xf32> + } + scf.yield %1 : tensor<1x2x2x4xf32> + } + return %0 : tensor<1x2x2x4xf32> +} + +// CHECK-LABEL: func.func @dont_hoist_loop_dependent_slice_parameters +// CHECK-NOT: tensor.extract_slice +// CHECK-COUNT-2: scf.for +// CHECK-COUNT-3: tensor.extract_slice +// CHECK: tensor.insert_slice +// CHECK-COUNT-2: scf.yield +// CHECK-NOT: tensor.insert_slice + +// ----- + +func.func @dont_hoist_slices_not_disjoint( + %input: tensor<1x9x9x3xf32>, %filter: tensor<3x3x3x16xf32>, %init: tensor<1x2x2x4xf32>, + %offset0: index, %offset1: index, %offset2: index) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %0 = scf.for %iv0 = %c0 to %c3 step %c1 iter_args(%arg0 = %init) -> (tensor<1x2x2x4xf32>) { + %1 = scf.for %iv1 = %c0 to %c3 step %c1 iter_args(%arg1 = %arg0) -> (tensor<1x2x2x4xf32>) { + %2 = affine.apply affine_map<(d0, d1) -> (d0 * 2 + d1)>(%offset0, %iv0) + %3 = affine.apply affine_map<(d0, d1) -> (d0 * 2 + d1)>(%offset1, %iv1) + %4 = tensor.extract_slice %input[%c0, %2, %3, %c0] [1, 1, 3, 3] [1, 1, 1, 1] : tensor<1x9x9x3xf32> to tensor<1x3x3xf32> + %5 = tensor.extract_slice %filter[%iv0, %iv1, %c0, %offset2] [1, 1, 3, 4] [1, 1, 1, 1] : tensor<3x3x3x16xf32> to tensor<1x3x4xf32> + %6 = tensor.extract_slice %arg1[%c0, %c0, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %7 = linalg.conv_1d_nwc_wcf {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>} + ins(%4, %5 : tensor<1x3x3xf32>, tensor<1x3x4xf32>) outs(%6 : tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %8 = tensor.insert_slice %7 into %arg1[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + %9 = affine.apply affine_map<(d0, d1) -> (d0 * 2 + d1 + 2)>(%offset0, %iv0) + %10 = affine.apply affine_map<(d0, d1) -> (d0 * 2 + d1)>(%offset1, %iv1) + %11 = tensor.extract_slice %input[%c0, %9, %10, %c0] [1, 1, 3, 3] [1, 1, 1, 1] : tensor<1x9x9x3xf32> to tensor<1x3x3xf32> + %12 = tensor.extract_slice %filter[%iv0, %iv1, %c0, %offset2] [1, 1, 3, 4] [1, 1, 1, 1] : tensor<3x3x3x16xf32> to tensor<1x3x4xf32> + %13 = tensor.extract_slice %arg1[%c0, %c0, %c1, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %14 = linalg.conv_1d_nwc_wcf {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>} + ins(%11, %12 : tensor<1x3x3xf32>, tensor<1x3x4xf32>) outs(%13 : tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %15 = tensor.insert_slice %14 into %8[0, 0, 1, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + scf.yield %15 : tensor<1x2x2x4xf32> + } + scf.yield %1 : tensor<1x2x2x4xf32> + } + return %0 : tensor<1x2x2x4xf32> +} + +// CHECK-LABEL: func.func @dont_hoist_slices_not_disjoint +// CHECK-NOT: tensor.extract_slice +// CHECK-COUNT-2: scf.for +// CHECK-COUNT-3: tensor.extract_slice +// CHECK: tensor.insert_slice +// CHECK-COUNT-3: tensor.extract_slice +// CHECK: tensor.insert_slice +// CHECK-COUNT-2: scf.yield +// CHECK-NOT: tensor.insert_slice + +// ----- + +func.func @dont_hoist_mismatched_extract_insert_slice( + %input: tensor<1x9x9x3xf32>, %filter: tensor<3x3x3x16xf32>, %init: tensor<1x2x2x4xf32>, + %offset0: index, %offset1: index, %offset2: index) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %0 = scf.for %iv0 = %c0 to %c3 step %c1 iter_args(%arg0 = %init) -> (tensor<1x2x2x4xf32>) { + %1 = scf.for %iv1 = %c0 to %c3 step %c1 iter_args(%arg1 = %arg0) -> (tensor<1x2x2x4xf32>) { + %2 = affine.apply affine_map<(d0, d1) -> (d0 * 2 + d1)>(%offset0, %iv0) + %3 = affine.apply affine_map<(d0, d1) -> (d0 * 2 + d1)>(%offset1, %iv1) + %4 = tensor.extract_slice %input[%c0, %2, %3, %c0] [1, 1, 3, 3] [1, 1, 1, 1] : tensor<1x9x9x3xf32> to tensor<1x3x3xf32> + %5 = tensor.extract_slice %filter[%iv0, %iv1, %c0, %offset2] [1, 1, 3, 4] [1, 1, 1, 1] : tensor<3x3x3x16xf32> to tensor<1x3x4xf32> + %6 = tensor.extract_slice %arg1[%c0, %c0, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %7 = linalg.conv_1d_nwc_wcf {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>} + ins(%4, %5 : tensor<1x3x3xf32>, tensor<1x3x4xf32>) outs(%6 : tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %8 = tensor.insert_slice %7 into %arg1[0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + scf.yield %8 : tensor<1x2x2x4xf32> + } + scf.yield %1 : tensor<1x2x2x4xf32> + } + return %0 : tensor<1x2x2x4xf32> +} + +// CHECK-LABEL: func.func @dont_hoist_mismatched_extract_insert_slice +// CHECK-NOT: tensor.extract_slice +// CHECK-COUNT-2: scf.for +// CHECK-COUNT-3: tensor.extract_slice +// CHECK: tensor.insert_slice +// CHECK-COUNT-2: scf.yield +// CHECK-NOT: tensor.insert_slice diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -65,6 +65,12 @@ "Test folding consecutive tensor.insert_slice/tensor.extract_slice"), llvm::cl::init(false)}; + Option testHoistExtractInsertSlice{ + *this, "test-hoist-extract-insert-slice", + llvm::cl::desc("Test hoisting tensor.insert_slice/tensor.extract_slice " + "out of loops"), + llvm::cl::init(false)}; + Option testRewriteExtractSliceWithTiledCollapseShape{ *this, "test-rewrite-extract-slice-from-collapse-shape", llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape " @@ -114,6 +120,13 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +static void applyHoistExtractInsertSlicePatterns(Operation *rootOp) { + RewritePatternSet patterns(rootOp->getContext()); + tensor::populateHoistExtractInsertSliceOpPatterns(patterns); + scf::ForOp::getCanonicalizationPatterns(patterns, rootOp->getContext()); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + namespace { /// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`. /// The `tensor.extract_slice` is replaced by a loop or gather operation that @@ -249,6 +262,8 @@ applyFoldConsecutiveInsertExtractSlicePatterns(rootOp); if (testExtractFrominsertSliceDest) applyExtractFromInsertSliceDestPatterns(rootOp); + if (testHoistExtractInsertSlice) + applyHoistExtractInsertSlicePatterns(rootOp); if (testRewriteExtractSliceWithTiledCollapseShape) { if (failed( applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))