diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinOps.h" @@ -40,139 +41,37 @@ using namespace mlir; using namespace mlir::linalg; -namespace { -/// Represents a unit of hoistable TransferWriteOp. This may comprise other -/// instructions that need to be hoisted too. -struct HoistableWrite { - vector::TransferWriteOp transferWriteOp; - tensor::InsertSliceOp insertSliceOp; -}; -/// Represents a unit of hoistable TransferReadOp. This may comprise other -/// instructions that need to be hoisted too. -struct HoistableRead { - vector::TransferReadOp transferReadOp; - tensor::ExtractSliceOp extractSliceOp; -}; -} // namespace - -/// Return true if op1 and op2 are the same constant or the same SSA value. -static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) { - auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional { - Attribute attr = ofr.dyn_cast(); - // Note: isa+cast-like pattern allows writing the condition below as 1 line. - if (!attr && ofr.get().getDefiningOp()) - attr = ofr.get().getDefiningOp().getValue(); - if (auto intAttr = attr.dyn_cast_or_null()) - return intAttr.getValue().getSExtValue(); - return llvm::None; - }; - auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2); - if (cst1 && cst2 && *cst1 == *cst2) - return true; - auto v1 = op1.dyn_cast(), v2 = op2.dyn_cast(); - return v1 && v2 && v1 == v2; -} - -/// Return true is all offsets, sizes and strides are equal. -static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s, - tensor::InsertSliceOp si) { - if (s.getStaticOffsets().size() != si.getStaticOffsets().size()) - return false; - if (s.getStaticSizes().size() != si.getStaticSizes().size()) - return false; - if (s.getStaticStrides().size() != si.getStaticStrides().size()) - return false; - for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets())) - if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) - return false; - for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes())) - if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) - return false; - for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides())) - if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) - return false; - return true; -} - -/// Look for a HoistableRead, in the given tensor uses, accessing the same -/// offset as the HoistableWrite. -static HoistableRead findMatchingTransferRead(HoistableWrite write, - Value srcTensor) { - assert(write.transferWriteOp && - "expected hoistable write to have a .transfer_write"); - - LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: " - << *write.transferWriteOp.getOperation() << "\n"); - if (write.insertSliceOp) - LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: " - << *write.insertSliceOp.getOperation() << "\n"); - SmallVector users(srcTensor.getUsers().begin(), - srcTensor.getUsers().end()); +/// Look for a vector.transfer_read, in the uses of the given `srcTensor`, +/// accessing the same offset as the vector.transfer_write. +static vector::TransferReadOp +findMatchingTransferRead(vector::TransferWriteOp write, Value srcTensor) { + LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: " << write << "\n"); + SmallVector users = llvm::to_vector(srcTensor.getUsers()); while (!users.empty()) { Operation *user = users.pop_back_val(); - LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user - << "\n"); - - // If HoistableWrite involves a InsertSliceOp, we need to find a - // matching ExtractSliceOp. - tensor::ExtractSliceOp sliceOp; - Operation *maybeTransferReadUser = user; - if (write.insertSliceOp) { - sliceOp = dyn_cast(user); - if (!sliceOp || sliceOp.getResult().getType() != - write.insertSliceOp.getSource().getType()) - continue; - - LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: " - << *sliceOp << " vs " << *write.insertSliceOp << "\n"); - if (!sameOffsetsSizesAndStrides(sliceOp, write.insertSliceOp)) - continue; - - LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n"); - // If we got here, sliceOp is hoistable iff it has exactly 2 uses: - // 1. the transfer_write we want to hoist. - // 2. a matching transfer_read. - // Anything else, we skip. - bool skip = false; - Operation *otherUser = nullptr; - for (Operation *u : sliceOp->getUsers()) { - if (u == write.transferWriteOp) - continue; - if (otherUser) { - skip = true; - break; - } - otherUser = u; - } - if (skip || !otherUser) - continue; - maybeTransferReadUser = otherUser; - } + LLVM_DEBUG(DBGS() << "inspect potential read user: " << *user << "\n"); - LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser - << "\n"); - auto read = dyn_cast(maybeTransferReadUser); - if (read && read.getIndices() == write.transferWriteOp.getIndices() && - read.getVectorType() == write.transferWriteOp.getVectorType()) - return HoistableRead{read, sliceOp}; + auto read = dyn_cast(user); + if (read && read.getIndices() == write.getIndices() && + read.getVectorType() == write.getVectorType()) + return read; if (isa(user)) { // If we find a write with disjoint indices recurse through its uses. if (vector::isDisjointTransferIndices( cast(user), - cast( - write.transferWriteOp.getOperation()))) { + cast(*write))) { users.append(user->getUsers().begin(), user->getUsers().end()); } } } - return HoistableRead(); + return nullptr; } -/// Check if the chunk of data inserted by the HoistableWrite are read by any -/// other op than the HoistableRead candidate. -static bool tensorChunkAccessedByUnknownOp(HoistableWrite write, - HoistableRead candidateRead, +/// Return true if the chunk of data inserted by the vector.transfer_write op +/// are read by any other op than the vector.transfer_read candidate. +static bool tensorChunkAccessedByUnknownOp(vector::TransferWriteOp write, + vector::TransferReadOp candidateRead, BlockArgument tensorArg) { // Make sure none of the other uses read the part of the tensor modified // by the transfer_write. @@ -182,13 +81,10 @@ for (OpOperand &use : uses.pop_back_val()) { Operation *user = use.getOwner(); // Skip the candidate use, only inspect the "other" uses. - if (user == candidateRead.transferReadOp || - user == candidateRead.extractSliceOp || - user == write.transferWriteOp || user == write.insertSliceOp) + if (user == candidateRead || user == write) continue; - // Consider all transitive uses through a extract_slice / insert_slice. - // TODO: atm we just bail because a stronger analysis is needed for these - // cases. + // Tensor extract/insert slice ops should be hoisted separately. Just bail + // out if we see them here. if (isa(user)) return true; // Consider all transitive uses through a vector.transfer_write. @@ -208,18 +104,18 @@ } // Follow the use yield as long as it doesn't escape the original // region. - scf::YieldOp yieldUser = dyn_cast(user); - if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor( - yieldUser->getParentOp())) { - Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber()); - uses.push_back(ret.getUses()); - continue; + if (auto yieldUser = dyn_cast(user)) { + Operation *yieldParent = yieldUser->getParentOp(); + if (write->getParentOp()->isAncestor(yieldParent)) { + Value ret = yieldParent->getResult(use.getOperandNumber()); + uses.push_back(ret.getUses()); + continue; + } } auto read = dyn_cast(user); if (!read || !vector::isDisjointTransferIndices( - cast(read.getOperation()), - cast( - write.transferWriteOp.getOperation()))) { + cast(*read), + cast(*write))) { return true; } } @@ -227,124 +123,70 @@ return false; } -/// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`. -/// Return the null HoistableWrite() if it is not comprised of a -/// vector.transfer_write + optional insert_slice or if any of the indexings -/// is `forOp`-dependent. -static HoistableWrite -getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp, - OpOperand &yieldOperand) { +/// Return the `forOp`-invariant vector.transfer_write that produces the given +/// `yieldOperand`. Return nullptr if `yieldOperand` is not produced by a +/// vector.transfer_write op, or if any of the indexings `forOp`-dependent. +static vector::TransferWriteOp +getLoopInvariantTransferWrite(scf::ForOp forOp, OpOperand &yieldOperand) { Value v = yieldOperand.get(); if (auto write = v.getDefiningOp()) { // Indexing must not depend on `forOp`. for (Value operand : write.getIndices()) if (!forOp.isDefinedOutsideOfLoop(operand)) - return HoistableWrite(); - - return HoistableWrite{write, nullptr}; + return nullptr; + return write; } - if (auto insertSliceOp = v.getDefiningOp()) { - // Inserted slice must come from vector.transfer_write. - auto write = - insertSliceOp.getSource().getDefiningOp(); - if (!write) - return HoistableWrite(); - - // Tensor inserted into must be a BBArg at position matching yieldOperand's. - auto bbArg = insertSliceOp.getDest().dyn_cast(); - if (!bbArg || bbArg.getOwner()->getParentOp() != forOp || - bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber()) - return HoistableWrite(); - - // Indexing inserted into must not depend on `forOp`. - for (Value operand : insertSliceOp->getOperands().drop_front( - tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex())) - if (!forOp.isDefinedOutsideOfLoop(operand)) - return HoistableWrite(); - - return HoistableWrite{write, insertSliceOp}; - } - - return HoistableWrite(); + return nullptr; } -/// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair. -static void hoistReadWrite(HoistableRead read, HoistableWrite write, +/// Mechanically hoist matching vector transfer read/write pairs involving +/// `tensorBBArg` out of the enclosing parent scf.for op. +static void hoistReadWrite(vector::TransferReadOp read, + vector::TransferWriteOp write, BlockArgument tensorBBArg) { scf::ForOp forOp = cast(tensorBBArg.getOwner()->getParentOp()); - assert(read.transferReadOp && write.transferWriteOp && - "expected transfer_read and transfer_write ops to be set"); - assert(((read.extractSliceOp && write.insertSliceOp) || - (!read.extractSliceOp && !write.insertSliceOp)) && - "expected matching extract_slice / insert_slice"); + assert(read && write && "expected valid transfer_read and transfer_write"); LLVM_DEBUG(DBGS() << "In forOp:\n" - << *forOp.getOperation() - << "\nHoist: " << *read.transferReadOp.getOperation() - << "\nHoist: " << *write.transferWriteOp.getOperation() - << "\nInvolving: " << tensorBBArg << "\n"); - - // If a read slice is present, hoist it. - if (read.extractSliceOp) - forOp.moveOutOfLoop(read.extractSliceOp); + << *forOp.getOperation() << "\nHoist: " << read + << "\nHoist: " << write << "\nInvolving: " << tensorBBArg + << "\n"); // Hoist the transfer_read op. - forOp.moveOutOfLoop(read.transferReadOp); + forOp.moveOutOfLoop(read); // TODO: don't hardcode /*numIvs=*/1. assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; // Update the source tensor. - if (read.extractSliceOp) - read.extractSliceOp.getSourceMutable().assign( - forOp.getInitArgs()[initArgNumber]); - else - read.transferReadOp.getSourceMutable().assign( - forOp.getInitArgs()[initArgNumber]); + read.getSourceMutable().assign(forOp.getInitArgs()[initArgNumber]); // Hoist write after. - if (write.insertSliceOp) - write.insertSliceOp->moveAfter(forOp); - write.transferWriteOp->moveAfter(forOp); + write->moveAfter(forOp); // Update the yield. auto yieldOp = cast(forOp.getRegion().front().getTerminator()); - if (write.insertSliceOp) - yieldOp->setOperand(initArgNumber, write.insertSliceOp.getDest()); - else - yieldOp->setOperand(initArgNumber, write.transferWriteOp.getSource()); + yieldOp->setOperand(initArgNumber, write.getSource()); // Rewrite `loop` with additional new yields. - OpBuilder b(read.transferReadOp); + OpBuilder b(read); NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { - return SmallVector{write.transferWriteOp.getVector()}; + return SmallVector{write.getVector()}; }; - auto newForOp = replaceLoopWithNewYields( - b, forOp, read.transferReadOp.getVector(), yieldFn); + auto newForOp = replaceLoopWithNewYields(b, forOp, read.getVector(), yieldFn); // Transfer write has been hoisted, need to update the vector and tensor // source. Replace the result of the loop to use the new tensor created // outside the loop. // Depending on whether a insert_slice is present or not, it carries the // update on the tensor operands. - if (write.insertSliceOp) { - newForOp.getResult(initArgNumber) - .replaceAllUsesWith(write.insertSliceOp.getResult()); - write.transferWriteOp.getSourceMutable().assign( - read.extractSliceOp.getResult()); - write.insertSliceOp.getDestMutable().assign( - read.extractSliceOp.getSource()); - } else { - newForOp.getResult(initArgNumber) - .replaceAllUsesWith(write.transferWriteOp.getResult()); - write.transferWriteOp.getSourceMutable().assign( - newForOp.getResult(initArgNumber)); - } + newForOp.getResult(initArgNumber).replaceAllUsesWith(write.getResult()); + write.getSourceMutable().assign(newForOp.getResult(initArgNumber)); // Always update with the newly yield tensor and vector. - write.transferWriteOp.getVectorMutable().assign(newForOp.getResults().back()); + write.getVectorMutable().assign(newForOp.getResults().back()); } // To hoist transfer op on tensor the logic can be significantly simplified @@ -363,30 +205,35 @@ while (changed) { changed = false; func.walk([&](scf::ForOp forOp) { + // Hoist tensor extract/insert slices out first. + LLVM_DEBUG(llvm::dbgs() + << "before hoisting tensor slice: " << forOp << "\n"); + OpBuilder builder(forOp); + auto newForOp = tensor::hoistTensorExtractInsertSliceOps(forOp, builder); + if ((changed = (newForOp != forOp))) + forOp = newForOp; + LLVM_DEBUG(llvm::dbgs() + << "after hoisting tensor slice: " << forOp << "\n"); + Operation *yield = forOp.getBody()->getTerminator(); for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) { OpOperand &ret = yield->getOpOperand(it.index()); - HoistableWrite write = - getLoopInvariantTransferWriteOpDefining(forOp, ret); - if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse()) + vector::TransferWriteOp write = + getLoopInvariantTransferWrite(forOp, ret); + if (!write || !write->hasOneUse()) continue; LLVM_DEBUG(dbgs() << "\n"; - DBGS() << "Candidate write for hoisting: " - << *write.transferWriteOp.getOperation() << "\n"); - if (write.insertSliceOp) - LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: " - << *write.insertSliceOp.getOperation() << "\n"); - if (llvm::any_of(write.transferWriteOp.getIndices(), - [&forOp](Value index) { - return !forOp.isDefinedOutsideOfLoop(index); - })) + DBGS() << "Candidate write for hoisting: " << write << "\n"); + if (llvm::any_of(write.getIndices(), [&forOp](Value index) { + return !forOp.isDefinedOutsideOfLoop(index); + })) continue; // Find a read with the same type and indices. - HoistableRead matchingRead = + vector::TransferReadOp matchingRead = findMatchingTransferRead(write, it.value()); // Make sure none of the other uses read the part of the tensor modified // by the transfer_write. - if (!matchingRead.transferReadOp || + if (!matchingRead || tensorChunkAccessedByUnknownOp(write, matchingRead, it.value())) continue; @@ -398,7 +245,8 @@ // Need to interrupt and restart: erasing the loop messes up the walk. return WalkResult::interrupt(); } - return WalkResult::advance(); + + return changed ? WalkResult::interrupt() : WalkResult::advance(); }); // Apply canonicalization so the newForOp + yield folds immediately, thus // cleaning up the IR and potentially enabling more hoisting.