diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -63,6 +63,7 @@ promoteSingleIterationLoops(cast(op)); hoistViewAllocOps(cast(op)); hoistRedundantVectorTransfers(cast(op)); + hoistRedundantVectorTransfersOnTensor(cast(op)); return success(); }; (void)linalg::applyStagedPatterns( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -21,10 +21,13 @@ #include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dominance.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" +using llvm::dbgs; + #define DEBUG_TYPE "linalg-hoisting" #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") @@ -32,8 +35,6 @@ using namespace mlir; using namespace mlir::linalg; -using llvm::dbgs; - void mlir::linalg::hoistViewAllocOps(FuncOp func) { bool changed = true; while (changed) { @@ -81,35 +82,145 @@ } } -/// Look for a transfer_read, in the given tensor uses, accessing the same -/// offset as the transfer_write. -static vector::TransferReadOp -findMatchingTransferRead(vector::TransferWriteOp write, Value srcTensor) { +namespace { +/// Represents a unit of hoistable TransferWriteOp. This may comprise other +/// instructions that need to be hoisted too. +struct HoistableWrite { + vector::TransferWriteOp transferWriteOp; + SubTensorInsertOp subTensorInsertOp; +}; +/// Represents a unit of hoistable TransferReadOp. This may comprise other +/// instructions that need to be hoisted too. +struct HoistableRead { + vector::TransferReadOp transferReadOp; + SubTensorOp subTensorOp; +}; +} // namespace + +/// Return true if op1 and op2 are the same constant or the same SSA value. +static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) { + auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional { + Attribute attr = ofr.dyn_cast(); + // Note: isa+cast-like pattern allows writing the condition below as 1 line. + if (!attr && ofr.get().getDefiningOp()) + attr = ofr.get().getDefiningOp().getValue(); + if (auto intAttr = attr.dyn_cast_or_null()) + return intAttr.getValue().getSExtValue(); + return llvm::None; + }; + auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2); + if (cst1 && cst2 && *cst1 == *cst2) + return true; + auto v1 = op1.dyn_cast(), v2 = op2.dyn_cast(); + return v1 && v2 && v1 == v2; +} + +/// Return true is all offsets, sizes and strides are equal. +static bool sameOffsetsSizesAndStrides(SubTensorOp s, SubTensorInsertOp si) { + if (s.static_offsets().size() != si.static_offsets().size()) + return false; + if (s.static_sizes().size() != si.static_sizes().size()) + return false; + if (s.static_strides().size() != si.static_strides().size()) + return false; + for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets())) + if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) + return false; + for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes())) + if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) + return false; + for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides())) + if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) + return false; + return true; +} + +/// Look for a HoistableRead, in the given tensor uses, accessing the same +/// offset as the HoistableWrite. +static HoistableRead findMatchingTransferRead(HoistableWrite write, + Value srcTensor) { + assert(write.transferWriteOp && + "expected hoistable write to have a .transfer_write"); + + LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: " + << *write.transferWriteOp.getOperation() << "\n"); + if (write.subTensorInsertOp) + LLVM_DEBUG(DBGS() << "findMatchingTransferRead subTensorInsertOp: " + << *write.subTensorInsertOp.getOperation() << "\n"); + for (Operation *user : srcTensor.getUsers()) { - auto read = dyn_cast(user); - if (read && read.indices() == write.indices() && - read.getVectorType() == write.getVectorType()) { - return read; + LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user + << "\n"); + + // If HoistableWrite involves a SubTensorInsertOp, we need to find a + // matching SubTensorOp. + SubTensorOp subTensorOp; + Operation *maybeTransferReadUser = user; + if (write.subTensorInsertOp) { + subTensorOp = dyn_cast(user); + if (!subTensorOp || subTensorOp.getResult().getType() != + write.subTensorInsertOp.source().getType()) + continue; + + LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: " + << *subTensorOp << " vs " << *write.subTensorInsertOp + << "\n"); + if (!sameOffsetsSizesAndStrides(subTensorOp, write.subTensorInsertOp)) + continue; + + LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n"); + // If we got here, subTensorOp is hoistable iff it has exactly 2 uses: + // 1. the transfer_write we want to hoist. + // 2. a matching transfer_read. + // Anything else, we skip. + bool skip = false; + Operation *otherUser = nullptr; + for (Operation *u : subTensorOp->getUsers()) { + if (u == write.transferWriteOp) + continue; + if (otherUser) { + skip = true; + break; + } + otherUser = u; + } + if (skip || !otherUser) + continue; + maybeTransferReadUser = otherUser; } + + LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser + << "\n"); + auto read = dyn_cast(maybeTransferReadUser); + if (read && read.indices() == write.transferWriteOp.indices() && + read.getVectorType() == write.transferWriteOp.getVectorType()) + return HoistableRead{read, subTensorOp}; } - return nullptr; + return HoistableRead(); } -/// Check if the chunk of data inserted by the transfer_write in the given -/// tensor are read by any other op than the read candidate. -static bool tensorChunkAccessedByUnknownOp(vector::TransferWriteOp write, - vector::TransferReadOp candidateRead, - Value srcTensor) { +/// Check if the chunk of data inserted by the HoistableWrite are read by any +/// other op than the HoistableRead candidate. +static bool tensorChunkAccessedByUnknownOp(HoistableWrite write, + HoistableRead candidateRead, + BlockArgument tensorArg) { // Make sure none of the other uses read the part of the tensor modified // by the transfer_write. llvm::SmallVector uses; - uses.push_back(srcTensor.getUses()); + uses.push_back(tensorArg.getUses()); while (!uses.empty()) { for (OpOperand &use : uses.pop_back_val()) { Operation *user = use.getOwner(); // Skip the candidate use, only inspect the "other" uses. - if (user == candidateRead.getOperation() || user == write.getOperation()) + if (user == candidateRead.transferReadOp || + user == candidateRead.subTensorOp || user == write.transferWriteOp || + user == write.subTensorInsertOp) continue; + // Consider all transitive uses through a subtensor / subtensor_insert. + // TODO: atm we just bail because a stronger analysis is needed for these + // cases. + if (isa(user)) + return true; // Consider all transitive uses through a vector.transfer_write. if (auto writeUser = dyn_cast(user)) { uses.push_back(writeUser->getResult(0).getUses()); @@ -128,8 +239,8 @@ // Follow the use yield as long as it doesn't escape the original // region. scf::YieldOp yieldUser = dyn_cast(user); - if (yieldUser && - write->getParentOp()->isAncestor(yieldUser->getParentOp())) { + if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor( + yieldUser->getParentOp())) { Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber()); uses.push_back(ret.getUses()); continue; @@ -137,7 +248,8 @@ auto read = dyn_cast(user); if (!read || !isDisjointTransferIndices( cast(read.getOperation()), - cast(write.getOperation()))) { + cast( + write.transferWriteOp.getOperation()))) { return true; } } @@ -145,6 +257,118 @@ return false; } +/// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`. +/// Return the null HoistableWrite() if it is not comprised of a +/// vector.transfer_write + optional subtensor_insert or if any of the indexings +/// is `forOp`-dependent. +static HoistableWrite +getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp, + OpOperand &yieldOperand) { + Value v = yieldOperand.get(); + if (auto write = v.getDefiningOp()) { + // Indexing must not depend on `forOp`. + for (Value operand : write.indices()) + if (!forOp.isDefinedOutsideOfLoop(operand)) + return HoistableWrite(); + + return HoistableWrite{write, nullptr}; + } + + if (auto subTensorInsertOp = v.getDefiningOp()) { + // Inserted subTensor must come from vector.transfer_write. + auto write = + subTensorInsertOp.source().getDefiningOp(); + if (!write) + return HoistableWrite(); + + // Tensor inserted into must be a BBArg at position matching yieldOperand's. + auto bbArg = subTensorInsertOp.dest().dyn_cast(); + if (!bbArg || bbArg.getOwner()->getParentOp() != forOp || + bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber()) + return HoistableWrite(); + + // Indexing inserted into must not depend on `forOp`. + for (Value operand : subTensorInsertOp->getOperands().drop_front( + SubTensorInsertOp::getOffsetSizeAndStrideStartOperandIndex())) + if (!forOp.isDefinedOutsideOfLoop(operand)) + return HoistableWrite(); + + return HoistableWrite{write, subTensorInsertOp}; + } + + return HoistableWrite(); +} + +/// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair. +static void hoistReadWrite(HoistableRead read, HoistableWrite write, + BlockArgument tensorBBArg) { + scf::ForOp forOp = cast(tensorBBArg.getOwner()->getParentOp()); + assert(read.transferReadOp && write.transferWriteOp && + "expected transfer_read and transfer_write ops to be set"); + assert(((read.subTensorOp && write.subTensorInsertOp) || + (!read.subTensorOp && !write.subTensorInsertOp)) && + "expected matching subtensor / subtensor_insert"); + LLVM_DEBUG(DBGS() << "In forOp:\n" + << *forOp.getOperation() + << "\nHoist: " << *read.transferReadOp.getOperation() + << "\nHoist: " << *write.transferWriteOp.getOperation() + << "\nInvolving: " << tensorBBArg << "\n"); + + // If a read subtensor is present, hoist it. + if (read.subTensorOp && failed(forOp.moveOutOfLoop({read.subTensorOp}))) + llvm_unreachable("Unexpected failure moving subtensor out of loop"); + + // Hoist the transfer_read op. + if (failed(forOp.moveOutOfLoop({read.transferReadOp}))) + llvm_unreachable("Unexpected failure moving transfer read out of loop"); + + // TODO: don't hardcode /*numIvs=*/1. + assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); + unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; + + // Update the source tensor. + if (read.subTensorOp) + read.subTensorOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]); + else + read.transferReadOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]); + + // Hoist write after. + if (write.subTensorInsertOp) + write.subTensorInsertOp->moveAfter(forOp); + write.transferWriteOp->moveAfter(forOp); + + // Update the yield. + auto yieldOp = cast(forOp.region().front().getTerminator()); + if (write.subTensorInsertOp) + yieldOp->setOperand(initArgNumber, write.subTensorInsertOp.dest()); + else + yieldOp->setOperand(initArgNumber, write.transferWriteOp.source()); + + // Rewrite `loop` with additional new yields. + OpBuilder b(read.transferReadOp); + auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(), + write.transferWriteOp.vector()); + // Transfer write has been hoisted, need to update the vector and tensor + // source. Replace the result of the loop to use the new tensor created + // outside the loop. + // Depending on whether a subtensor_insert is present or not, it carries the + // update on the tensor operands. + if (write.subTensorInsertOp) { + newForOp.getResult(initArgNumber) + .replaceAllUsesWith(write.subTensorInsertOp.getResult()); + write.transferWriteOp.sourceMutable().assign(read.subTensorOp.result()); + write.subTensorInsertOp.destMutable().assign(read.subTensorOp.source()); + } else { + newForOp.getResult(initArgNumber) + .replaceAllUsesWith(write.transferWriteOp.getResult(0)); + write.transferWriteOp.sourceMutable().assign( + newForOp.getResult(initArgNumber)); + } + + // Always update with the newly yield tensor and vector. + write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back()); +} + // To hoist transfer op on tensor the logic can be significantly simplified // compared to the case on buffer. The transformation follows this logic: // 1. Look for transfer_write with a single use from ForOp yield @@ -163,57 +387,48 @@ func.walk([&](scf::ForOp forOp) { Operation *yield = forOp.getBody()->getTerminator(); for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) { - Value ret = yield->getOperand(it.index()); - auto write = ret.getDefiningOp(); - if (!write || !write->hasOneUse()) + OpOperand &ret = yield->getOpOperand(it.index()); + HoistableWrite write = + getLoopInvariantTransferWriteOpDefining(forOp, ret); + if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse()) continue; - LLVM_DEBUG(DBGS() << "Candidate write for hoisting: " - << *write.getOperation() << "\n"); - if (llvm::any_of(write.indices(), [&forOp](Value index) { - return !forOp.isDefinedOutsideOfLoop(index); - })) + LLVM_DEBUG(dbgs() << "\n"; + DBGS() << "Candidate write for hoisting: " + << *write.transferWriteOp.getOperation() << "\n"); + if (write.subTensorInsertOp) + LLVM_DEBUG(DBGS() << "Candidate subtensor_insert for hoisting: " + << *write.subTensorInsertOp.getOperation() << "\n"); + if (llvm::any_of(write.transferWriteOp.indices(), + [&forOp](Value index) { + return !forOp.isDefinedOutsideOfLoop(index); + })) continue; // Find a read with the same type and indices. - vector::TransferReadOp matchingRead = + HoistableRead matchingRead = findMatchingTransferRead(write, it.value()); // Make sure none of the other uses read the part of the tensor modified // by the transfer_write. - if (!matchingRead || + if (!matchingRead.transferReadOp || tensorChunkAccessedByUnknownOp(write, matchingRead, it.value())) continue; - // Hoist read before. - if (failed(forOp.moveOutOfLoop({matchingRead}))) - llvm_unreachable( - "Unexpected failure to move transfer read out of loop"); - // Update the source tensor. - matchingRead.sourceMutable().assign(forOp.initArgs()[it.index()]); - - // Hoist write after. - write->moveAfter(forOp); - yield->setOperand(it.index(), write.source()); - - // Rewrite `loop` with new yields by cloning and erase the original - // loop. - OpBuilder b(matchingRead); - auto newForOp = - cloneWithNewYields(b, forOp, matchingRead.vector(), write.vector()); - - // Transfer write has been hoisted, need to update the vector and tensor - // source. Replace the result of the loop to use the new tensor created - // outside the loop. - newForOp.getResult(it.index()).replaceAllUsesWith(write.getResult(0)); - write.vectorMutable().assign(newForOp.getResults().back()); - write.sourceMutable().assign(newForOp.getResult(it.index())); - + LLVM_DEBUG(DBGS() << "Start hoisting\n"); + hoistReadWrite(matchingRead, write, it.value()); changed = true; forOp.erase(); - // Need to interrupt and restart because erasing the loop messes up the - // walk. + + // Need to interrupt and restart: erasing the loop messes up the walk. return WalkResult::interrupt(); } return WalkResult::advance(); }); + // Apply canonicalization so the newForOp + yield folds immediately, thus + // cleaning up the IR and potentially enabling more hoisting. + if (changed) { + OwningRewritePatternList patterns; + scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext()); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + } } } diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -1,5 +1,7 @@ -// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-view-allocs -allow-unregistered-dialect | FileCheck %s -// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-redundant-transfers -allow-unregistered-dialect | FileCheck %s --check-prefix=VECTOR_TRANSFERS +// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-view-allocs -allow-unregistered-dialect -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-redundant-transfers -allow-unregistered-dialect -split-input-file | FileCheck %s --check-prefix=VECTOR_TRANSFERS + +// ----- // CHECK-LABEL: func @hoist_allocs( // CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index, @@ -82,6 +84,8 @@ return } +// ----- + // VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs( // VECTOR_TRANSFERS-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref, // VECTOR_TRANSFERS-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref, @@ -152,6 +156,8 @@ return } +// ----- + // VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_disjoint( // VECTOR_TRANSFERS-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref, // VECTOR_TRANSFERS-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref, @@ -231,6 +237,8 @@ return } +// ----- + // VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_tensor func @hoist_vector_transfer_pairs_tensor( %tensor0: tensor, %tensor1: tensor, %tensor2: tensor, @@ -243,11 +251,10 @@ // VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor, vector<1xf32> // VECTOR_TRANSFERS: scf.for {{.*}} iter_args({{.*}}) -> -// VECTOR_TRANSFERS-SAME: (tensor, tensor, tensor, tensor, tensor, tensor, vector<1xf32>) { +// VECTOR_TRANSFERS-SAME: (tensor, tensor, tensor, tensor, tensor, vector<1xf32>) { // VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor, vector<2xf32> // VECTOR_TRANSFERS: scf.for {{.*}} iter_args({{.*}}) -> -// VECTOR_TRANSFERS-SAME: (tensor, tensor, tensor, tensor, tensor, tensor, vector<2xf32>, vector<1xf32>) { -// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor, vector<3xf32> +// VECTOR_TRANSFERS-SAME: (tensor, tensor, tensor, tensor, vector<2xf32>, vector<1xf32>) { // VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor, vector<4xf32> // VECTOR_TRANSFERS: "some_crippling_use"(%{{.*}}) : (tensor) -> () // VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor, vector<5xf32> @@ -261,11 +268,11 @@ // VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<5xf32>, tensor // VECTOR_TRANSFERS: "some_crippling_use"(%{{.*}}) : (tensor) -> () // VECTOR_TRANSFERS: scf.yield {{.*}} : -// VECTOR_TRANSFERS-SAME: tensor, tensor, tensor, tensor, tensor, tensor, vector<2xf32>, vector<1xf32> +// VECTOR_TRANSFERS-SAME: tensor, tensor, tensor, tensor, vector<2xf32>, vector<1xf32> // VECTOR_TRANSFERS: } // VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<2xf32>, tensor // VECTOR_TRANSFERS: scf.yield {{.*}} : -// VECTOR_TRANSFERS-SAME: tensor, tensor, tensor, tensor, tensor, tensor, vector<1xf32> +// VECTOR_TRANSFERS-SAME: tensor, tensor, tensor, tensor, tensor, vector<1xf32> // VECTOR_TRANSFERS: } // VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<1xf32>, tensor %0:6 = scf.for %i = %lb to %ub step %step @@ -280,7 +287,6 @@ tensor, tensor) { %r0 = vector.transfer_read %arg7[%c0, %c0], %cst: tensor, vector<1xf32> %r1 = vector.transfer_read %arg6[%i, %i], %cst: tensor, vector<2xf32> - %r2 = vector.transfer_read %arg8[%c0, %c0], %cst: tensor, vector<3xf32> %r3 = vector.transfer_read %arg9[%c0, %c0], %cst: tensor, vector<4xf32> "some_crippling_use"(%arg10) : (tensor) -> () %r4 = vector.transfer_read %arg10[%c0, %c0], %cst: tensor, vector<5xf32> @@ -312,6 +318,8 @@ tensor, tensor } +// ----- + // VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_disjoint_tensor( // VECTOR_TRANSFERS-SAME: %[[TENSOR0:[a-zA-Z0-9]*]]: tensor, // VECTOR_TRANSFERS-SAME: %[[TENSOR1:[a-zA-Z0-9]*]]: tensor, @@ -332,10 +340,10 @@ // VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor, vector<3xf32> // VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor, vector<4xf32> // VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor, vector<4xf32> -// VECTOR_TRANSFERS: %[[R:.*]]:8 = scf.for {{.*}} iter_args({{.*}}) -> -// VECTOR_TRANSFERS-SAME: (tensor, tensor, tensor, tensor, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) { +// VECTOR_TRANSFERS: %[[R:.*]]:6 = scf.for {{.*}} iter_args({{.*}}) -> +// VECTOR_TRANSFERS-SAME: (tensor, tensor, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) { // VECTOR_TRANSFERS: scf.for {{.*}} iter_args({{.*}}) -> -// VECTOR_TRANSFERS-SAME: (tensor, tensor, tensor, tensor, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) { +// VECTOR_TRANSFERS-SAME: (tensor, tensor, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) { // VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR1]]{{.*}} : tensor, vector<2xf32> // VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR1]]{{.*}} : tensor, vector<2xf32> // VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> @@ -349,15 +357,15 @@ // VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor // VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor // VECTOR_TRANSFERS: scf.yield {{.*}} : -// VECTOR_TRANSFERS-SAME: tensor, tensor, tensor, tensor, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32> +// VECTOR_TRANSFERS-SAME: tensor, tensor, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32> // VECTOR_TRANSFERS: } // VECTOR_TRANSFERS: scf.yield {{.*}} : -// VECTOR_TRANSFERS-SAME: tensor, tensor, tensor, tensor, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32> +// VECTOR_TRANSFERS-SAME: tensor, tensor, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32> // VECTOR_TRANSFERS: } -// VECTOR_TRANSFERS: %[[TENSOR4:.*]] = vector.transfer_write %{{.*}}, %[[R]]#3{{.*}} : vector<4xf32>, tensor -// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[TENSOR4]]{{.*}} : vector<4xf32>, tensor -// VECTOR_TRANSFERS: %[[TENSOR5:.*]] = vector.transfer_write %{{.*}}, %[[R]]#2{{.*}} : vector<3xf32>, tensor -// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[TENSOR5]]{{.*}} : vector<3xf32>, tensor +// VECTOR_TRANSFERS: %[[TENSOR4:.*]] = vector.transfer_write %[[R]]#5, %[[TENSOR3]]{{.*}} : vector<4xf32>, tensor +// VECTOR_TRANSFERS: vector.transfer_write %[[R]]#4, %[[TENSOR4]]{{.*}} : vector<4xf32>, tensor +// VECTOR_TRANSFERS: %[[TENSOR5:.*]] = vector.transfer_write %[[R]]#3, %[[TENSOR2]]{{.*}} : vector<3xf32>, tensor +// VECTOR_TRANSFERS: vector.transfer_write %[[R]]#2, %[[TENSOR5]]{{.*}} : vector<3xf32>, tensor %0:4 = scf.for %i = %lb to %ub step %step iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2, %arg3 = %tensor3) @@ -396,3 +404,111 @@ } return %0#0, %0#1, %0#2, %0#3 : tensor, tensor, tensor, tensor } + +// ----- + +// VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_tensor_and_subtensors +// VECTOR_TRANSFERS-SAME: %[[TENSOR0:[a-zA-Z0-9]*]]: tensor, +// VECTOR_TRANSFERS-SAME: %[[TENSOR1:[a-zA-Z0-9]*]]: tensor, +// VECTOR_TRANSFERS-SAME: %[[TENSOR2:[a-zA-Z0-9]*]]: tensor, +// VECTOR_TRANSFERS-SAME: %[[TENSOR3:[a-zA-Z0-9]*]]: tensor, +// VECTOR_TRANSFERS-SAME: %[[TENSOR4:[a-zA-Z0-9]*]]: tensor, +// VECTOR_TRANSFERS-SAME: %[[TENSOR5:[a-zA-Z0-9]*]]: tensor +func @hoist_vector_transfer_pairs_tensor_and_subtensors( + %tensor0: tensor, %tensor1: tensor, %tensor2: tensor, + %tensor3: tensor, %tensor4: tensor, %tensor5: tensor, + %val: index, %lb : index, %ub : index, %step: index) -> + ( + tensor, tensor, tensor//, tensor, tensor, tensor, tensor + ) { + %c0 = constant 0 : index + %cst = constant 0.0 : f32 + + // VECTOR_TRANSFERS: scf.for %[[I:.*]] = {{.*}} iter_args( + // VECTOR_TRANSFERS-SAME: %[[TENSOR0_ARG:[0-9a-zA-Z]+]] = %[[TENSOR0]], + // VECTOR_TRANSFERS-SAME: %[[TENSOR1_ARG:[0-9a-zA-Z]+]] = %[[TENSOR1]], + // VECTOR_TRANSFERS-SAME: %[[TENSOR2_ARG:[0-9a-zA-Z]+]] = %[[TENSOR2]] + // VECTOR_TRANSFERS-SAME: ) -> + // VECTOR_TRANSFERS-SAME: (tensor, tensor, tensor + %0:3 = scf.for %i = %lb to %ub step %step + iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2) + -> (tensor, tensor, tensor) { + + // Hoisted + // VECTOR_TRANSFERS: %[[ST0:.*]] = subtensor %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}}: tensor to tensor + // VECTOR_TRANSFERS: %[[V0:.*]] = vector.transfer_read %[[ST0]]{{.*}} : tensor, vector<1xf32> + + // VECTOR_TRANSFERS: %[[R:.*]]:3 = scf.for %[[J:.*]] = {{.*}} iter_args( + // VECTOR_TRANSFERS-SAME: %[[TENSOR1_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR1_ARG]] + // VECTOR_TRANSFERS-SAME: %[[TENSOR2_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR2_ARG]] + // VECTOR_TRANSFERS-SAME: %[[V0_ARG_L2:[0-9a-zA-Z]+]] = %[[V0]] + // VECTOR_TRANSFERS-SAME: ) -> + // VECTOR_TRANSFERS-SAME: (tensor, tensor, vector<1xf32> + %1:3 = scf.for %j = %lb to %ub step %step + iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2) + -> (tensor, tensor, tensor) { + // Hoists. + %st0 = subtensor %arg6[%i, %i][%step, %step][1, 1] : tensor to tensor + %r0 = vector.transfer_read %st0[%c0, %c0], %cst: tensor, vector<1xf32> + + // VECTOR_TRANSFERS: %[[ST1:.*]] = subtensor %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor to tensor + // VECTOR_TRANSFERS: %[[V1:.*]] = vector.transfer_read %[[ST1]]{{.*}} : tensor, vector<2xf32> + // Does not hoist (subtensor depends on %j) + %st1 = subtensor %arg7[%j, %c0][%step, %step][1, 1] : tensor to tensor + %r1 = vector.transfer_read %st1[%c0, %c0], %cst: tensor, vector<2xf32> + + // VECTOR_TRANSFERS: %[[ST2:.*]] = subtensor %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor to tensor + // VECTOR_TRANSFERS: %[[V2:.*]] = vector.transfer_read %[[ST2]]{{.*}} : tensor, vector<3xf32> + // Does not hoist, 2 subtensor %arg8. + %st2 = subtensor %arg8[%i, %c0][%step, %step][1, 1] : tensor to tensor + %r2 = vector.transfer_read %st2[%c0, %c0], %cst: tensor, vector<3xf32> + + // VECTOR_TRANSFERS: %[[U0:.*]] = "some_use"(%[[V0_ARG_L2]]) : (vector<1xf32>) -> vector<1xf32> + // VECTOR_TRANSFERS: %[[U1:.*]] = "some_use"(%[[V1]]) : (vector<2xf32>) -> vector<2xf32> + // VECTOR_TRANSFERS: %[[U2:.*]] = "some_use"(%[[V2]]) : (vector<3xf32>) -> vector<3xf32> + %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32> + %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32> + %u2 = "some_use"(%r2) : (vector<3xf32>) -> vector<3xf32> + + // Hoists + %w0 = vector.transfer_write %u0, %st0[%c0, %c0] : vector<1xf32>, tensor + + // VECTOR_TRANSFERS-DAG: %[[STI1:.*]] = vector.transfer_write %[[U1]], %{{.*}} : vector<2xf32>, tensor + // Does not hoist (associated subtensor depends on %j). + %w1 = vector.transfer_write %u1, %st1[%i, %i] : vector<2xf32>, tensor + + // VECTOR_TRANSFERS-DAG: %[[STI2:.*]] = vector.transfer_write %[[U2]], %{{.*}} : vector<3xf32>, tensor + // Does not hoist, 2 subtensor / subtensor_insert for %arg8. + %w2 = vector.transfer_write %u2, %st2[%c0, %c0] : vector<3xf32>, tensor + + // Hoists. + %sti0 = subtensor_insert %w0 into %arg6[%i, %i][%step, %step][1, 1] : tensor into tensor + + // VECTOR_TRANSFERS-DAG: subtensor_insert %[[STI1]] into %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor into tensor + // Does not hoist (depends on %j). + %sti1 = subtensor_insert %w1 into %arg7[%j, %c0][%step, %step][1, 1] : tensor into tensor + + // VECTOR_TRANSFERS-DAG: subtensor_insert %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor into tensor + // Does not hoist, 2 subtensor / subtensor_insert for %arg8. + %sti2 = subtensor_insert %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor into tensor + %st22 = subtensor %sti2[%i, %c0][%step, %step][1, 1] : tensor to tensor + %sti22 = subtensor_insert %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor into tensor + + // VECTOR_TRANSFERS: scf.yield {{.*}} : tensor, tensor, vector<1xf32> + // VECTOR_TRANSFERS: } + scf.yield %sti0, %sti1, %sti22: + tensor, tensor, tensor + } + + // Hoisted + // VECTOR_TRANSFERS: %[[STI0:.*]] = vector.transfer_write %[[R]]#2, %[[ST0]]{{.*}} : vector<1xf32>, tensor + // VECTOR_TRANSFERS: subtensor_insert %[[STI0]] into %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}} : tensor into tensor + + // VECTOR_TRANSFERS: scf.yield {{.*}} : tensor, tensor, tensor + scf.yield %1#0, %1#1, %1#2 : + tensor, tensor, tensor + + // VECTOR_TRANSFERS: } + } + return %0#0, %0#1, %0#2 : tensor, tensor, tensor +}