diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -13,6 +13,11 @@ #include "mlir/IR/PatternMatch.h" namespace mlir { + +namespace scf { +class ForOp; +} // namespace scf + namespace tensor { /// Populates `patterns` with patterns to wrap a tensor.pad op with an scf.if op @@ -40,6 +45,16 @@ /// destination tensor of its producer tensor.insert_slice op. void populateExtractFromInsertSliceDestOpPatterns(RewritePatternSet &patterns); +/// Hoists pairing tensor.extract_slice/insert_slice ops out of scf.for loops +/// when possible. The slice op pair should have matching loop invairant +/// offsets/sizes/strides; they should extract from and insert into the same +/// loop carried value. +scf::ForOp hoistTensorExtractInsertSliceOps(scf::ForOp forOp, + OpBuilder &builder); +/// Collects patterns to hoist pairing tensor.extract_slice/insert_slice ops +/// out of scf.for loops when possible. This wraps the above utility function. +void populateHoistExtractInsertSliceOpPatterns(RewritePatternSet &patterns); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ Bufferize.cpp ExtractFromInsertSliceDestPatterns.cpp ExtractSliceFromReshapeUtils.cpp + HoistExtractInsertSlicePatterns.cpp MergeConsecutiveInsertExtractSlicePatterns.cpp SplitPaddingPatterns.cpp SwapExtractSliceWithProducerPatterns.cpp @@ -23,6 +24,7 @@ MLIRMemRefDialect MLIRPass MLIRSCFDialect + MLIRSCFUtils MLIRTensorDialect MLIRTilingInterface MLIRTransforms diff --git a/mlir/lib/Dialect/Tensor/Transforms/HoistExtractInsertSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/HoistExtractInsertSlicePatterns.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/HoistExtractInsertSlicePatterns.cpp @@ -0,0 +1,277 @@ +//===- HoistExtractInsertSlicePatterns.cpp --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "mlir-hoist-extract-insert-slice" + +using namespace mlir; +using namespace mlir::tensor; + +/// Verifies that the `index`-th yielded value is coming from a hoistable +/// insert_slice op and returns the insert_slice op. +static InsertSliceOp +getHoistableInsertSlice(scf::ForOp forOp, unsigned index, + ArrayRef insertOps) { + auto yieldOp = cast(forOp.getBody()->getTerminator()); + Value yieldValue = yieldOp.getOperands()[index]; + + // Expect the yielded value to come from a insert_slice op. + auto insertOp = yieldValue.getDefiningOp(); + if (!insertOp) { + LLVM_DEBUG(llvm::dbgs() + << "yielded value not coming from insert slice op\n"); + return nullptr; + } + if (!insertOp->hasOneUse()) { + LLVM_DEBUG(llvm::dbgs() << "insert slice has more than one use\n"); + return nullptr; + } + LLVM_DEBUG(llvm::dbgs() << "yielded insert op: " << insertOp << "\n"); + + // Make sure this insert_slice op is updating some loop carried value. + // All insert_slice ops doing that is previously collected in `insertOps`. + if (!llvm::is_contained(insertOps, insertOp)) { + LLVM_DEBUG(llvm::dbgs() + << "insert slice op not updating loop carried value\n"); + return nullptr; + } + + // The destination tensor of the insert_slice op should be the block + // argument representing the loop carried value. + Value insertDest = insertOp.getDest(); + auto destBlockArg = insertDest.dyn_cast(); + if (!destBlockArg) { + // Allow a chain of insert_slice ops that build upon on another. But the + // first insert_slice op must insert into the block argument. + while (auto prevOp = insertDest.getDefiningOp()) { + insertDest = prevOp.getDest(); + destBlockArg = insertDest.dyn_cast(); + } + } + + // Guaranteed by `insertOp` in `insertOps`. But double check: + assert(destBlockArg && destBlockArg.getOwner()->getParentOp() == forOp); + + if (destBlockArg.getArgNumber() != index + 1) { + LLVM_DEBUG(llvm::dbgs() + << "index mismatch between yield and insert slice dest\n"); + return nullptr; + } + + // All insert_slice offsets/sizes/strides must be loop invariant. + for (Value v : insertOp->getOperands().drop_front( + InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex())) { + if (!forOp.isDefinedOutsideOfLoop(v)) { + LLVM_DEBUG(llvm::dbgs() << "slice offset/size/stride defined inside loop:" + << v << "\n"); + return nullptr; + } + } + + LLVM_DEBUG(llvm::dbgs() << "hoistable insert op: " << insertOp << "\n"); + return insertOp; +} + +/// Finds the extract_slice op that have the same offsets/strides/sizes as the +/// given `insertOp` from `extractOps`. +static ExtractSliceOp +findMatchingExtractSlice(InsertSliceOp insertOp, + ArrayRef extractOps) { + unsigned opIndex = 0; + ExtractSliceOp extractOp; + for (; opIndex < extractOps.size(); ++opIndex) { + extractOp = extractOps[opIndex]; + const auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; + if (extractOp.getType() == insertOp.getSourceType() && + extractOp.isSameAs(insertOp, isSame)) + break; + } + if (opIndex == extractOps.size()) { + LLVM_DEBUG(llvm::dbgs() + << "missing matched extract slice for yielded insert slice\n"); + return nullptr; + } + + LLVM_DEBUG(llvm::dbgs() << "matching extract op: " << extractOp << "\n"); + return extractOp; +} + +using ForOpEraseFn = std::function; + +/// Hoists the `extractOp` and `insertOp` pair that updates the `index`-th loop +/// carried value out of the given `forOp` and returns the new scf.for op. +static scf::ForOp hoistExtractInsertSlice(scf::ForOp forOp, unsigned index, + ExtractSliceOp extractOp, + InsertSliceOp insertOp, + OpBuilder &builder, + const ForOpEraseFn &forOpEraseFn) { + // Update the extract_slice op's source and move it out. + extractOp.getSourceMutable().assign(forOp.getInitArgs()[index]); + forOp.moveOutOfLoop(extractOp); + + // Update the terminator yielded value and move the insert_slice op out. + auto yieldOp = cast(forOp.getBody()->getTerminator()); + yieldOp->setOperand(index, insertOp.getDest()); + insertOp->moveAfter(forOp); + + // Build a new loop to additionally yield the insert_slice op's source. + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(forOp); + NewYieldValueFn yieldFn = [&](OpBuilder &, Location, + ArrayRef) { + return SmallVector{insertOp.getSource()}; + }; + auto newForOp = + replaceLoopWithNewYields(builder, forOp, extractOp.getResult(), yieldFn); + + // Point all uses of the loop result value to the hoisted insert_slice. + newForOp.getResult(index).replaceAllUsesWith(insertOp.getResult()); + // Fix hoisted insert_slice op's source and destination tensors. + insertOp.getSourceMutable().assign(newForOp.getResults().back()); + insertOp.getDestMutable().assign(newForOp.getResult(index)); + + forOpEraseFn(forOp); + return newForOp; +} + +/// Collects and appends all children insert_slice ops from the given `seedOp` +/// into `insertOps`, and returns true if the insert_slice op chain rooting from +/// `seeOp` does not have other users than scf.yield ops. +static bool collectInsertSliceChain(InsertSliceOp seedOp, + SmallVectorImpl &insertOps) { + SmallVector worklist; + worklist.push_back(seedOp); + while (!worklist.empty()) { + InsertSliceOp insertOp = worklist.pop_back_val(); + insertOps.push_back(insertOp); + for (Operation *user : insertOp.getResult().getUsers()) { + if (auto userOp = dyn_cast(user)) { + worklist.push_back(userOp); + } else if (!isa(user)) { + LLVM_DEBUG(llvm::dbgs() + << "non extract/insert slice user of loop carried value: " + << *user << "\n"); + return false; + } + } + } + return true; +} + +/// Hoists extract/insert slice ops that are users of the `index`-th loop +/// carried value out of the given `forOp`. Returns the new scf.for op on +/// success; returns nullptr otherwise. +static scf::ForOp hoistLoopCarriedValueUses(scf::ForOp forOp, unsigned index, + OpBuilder &builder, + const ForOpEraseFn &forOpEraseFn) { + Value loopValue = forOp.getRegionIterArgs()[index]; + LLVM_DEBUG(llvm::dbgs() << "checking loop carried value #" << index << "\n"); + + // Make sure the users of the loop carried value is all insert/extract + // slice ops. This helps to simplify further logic. + SmallVector extractOps; + SmallVector insertOps; + for (Operation *user : loopValue.getUsers()) { + if (auto op = dyn_cast(user)) { + extractOps.push_back(op); + continue; + } + if (auto op = dyn_cast(user)) { + if (!collectInsertSliceChain(op, insertOps)) + return nullptr; + continue; + } + LLVM_DEBUG(llvm::dbgs() + << "non extract/insert slice user of loop carried value: " + << *user << "\n"); + return nullptr; + } + + InsertSliceOp insertOp = getHoistableInsertSlice(forOp, index, insertOps); + if (!insertOp) + return nullptr; + // To be conservative, require all other insert slice ops be disjoint with the + // one to hoist out. + for (InsertSliceOp otherOp : insertOps) { + if (otherOp != insertOp && !areDisjointSlices(otherOp, insertOp)) { + LLVM_DEBUG(llvm::dbgs() + << "insert slice op not disjoint with: " << otherOp << "\n"); + return nullptr; + } + } + + ExtractSliceOp extractOp = findMatchingExtractSlice(insertOp, extractOps); + if (!extractOp) + return nullptr; + // To be conservative, require all other extract slice ops be disjoint with + // the one to hoist out. + for (ExtractSliceOp otherOp : extractOps) { + if (otherOp != extractOp && !areDisjointSlices(otherOp, extractOp)) { + LLVM_DEBUG(llvm::dbgs() + << "extract slice op not disjoint with: " << otherOp << "\n"); + return nullptr; + } + } + + return hoistExtractInsertSlice(forOp, index, extractOp, insertOp, builder, + forOpEraseFn); +} + +scf::ForOp tensor::hoistTensorExtractInsertSliceOps(scf::ForOp forOp, + OpBuilder &builder) { + auto eraseFn = [](scf::ForOp forOp) { forOp->erase(); }; + bool changed = true; + while (changed) { + changed = false; + for (unsigned i = 0; i < forOp.getNumRegionIterArgs(); ++i) + if (auto newOp = hoistLoopCarriedValueUses(forOp, i, builder, eraseFn)) { + forOp = newOp; // Use the new scf.for op for next iteration + changed = true; + break; + } + }; + return forOp; +} + +namespace { + +/// Hoists pairs of tensor.extract_slice and tensor.insert_slice ops out of the +/// surrounding scf.for loops. +/// +/// This requires the extract/insert slice op pair to have the exact same +/// loop-invariant offsets, strides, and sizes. Also they should extract from / +/// insert into the same loop carried value. +struct HoistExtractInsertSlice : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const override { + auto eraseFn = [&](scf::ForOp op) { rewriter.eraseOp(op); }; + for (unsigned i = 0; i < forOp.getNumRegionIterArgs(); ++i) + if (hoistLoopCarriedValueUses(forOp, i, rewriter, eraseFn)) + return success(); + return failure(); + } +}; + +} // namespace + +void tensor::populateHoistExtractInsertSliceOpPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Tensor/hoist-extract-insert-slice.mlir b/mlir/test/Dialect/Tensor/hoist-extract-insert-slice.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/hoist-extract-insert-slice.mlir @@ -0,0 +1,262 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-hoist-extract-insert-slice -allow-unregistered-dialect -canonicalize %s | FileCheck %s + +func.func @hoist_slices_in_double_loop( + %input: tensor<1x9x9x3xf32>, %filter: tensor<3x3x3x16xf32>, %init: tensor<1x2x2x4xf32>, + %offset0: index, %offset1: index, %offset2: index) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %0 = scf.for %iv0 = %c0 to %c3 step %c1 iter_args(%arg0 = %init) -> (tensor<1x2x2x4xf32>) { + %1 = scf.for %iv1 = %c0 to %c3 step %c1 iter_args(%arg1 = %arg0) -> (tensor<1x2x2x4xf32>) { + %6 = tensor.extract_slice %arg1[%c0, %c0, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %7 = "normal.compute"(%6) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %8 = tensor.insert_slice %7 into %arg1[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + %13 = tensor.extract_slice %arg1[%c0, %c1, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %14 = "normal.compute"(%13) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %15 = tensor.insert_slice %14 into %8[0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + scf.yield %15 : tensor<1x2x2x4xf32> + } + scf.yield %1 : tensor<1x2x2x4xf32> + } + return %0 : tensor<1x2x2x4xf32> +} + +// CHECK-LABEL: func.func @hoist_slices_in_double_loop +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x9x9x3xf32>, %[[FILTER:.+]]: tensor<3x3x3x16xf32>, %[[INIT:.+]]: tensor<1x2x2x4xf32> +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C3:.+]] = arith.constant 3 : index +// CHECK: %[[INIT_SLICE1:.+]] = tensor.extract_slice %[[INIT]][0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> +// CHECK: %[[INIT_SLICE0:.+]] = tensor.extract_slice %[[INIT]][0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> +// CHECK: %[[FOR0:.+]]:2 = scf.for %[[IV0:.+]] = %[[C0]] to %[[C3]] step %[[C1]] +// CHECK-SAME: iter_args(%[[FOR0_ARG1:.+]] = %[[INIT_SLICE1]], %[[FOR0_ARG0:.+]] = %[[INIT_SLICE0]]) +// CHECK: %[[FOR1:.+]]:2 = scf.for %[[IV1:.+]] = %[[C0]] to %[[C3]] step %[[C1]] +// CHECK-SAME: iter_args(%[[FOR1_ARG1:.+]] = %[[FOR0_ARG1]], %[[FOR1_ARG0:.+]] = %[[FOR0_ARG0]]) +// CHECK: %[[COMP0:.+]] = "normal.compute"(%[[FOR1_ARG0]]) +// CHECK: %[[COMP1:.+]] = "normal.compute"(%[[FOR1_ARG1]]) +// CHECK: scf.yield %[[COMP1]], %[[COMP0]] : tensor<1x2x4xf32>, tensor<1x2x4xf32> +// CHECK: } +// CHECK: scf.yield %[[FOR1]]#0, %[[FOR1]]#1 : tensor<1x2x4xf32>, tensor<1x2x4xf32> +// CHECK: } +// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[FOR0]]#1 into %[[INIT]][0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] +// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[FOR0]]#0 into %[[INSERT0]][0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] +// CHECK: return %[[INSERT1]] : tensor<1x2x2x4xf32> + +// ----- + +func.func @hoist_long_extract_insert_chain( + %input: tensor<1x9x9x3xf32>, %filter: tensor<3x3x3x16xf32>, %init: tensor<1x4x2x4xf32>, + %offset0: index, %offset1: index, %offset2: index) -> tensor<1x4x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %0 = scf.for %iv0 = %c0 to %c3 step %c1 iter_args(%arg0 = %init) -> (tensor<1x4x2x4xf32>) { + %1 = scf.for %iv1 = %c0 to %c3 step %c1 iter_args(%arg1 = %arg0) -> (tensor<1x4x2x4xf32>) { + %6 = tensor.extract_slice %arg1[%c0, %c0, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x4x2x4xf32> to tensor<1x2x4xf32> + %7 = "normal.compute"(%6) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %8 = tensor.insert_slice %7 into %arg1[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x4x2x4xf32> + + %13 = tensor.extract_slice %arg1[%c0, %c1, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x4x2x4xf32> to tensor<1x2x4xf32> + %14 = "normal.compute"(%13) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %15 = tensor.insert_slice %14 into %8[0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x4x2x4xf32> + + %16 = tensor.extract_slice %arg1[%c0, %c2, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x4x2x4xf32> to tensor<1x2x4xf32> + %17 = "normal.compute"(%16) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %18 = tensor.insert_slice %17 into %15[0, 2, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x4x2x4xf32> + + %19 = tensor.extract_slice %arg1[%c0, %c3, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x4x2x4xf32> to tensor<1x2x4xf32> + %20 = "normal.compute"(%19) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %21 = tensor.insert_slice %20 into %18[0, 3, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x4x2x4xf32> + scf.yield %21 : tensor<1x4x2x4xf32> + } + scf.yield %1 : tensor<1x4x2x4xf32> + } + return %0 : tensor<1x4x2x4xf32> +} + +// CHECK-LABEL: func.func @hoist_long_extract_insert_chain +// CHECK-COUNT-4: tensor.extract_slice +// CHECK-COUNT-2: scf.for +// CHECK-COUNT-2: scf.yield +// CHECK-COUNT-4: tensor.insert_slice + +// ----- + +func.func @dont_hoist_non_extract_insert_slice_usage_of_loop_carried_value( + %init: tensor<1x2x2x4xf32>, + %offset0: index, %offset1: index, %offset2: index) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %0 = scf.for %iv0 = %c0 to %c3 step %c1 iter_args(%arg0 = %init) -> (tensor<1x2x2x4xf32>) { + %1 = scf.for %iv1 = %c0 to %c3 step %c1 iter_args(%arg1 = %arg0) -> (tensor<1x2x2x4xf32>) { + %6 = tensor.extract_slice %arg1[%c0, %c0, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %7 = "normal.compute"(%6) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %8 = tensor.insert_slice %7 into %arg1[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + "blocking.usage"(%arg1) : (tensor<1x2x2x4xf32>) -> () + scf.yield %8 : tensor<1x2x2x4xf32> + } + scf.yield %1 : tensor<1x2x2x4xf32> + } + return %0 : tensor<1x2x2x4xf32> +} + +// CHECK-LABEL: func.func @dont_hoist_non_extract_insert_slice_usage_of_loop_carried_value +// CHECK-NOT: tensor.extract_slice +// CHECK-COUNT-2: scf.for +// CHECK: tensor.extract_slice +// CHECK: tensor.insert_slice +// CHECK-COUNT-2: scf.yield +// CHECK-NOT: tensor.insert_slice + +// ----- + +func.func @dont_hoist_multi_insert_slice_uses( + %init: tensor<1x2x2x4xf32>, + %offset0: index, %offset1: index, %offset2: index) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %0 = scf.for %iv0 = %c0 to %c3 step %c1 iter_args(%arg0 = %init) -> (tensor<1x2x2x4xf32>) { + %1 = scf.for %iv1 = %c0 to %c3 step %c1 iter_args(%arg1 = %arg0) -> (tensor<1x2x2x4xf32>) { + %6 = tensor.extract_slice %arg1[%c0, %c0, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %7 = "normal.compute"(%6) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %8 = tensor.insert_slice %7 into %arg1[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + "blocking.usage"(%8) : (tensor<1x2x2x4xf32>) -> () + scf.yield %8 : tensor<1x2x2x4xf32> + } + scf.yield %1 : tensor<1x2x2x4xf32> + } + return %0 : tensor<1x2x2x4xf32> +} + +// CHECK-LABEL: func.func @dont_hoist_multi_insert_slice_uses +// CHECK-NOT: tensor.extract_slice +// CHECK-COUNT-2: scf.for +// CHECK: tensor.extract_slice +// CHECK: tensor.insert_slice +// CHECK-COUNT-2: scf.yield +// CHECK-NOT: tensor.insert_slice + +// ----- + +func.func @dont_hoist_loop_dependent_slice_parameters( + %init: tensor<1x2x2x4xf32>, + %offset0: index, %offset1: index, %offset2: index) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %0 = scf.for %iv0 = %c0 to %c3 step %c1 iter_args(%arg0 = %init) -> (tensor<1x2x2x4xf32>) { + %1 = scf.for %iv1 = %c0 to %c3 step %c1 iter_args(%arg1 = %arg0) -> (tensor<1x2x2x4xf32>) { + %mod = affine.apply affine_map<(d0) -> (d0 mod 2)>(%iv1) + %6 = tensor.extract_slice %arg1[%c0, %mod, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %7 = "normal.compute"(%6) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %8 = tensor.insert_slice %7 into %arg1[0, %mod, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + scf.yield %8 : tensor<1x2x2x4xf32> + } + scf.yield %1 : tensor<1x2x2x4xf32> + } + return %0 : tensor<1x2x2x4xf32> +} + +// CHECK-LABEL: func.func @dont_hoist_loop_dependent_slice_parameters +// CHECK-NOT: tensor.extract_slice +// CHECK-COUNT-2: scf.for +// CHECK: tensor.extract_slice +// CHECK: tensor.insert_slice +// CHECK-COUNT-2: scf.yield +// CHECK-NOT: tensor.insert_slice + +// ----- + +func.func @dont_hoist_out_of_dependent_loops( + %init: tensor<1x2x2x4xf32>, + %offset0: index, %offset1: index, %offset2: index) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %output = scf.for %iv = %c0 to %c2 step %c1 iter_args(%arg = %init) -> (tensor<1x2x2x4xf32>) { + %0 = scf.for %iv0 = %c0 to %c3 step %c1 iter_args(%arg0 = %arg) -> (tensor<1x2x2x4xf32>) { + %1 = scf.for %iv1 = %c0 to %c3 step %c1 iter_args(%arg1 = %arg0) -> (tensor<1x2x2x4xf32>) { + %6 = tensor.extract_slice %arg1[%c0, %iv, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %7 = "normal.compute"(%6) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %8 = tensor.insert_slice %7 into %arg1[0, %iv, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + scf.yield %8 : tensor<1x2x2x4xf32> + } + scf.yield %1 : tensor<1x2x2x4xf32> + } + scf.yield %0 : tensor<1x2x2x4xf32> + } + return %output : tensor<1x2x2x4xf32> +} + +// CHECK-LABEL: func.func @dont_hoist_out_of_dependent_loops +// CHECK: scf.for +// CHECK: tensor.extract_slice +// CHECK-COUNT-2: scf.for +// CHECK-COUNT-2: scf.yield +// CHECK: tensor.insert_slice +// CHECK: scf.yield + +// ----- + +func.func @dont_hoist_insert_slices_not_disjoint( + %init: tensor<1x2x2x4xf32>, + %offset0: index, %offset1: index, %offset2: index) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %0 = scf.for %iv0 = %c0 to %c3 step %c1 iter_args(%arg0 = %init) -> (tensor<1x2x2x4xf32>) { + %1 = scf.for %iv1 = %c0 to %c3 step %c1 iter_args(%arg1 = %arg0) -> (tensor<1x2x2x4xf32>) { + %6 = tensor.extract_slice %arg1[%c0, %c0, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %7 = "normal.compute"(%6) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %8 = tensor.insert_slice %7 into %arg1[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + %13 = tensor.extract_slice %arg1[%c0, %c0, %c1, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %14 = "normal.compute"(%13) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %15 = tensor.insert_slice %14 into %8[0, 0, 1, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + scf.yield %15 : tensor<1x2x2x4xf32> + } + scf.yield %1 : tensor<1x2x2x4xf32> + } + return %0 : tensor<1x2x2x4xf32> +} + +// CHECK-LABEL: func.func @dont_hoist_insert_slices_not_disjoint +// CHECK-NOT: tensor.extract_slice +// CHECK-COUNT-2: scf.for +// CHECK: tensor.extract_slice +// CHECK: tensor.insert_slice +// CHECK: tensor.extract_slice +// CHECK: tensor.insert_slice +// CHECK-COUNT-2: scf.yield +// CHECK-NOT: tensor.insert_slice + +// ----- + +func.func @dont_hoist_mismatched_extract_insert_slice( + %init: tensor<1x2x2x4xf32>, + %offset0: index, %offset1: index, %offset2: index) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %0 = scf.for %iv0 = %c0 to %c3 step %c1 iter_args(%arg0 = %init) -> (tensor<1x2x2x4xf32>) { + %1 = scf.for %iv1 = %c0 to %c3 step %c1 iter_args(%arg1 = %arg0) -> (tensor<1x2x2x4xf32>) { + %6 = tensor.extract_slice %arg1[%c0, %c0, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %7 = "normal.compute"(%6) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %8 = tensor.insert_slice %7 into %arg1[0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + scf.yield %8 : tensor<1x2x2x4xf32> + } + scf.yield %1 : tensor<1x2x2x4xf32> + } + return %0 : tensor<1x2x2x4xf32> +} + +// CHECK-LABEL: func.func @dont_hoist_mismatched_extract_insert_slice +// CHECK-NOT: tensor.extract_slice +// CHECK-COUNT-2: scf.for +// CHECK: tensor.extract_slice +// CHECK: tensor.insert_slice +// CHECK-COUNT-2: scf.yield +// CHECK-NOT: tensor.insert_slice diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -65,6 +65,12 @@ "destination tensor"), llvm::cl::init(false)}; + Option testHoistExtractInsertSlice{ + *this, "test-hoist-extract-insert-slice", + llvm::cl::desc("Test hoisting tensor.insert_slice/tensor.extract_slice " + "out of loops"), + llvm::cl::init(false)}; + Option testRewriteExtractSliceWithTiledCollapseShape{ *this, "test-rewrite-extract-slice-from-collapse-shape", llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape " @@ -114,6 +120,16 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +static void applyHoistExtractInsertSlicePatterns(Operation *rootOp) { + MLIRContext *ctx = rootOp->getContext(); + RewritePatternSet patterns(ctx); + tensor::populateHoistExtractInsertSliceOpPatterns(patterns); + tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); + tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); + scf::ForOp::getCanonicalizationPatterns(patterns, ctx); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + namespace { /// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`. /// The `tensor.extract_slice` is replaced by a loop or gather operation that @@ -261,6 +277,8 @@ applyFoldConsecutiveInsertExtractSlicePatterns(rootOp); if (testExtractFrominsertSliceDest) applyExtractFromInsertSliceDestPatterns(rootOp); + if (testHoistExtractInsertSlice) + applyHoistExtractInsertSlicePatterns(rootOp); if (testRewriteExtractSliceWithTiledCollapseShape) { if (failed( applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach))) diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5226,6 +5226,7 @@ ":MemRefDialect", ":Pass", ":SCFDialect", + ":SCFUtils", ":TensorDialect", ":TensorPassIncGen", ":TilingInterface",