diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1723,6 +1723,10 @@ dominated by the transfer_write (i.e. no aliasing between the write and the read across the loop) + WARNING: This hoisting does not model parallelism and is generally incorrect + when used on distributed loops with memref semantics! + TODO: obsolete and should be retired. + #### Return modes: The operation always succeeds and returns a handle to the transformed @@ -1745,4 +1749,51 @@ }]; } +//===----------------------------------------------------------------------===// +// HoistRedundantTensorSubsetsOp +//===----------------------------------------------------------------------===// + +def HoistRedundantTensorSubsetsOp : + Op { + let description = [{ + Hoists supported tensor subset extract/insert operation pairs out of + immediately enclosing loop iteratively, if the following conditions + are true: + 1. The 2 ops access the same tensor subset. + 2. All operands are invariant under the enclosing loop. + + The supported subset extract/insert operation pairs currently comprise: + - tensor.extract_slice / tensor.insert_slice + - vector.transfer_read / vector.transfer_write on tensors + + Only scf.for loops are currently supported. + + When applied to: + 1. an scf.for loop, hoist out of this loop only. + 2. a non-loop op, apply hoisting to all the contained loop ops. + + #### Return modes: + + The operation always succeeds and returns a handle to the transformed + function op. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) "; + + let builders = [ + OpBuilder<(ins "Value":$target)>, + ]; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h @@ -10,9 +10,13 @@ #define MLIR_DIALECT_LINALG_TRANSFORMS_HOISTING_H_ namespace mlir { +class RewriterBase; namespace func { class FuncOp; } // namespace func +namespace scf { +class ForOp; +} // namespace scf namespace linalg { @@ -28,11 +32,39 @@ /// function on the candidate loop above which to hoist. Hoisting the transfers /// results in scf::ForOp yielding the value that originally transited through /// memory. -// TODO: generalize on a per-need basis. +/// +/// WARNING: This hoisting does not model parallelism and is generally incorrect +/// when used on distributed loops with memref semantics! +// TODO: obsolete and should be retired void hoistRedundantVectorTransfers(func::FuncOp func); -/// Same behavior as `hoistRedundantVectorTransfers` but works on tensors -/// instead of buffers. +/// Greedily hoist redundant subset extract/insert operations on tensors outside +/// of `forOp`. The logic follows: +/// 1. Look for a write walking back from the `forOp` yield. +/// 2. Check the uses of the matching block argument and look for a matching +/// read (i.e. extract_slice of transfer_read) with matching indices. +/// 3. In the case of a transfer_write, we can bypass other +/// non-WAW-conflicting operations and find more hoisting opportunities. +/// 4. Hoist the read/write pair and update the tensor SSA links. +/// +/// Return the unmodified `forOp` if no hoisting occured. +/// Return a new scf::ForOp if hoisting on tensors occured. +/// +/// After this transformation the return scf::ForOp may have unused arguments +/// that can be removed by application of canonicalization patterns. +// TODO: This should be further generalized along a few different axes: +// - Other loops than scf.ForOp that operate on tensors (both sequential and +// parallel loops). +// - Other subset extract/insert pairs than tensor.extract/insert_slice and +// vector.transfer_read/write. +// - More general areSubsetDisjoint analysis/interface to work across all +// subset op types and allow bypassing non-WAW-conflicting operations in +// more cases. +scf::ForOp hoistRedundantSubsetExtractInsert(RewriterBase &rewriter, + scf::ForOp forOp); + +/// Call into `hoistRedundantSubsetInsertExtract` without a RewriterBase. +// TODO: obsolete and should be retired void hoistRedundantVectorTransfersOnTensor(func::FuncOp func); } // namespace linalg diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -30,7 +30,17 @@ list traits = []> : Tensor_Op { code extraBaseClassDeclaration = [{ - /// Returns the dynamic sizes for this subview operation if specified. + /// Return the type of the base tensor operand. + ::mlir::RankedTensorType getSourceType() { + return getSource().getType().cast(); + } + + /// Return the type of the result tensor. + ::mlir::RankedTensorType getResultType() { + return getResult().getType().cast(); + } + + /// Return the dynamic sizes for this subview operation if specified. ::mlir::Operation::operand_range getDynamicSizes() { return getSizes(); } /// Return the list of Range (i.e. offset, size, stride). Each @@ -105,7 +115,7 @@ %c0 = arith.constant 0 : index %x = tensor.dim %A, %c0 : tensor<4x?xf32> - // Returns the dynamic dimension of %A. + // Return the dynamic dimension of %A. %c1 = arith.constant 1 : index %y = tensor.dim %A, %c1 : memref<4x?xf32> @@ -361,14 +371,10 @@ ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ - /// Returns the type of the base tensor operand. - RankedTensorType getSourceType() { - return getSource().getType().cast(); - } - /// The result of an extract_slice is always a tensor. + // TODO: deprecate RankedTensorType getType() { - return getResult().getType().cast(); + return getResultType(); } /// Compute the rank-reduction mask that can be applied to map the source @@ -834,25 +840,21 @@ ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ - /// Returns the type of the base tensor operand. - RankedTensorType getSourceType() { - return getSource().getType().cast(); - } - /// The result of a insert_slice is always a tensor. + // TODO: Deprecate this method. RankedTensorType getType() { - return getResult().getType().cast(); + return getResultType(); } /// The `dest` type is the same as the result type. RankedTensorType getDestType() { - return getType(); + return getResultType(); } /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. std::array getArrayAttrMaxRanks() { - unsigned rank = getType().getRank(); + unsigned rank = getResultType().getRank(); return {rank, rank, rank}; } diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -82,6 +82,8 @@ /// that come from the fact there is no IndexAttr and that IndexType have no /// bitwidth. bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2); +bool isEqualConstantIntOrValueArray(ArrayRef ofrs1, + ArrayRef ofrs2); /// Helper function to convert a vector of `OpFoldResult`s into a vector of /// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3067,11 +3067,40 @@ transform::HoistRedundantVectorTransfersOp::applyToOne( func::FuncOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { + // WARNING: This hoisting does not model parallelism and is generally + // incorrect when used on distributed loops with memref semantics! + // TODO: obsolete and should be retired. linalg::hoistRedundantVectorTransfers(target); - linalg::hoistRedundantVectorTransfersOnTensor(target); results.push_back(target); return DiagnosedSilenceableFailure::success(); } + +//===----------------------------------------------------------------------===// +// HoistRedundantTensorSubsetsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::HoistRedundantTensorSubsetsOp::applyToOne( + Operation *target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + IRRewriter rewriter(target->getContext()); + auto forOp = dyn_cast(target); + if (forOp) { + scf::ForOp newForOp = + linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp); + results.push_back(newForOp); + return DiagnosedSilenceableFailure::success(); + } + + // TODO: walking in some reverse / inside-out order would be more efficient + // and would capture more cases. + target->walk([&](scf::ForOp forOp) { + hoistRedundantSubsetExtractInsert(rewriter, forOp); + }); + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -24,6 +24,7 @@ Promotion.cpp Split.cpp SplitReduction.cpp + SubsetHoisting.cpp SwapExtractSliceWithFillPatterns.cpp Tiling.cpp TilingInterfaceImpl.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -43,374 +43,13 @@ using namespace mlir; using namespace mlir::linalg; -namespace { -/// Represents a unit of hoistable TransferWriteOp. This may comprise other -/// instructions that need to be hoisted too. -struct HoistableWrite { - vector::TransferWriteOp transferWriteOp; - tensor::InsertSliceOp insertSliceOp; -}; -/// Represents a unit of hoistable TransferReadOp. This may comprise other -/// instructions that need to be hoisted too. -struct HoistableRead { - vector::TransferReadOp transferReadOp; - tensor::ExtractSliceOp extractSliceOp; -}; -} // namespace - -/// Return true if op1 and op2 are the same constant or the same SSA value. -static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) { - auto getConstantIntValue = [](OpFoldResult ofr) -> std::optional { - Attribute attr = ofr.dyn_cast(); - // Note: isa+cast-like pattern allows writing the condition below as 1 line. - if (!attr && ofr.get().getDefiningOp()) - attr = ofr.get().getDefiningOp().getValue(); - if (auto intAttr = attr.dyn_cast_or_null()) - return intAttr.getValue().getSExtValue(); - return std::nullopt; - }; - auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2); - if (cst1 && cst2 && *cst1 == *cst2) - return true; - auto v1 = op1.dyn_cast(), v2 = op2.dyn_cast(); - return v1 && v2 && v1 == v2; -} - -/// Return true is all offsets, sizes and strides are equal. -static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s, - tensor::InsertSliceOp si) { - if (s.getStaticOffsets().size() != si.getStaticOffsets().size()) - return false; - if (s.getStaticSizes().size() != si.getStaticSizes().size()) - return false; - if (s.getStaticStrides().size() != si.getStaticStrides().size()) - return false; - for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets())) - if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) - return false; - for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes())) - if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) - return false; - for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides())) - if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) - return false; - return true; -} - -/// Look for a HoistableRead, in the given tensor uses, accessing the same -/// offset as the HoistableWrite. -static HoistableRead findMatchingTransferRead(HoistableWrite write, - Value srcTensor) { - assert(write.transferWriteOp && - "expected hoistable write to have a .transfer_write"); - - LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: " - << *write.transferWriteOp.getOperation() << "\n"); - if (write.insertSliceOp) - LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: " - << *write.insertSliceOp.getOperation() << "\n"); - SmallVector users(srcTensor.getUsers().begin(), - srcTensor.getUsers().end()); - while (!users.empty()) { - Operation *user = users.pop_back_val(); - LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user - << "\n"); - - // If HoistableWrite involves a InsertSliceOp, we need to find a - // matching ExtractSliceOp. - tensor::ExtractSliceOp sliceOp; - Operation *maybeTransferReadUser = user; - if (write.insertSliceOp) { - sliceOp = dyn_cast(user); - if (!sliceOp || sliceOp.getResult().getType() != - write.insertSliceOp.getSource().getType()) - continue; - - LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: " - << *sliceOp << " vs " << *write.insertSliceOp << "\n"); - if (!sameOffsetsSizesAndStrides(sliceOp, write.insertSliceOp)) - continue; - - LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n"); - // If we got here, sliceOp is hoistable iff it has exactly 2 uses: - // 1. the transfer_write we want to hoist. - // 2. a matching transfer_read. - // Anything else, we skip. - bool skip = false; - Operation *otherUser = nullptr; - for (Operation *u : sliceOp->getUsers()) { - if (u == write.transferWriteOp) - continue; - if (otherUser) { - skip = true; - break; - } - otherUser = u; - } - if (skip || !otherUser) - continue; - maybeTransferReadUser = otherUser; - } - - LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser - << "\n"); - auto read = dyn_cast(maybeTransferReadUser); - if (read && read.getIndices() == write.transferWriteOp.getIndices() && - read.getVectorType() == write.transferWriteOp.getVectorType()) - return HoistableRead{read, sliceOp}; - - if (isa(user)) { - // If we find a write with disjoint indices recurse through its uses. - if (vector::isDisjointTransferIndices( - cast(user), - cast( - write.transferWriteOp.getOperation()))) { - users.append(user->getUsers().begin(), user->getUsers().end()); - } - } - } - return HoistableRead(); -} - -/// Check if the chunk of data inserted by the HoistableWrite are read by any -/// other op than the HoistableRead candidate. -static bool tensorChunkAccessedByUnknownOp(HoistableWrite write, - HoistableRead candidateRead, - BlockArgument tensorArg) { - // Make sure none of the other uses read the part of the tensor modified - // by the transfer_write. - llvm::SmallVector uses; - uses.push_back(tensorArg.getUses()); - while (!uses.empty()) { - for (OpOperand &use : uses.pop_back_val()) { - Operation *user = use.getOwner(); - // Skip the candidate use, only inspect the "other" uses. - if (user == candidateRead.transferReadOp || - user == candidateRead.extractSliceOp || - user == write.transferWriteOp || user == write.insertSliceOp) - continue; - // Consider all transitive uses through a extract_slice / insert_slice. - // TODO: atm we just bail because a stronger analysis is needed for these - // cases. - if (isa(user)) - return true; - // Consider all transitive uses through a vector.transfer_write. - if (auto writeUser = dyn_cast(user)) { - uses.push_back(writeUser->getResult(0).getUses()); - continue; - } - // Consider all nested uses through an scf::ForOp. We may have - // pass-through tensor arguments left from previous level of - // hoisting. - if (auto forUser = dyn_cast(user)) { - Value arg = forUser.getLoopBody().getArgument( - use.getOperandNumber() - forUser.getNumControlOperands() + - /*iv value*/ 1); - uses.push_back(arg.getUses()); - continue; - } - // Follow the use yield as long as it doesn't escape the original - // region. - scf::YieldOp yieldUser = dyn_cast(user); - if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor( - yieldUser->getParentOp())) { - Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber()); - uses.push_back(ret.getUses()); - continue; - } - auto read = dyn_cast(user); - if (!read || !vector::isDisjointTransferIndices( - cast(read.getOperation()), - cast( - write.transferWriteOp.getOperation()))) { - return true; - } - } - } - return false; -} - -/// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`. -/// Return the null HoistableWrite() if it is not comprised of a -/// vector.transfer_write + optional insert_slice or if any of the indexings -/// is `forOp`-dependent. -static HoistableWrite -getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp, - OpOperand &yieldOperand) { - Value v = yieldOperand.get(); - if (auto write = v.getDefiningOp()) { - // Indexing must not depend on `forOp`. - for (Value operand : write.getIndices()) - if (!forOp.isDefinedOutsideOfLoop(operand)) - return HoistableWrite(); - - return HoistableWrite{write, nullptr}; - } - - if (auto insertSliceOp = v.getDefiningOp()) { - // Inserted slice must come from vector.transfer_write. - auto write = - insertSliceOp.getSource().getDefiningOp(); - if (!write) - return HoistableWrite(); - - // Tensor inserted into must be a BBArg at position matching yieldOperand's. - auto bbArg = insertSliceOp.getDest().dyn_cast(); - if (!bbArg || bbArg.getOwner()->getParentOp() != forOp || - bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber()) - return HoistableWrite(); - - // Indexing inserted into must not depend on `forOp`. - for (Value operand : insertSliceOp->getOperands().drop_front( - tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex())) - if (!forOp.isDefinedOutsideOfLoop(operand)) - return HoistableWrite(); - - return HoistableWrite{write, insertSliceOp}; - } - - return HoistableWrite(); -} - -/// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair. -static void hoistReadWrite(HoistableRead read, HoistableWrite write, - BlockArgument tensorBBArg) { - scf::ForOp forOp = cast(tensorBBArg.getOwner()->getParentOp()); - assert(read.transferReadOp && write.transferWriteOp && - "expected transfer_read and transfer_write ops to be set"); - assert(((read.extractSliceOp && write.insertSliceOp) || - (!read.extractSliceOp && !write.insertSliceOp)) && - "expected matching extract_slice / insert_slice"); - LLVM_DEBUG(DBGS() << "In forOp:\n" - << *forOp.getOperation() - << "\nHoist: " << *read.transferReadOp.getOperation() - << "\nHoist: " << *write.transferWriteOp.getOperation() - << "\nInvolving: " << tensorBBArg << "\n"); - - // If a read slice is present, hoist it. - if (read.extractSliceOp) - forOp.moveOutOfLoop(read.extractSliceOp); - - // Hoist the transfer_read op. - forOp.moveOutOfLoop(read.transferReadOp); - - // TODO: don't hardcode /*numIvs=*/1. - assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); - unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; - - // Update the source tensor. - if (read.extractSliceOp) - read.extractSliceOp.getSourceMutable().assign( - forOp.getInitArgs()[initArgNumber]); - else - read.transferReadOp.getSourceMutable().assign( - forOp.getInitArgs()[initArgNumber]); - - // Hoist write after. - if (write.insertSliceOp) - write.insertSliceOp->moveAfter(forOp); - write.transferWriteOp->moveAfter(forOp); - - // Update the yield. - auto yieldOp = cast(forOp.getRegion().front().getTerminator()); - if (write.insertSliceOp) - yieldOp->setOperand(initArgNumber, write.insertSliceOp.getDest()); - else - yieldOp->setOperand(initArgNumber, write.transferWriteOp.getSource()); - - // Rewrite `loop` with additional new yields. - OpBuilder b(read.transferReadOp); - NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc, - ArrayRef newBBArgs) { - return SmallVector{write.transferWriteOp.getVector()}; - }; - auto newForOp = replaceLoopWithNewYields( - b, forOp, read.transferReadOp.getVector(), yieldFn); - - // Transfer write has been hoisted, need to update the vector and tensor - // source. Replace the result of the loop to use the new tensor created - // outside the loop. - // Depending on whether a insert_slice is present or not, it carries the - // update on the tensor operands. - if (write.insertSliceOp) { - newForOp.getResult(initArgNumber) - .replaceAllUsesWith(write.insertSliceOp.getResult()); - write.transferWriteOp.getSourceMutable().assign( - read.extractSliceOp.getResult()); - write.insertSliceOp.getDestMutable().assign( - read.extractSliceOp.getSource()); - } else { - newForOp.getResult(initArgNumber) - .replaceAllUsesWith(write.transferWriteOp.getResult()); - write.transferWriteOp.getSourceMutable().assign( - newForOp.getResult(initArgNumber)); - } - - // Always update with the newly yield tensor and vector. - write.transferWriteOp.getVectorMutable().assign(newForOp.getResults().back()); -} - -// To hoist transfer op on tensor the logic can be significantly simplified -// compared to the case on buffer. The transformation follows this logic: -// 1. Look for transfer_write with a single use from ForOp yield -// 2. Check the uses of the matching block argument and look for a transfer_read -// with the same indices. -// 3. Check that all the other uses of the tensor argument are either disjoint -// tensor_read or transfer_write. For transfer_write uses recurse to make sure -// the new tensor has the same restrictions on its uses. -// 4. Hoist the tensor_read/tensor_write and update the tensor SSA links. -// After this transformation the scf.forOp may have unused arguments that can be -// remove by the canonicalization pass. void mlir::linalg::hoistRedundantVectorTransfersOnTensor(func::FuncOp func) { - bool changed = true; - while (changed) { - changed = false; - func.walk([&](scf::ForOp forOp) { - Operation *yield = forOp.getBody()->getTerminator(); - for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) { - OpOperand &ret = yield->getOpOperand(it.index()); - HoistableWrite write = - getLoopInvariantTransferWriteOpDefining(forOp, ret); - if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse()) - continue; - LLVM_DEBUG(dbgs() << "\n"; - DBGS() << "Candidate write for hoisting: " - << *write.transferWriteOp.getOperation() << "\n"); - if (write.insertSliceOp) - LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: " - << *write.insertSliceOp.getOperation() << "\n"); - if (llvm::any_of(write.transferWriteOp.getIndices(), - [&forOp](Value index) { - return !forOp.isDefinedOutsideOfLoop(index); - })) - continue; - // Find a read with the same type and indices. - HoistableRead matchingRead = - findMatchingTransferRead(write, it.value()); - // Make sure none of the other uses read the part of the tensor modified - // by the transfer_write. - if (!matchingRead.transferReadOp || - tensorChunkAccessedByUnknownOp(write, matchingRead, it.value())) - continue; - - LLVM_DEBUG(DBGS() << "Start hoisting\n"); - hoistReadWrite(matchingRead, write, it.value()); - changed = true; - forOp.erase(); - - // Need to interrupt and restart: erasing the loop messes up the walk. - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - // Apply canonicalization so the newForOp + yield folds immediately, thus - // cleaning up the IR and potentially enabling more hoisting. - if (changed) { - RewritePatternSet patterns(func->getContext()); - scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext()); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); - } - } + IRRewriter rewriter(func->getContext()); + // TODO: walking in some reverse / inside-out order would be more efficient + // and would capture more cases. + func.walk([&](scf::ForOp forOp) { + hoistRedundantSubsetExtractInsert(rewriter, forOp); + }); } void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp @@ -0,0 +1,557 @@ +//===- SubsetHoisting.cpp - Linalg hoisting transformations----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements functions concerned with hoisting invariant subset +// operations in the context of Linalg transformations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "subset-hoisting" + +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") + +using namespace mlir; +using namespace mlir::linalg; + +/// Return true if all offsets, sizes and strides are equal. +/// This is a poor man's helper for subset equality of symmetrical ops. +static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s, + tensor::InsertSliceOp si) { + return isEqualConstantIntOrValueArray(s.getMixedOffsets(), + si.getMixedOffsets()) && + isEqualConstantIntOrValueArray(s.getMixedSizes(), + si.getMixedSizes()) && + isEqualConstantIntOrValueArray(s.getMixedStrides(), + si.getMixedStrides()); +} + +/// Return true if the location of the subset defined by the op is invariant of +/// the loop iteration. +static bool +isSubsetLocationLoopInvariant(scf::ForOp forOp, + vector::TransferWriteOp transferWriteOp) { + for (Value operand : transferWriteOp.getIndices()) + if (!forOp.isDefinedOutsideOfLoop(operand)) + return false; + return true; +} + +/// Return true if the location of the subset defined by the op is invariant of +/// the loop iteration. +static bool isSubsetLocationLoopInvariant(scf::ForOp forOp, + tensor::InsertSliceOp insertSliceOp) { + for (Value operand : insertSliceOp->getOperands().drop_front( + tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex())) + if (!forOp.isDefinedOutsideOfLoop(operand)) + return false; + return true; +} + +/// Greedily look for the first read such that: +/// - The read is of type `tensor.extract_slice`. +/// - The read is one of the uses of `srcTensor`. +/// - The read is to the same subset that `tensor.insert_slice` writes. +// TODO: Unify implementations once the "bypassing behavior" is the same. +static FailureOr +findHoistableMatchingInsertSlice(RewriterBase &rewriter, + tensor::InsertSliceOp insertSliceOp, + BlockArgument srcTensor) { + assert(srcTensor.getType().isa() && "not a ranked tensor"); + + auto forOp = cast(srcTensor.getOwner()->getParentOp()); + + LLVM_DEBUG(DBGS() << "--find matching read for: " << insertSliceOp << "\n"; + DBGS() << "--amongst users of: " << srcTensor << "\n"); + + SmallVector users(srcTensor.getUsers()); + if (forOp.isDefinedOutsideOfLoop(insertSliceOp.getDest())) + llvm::append_range(users, insertSliceOp.getDest().getUsers()); + + for (Operation *user : users) { + LLVM_DEBUG(DBGS() << "----inspect user: " << *user << "\n"); + auto extractSliceOp = dyn_cast(user); + // Skip ops other than extract_slice with an exact matching subset. + if (extractSliceOp) { + if (extractSliceOp.getResultType() != insertSliceOp.getSourceType() || + !sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp)) { + LLVM_DEBUG(DBGS() << "------not a matching extract_slice\n"; + DBGS() << *user << " vs " << *insertSliceOp << "\n"); + continue; + } + + // Skip insert_slice whose vector is defined within the loop: we need to + // hoist that definition first otherwise dominance violations trigger. + if (!extractSliceOp.getSource().isa() && + !forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) { + LLVM_DEBUG(DBGS() << "------transfer_read vector is loop-dependent\n"); + continue; + } + return extractSliceOp; + } + + // TODO: Look through disjoint subsets, similar to vector.transfer_write + // and unify implementations. + } + + LLVM_DEBUG(DBGS() << "----no matching extract_slice"); + return failure(); +} + +/// Greedily look for the first read such that: +/// - The read is of type `tensor.transfer_read`. +/// - The read is one of the uses of `srcTensor`. +/// - The read is to the same subset that `tensor.transfer_write` writes. +// TODO: Unify implementations once the "bypassing behavior" is the same. +static FailureOr +findHoistableMatchingTransferRead(RewriterBase &rewriter, + vector::TransferWriteOp transferWriteOp, + BlockArgument srcTensor) { + if (!srcTensor.getType().isa()) + return failure(); + + auto forOp = cast(srcTensor.getOwner()->getParentOp()); + + LLVM_DEBUG(DBGS() << "--find matching read for: " << transferWriteOp << "\n"; + DBGS() << "--amongst users of: " << srcTensor << "\n";); + + // vector.transfer_write is a bit peculiar: we look through WAW dependencies + // to disjoint tensor subsets. This requires a while loop. + // TODO: Look through disjoint subsets for tensor.insert_slice and unify + // implementations. + SmallVector users(srcTensor.getUsers()); + // TODO: transferWriteOp.getSource is actually the destination tensor!! + if (forOp.isDefinedOutsideOfLoop(transferWriteOp.getSource())) + llvm::append_range(users, transferWriteOp.getSource().getUsers()); + while (!users.empty()) { + Operation *user = users.pop_back_val(); + LLVM_DEBUG(DBGS() << "----inspect user: " << *user << "\n"); + auto read = dyn_cast(user); + if (read) { + // Skip ops other than transfer_read with an exact matching subset. + if (read.getIndices() != transferWriteOp.getIndices() || + read.getVectorType() != transferWriteOp.getVectorType()) { + LLVM_DEBUG(DBGS() << "------not a matching transfer_write\n"; + DBGS() << *user << " vs " << *transferWriteOp << "\n"); + continue; + } + + // Skip transfer_read whose vector is defined within the loop: we need + // to hoist that definition first otherwise dominance violations + // trigger. + if (!read.getSource().isa() && + !forOp.isDefinedOutsideOfLoop(read.getSource())) { + LLVM_DEBUG(DBGS() << "------transfer_read vector is loop-dependent\n"); + continue; + } + LLVM_DEBUG(DBGS() << "------found match\n"); + return read; + } + + // As an optimization, we look further through WAW dependencies to + // disjoint tensor subsets. This creates more opportunities to find a + // matching read. + if (isa(user)) { + // If we find a write with disjoint indices append all its uses. + // TODO: Generalize areSubsetsDisjoint and allow other bypass than + // just vector.transfer_write - vector.transfer_write. + if (vector::isDisjointTransferIndices( + cast(user), + cast( + transferWriteOp.getOperation()))) { + LLVM_DEBUG(DBGS() << "----follow through disjoint write\n"); + users.append(user->getUsers().begin(), user->getUsers().end()); + } else { + LLVM_DEBUG(DBGS() << "----skip non-disjoint write\n"); + } + } + } + + LLVM_DEBUG(DBGS() << "--no matching transfer_read\n"); + return rewriter.notifyMatchFailure(transferWriteOp, + "no matching transfer_read"); +} + +/// Return the `vector.transfer_write` that produces `yieldOperand`, if: +/// - The write operates on tensors. +/// - All indices are defined outside of the loop. +/// Return failure otherwise. +/// +/// This is sufficient condition to hoist the `vector.transfer_write`; other +/// operands can always be yielded by the loop where needed. +// TODO: generalize beyond scf::ForOp. +// TODO: Unify implementations once the "bypassing behavior" is the same. +static FailureOr +getLoopInvariantTransferWriteDefining(RewriterBase &rewriter, scf::ForOp forOp, + OpOperand &yieldOperand) { + assert(isa(yieldOperand.getOwner()) && "must be an scf.yield"); + + Value v = yieldOperand.get(); + auto transferWriteOp = v.getDefiningOp(); + if (!transferWriteOp) + return rewriter.notifyMatchFailure(v.getLoc(), "not a transfer_write"); + + if (transferWriteOp->getNumResults() == 0) { + return rewriter.notifyMatchFailure(v.getLoc(), + "unsupported transfer_write on buffers"); + } + + // We do not check that the destination is a BBarg that matches the yield + // operand as this would prevent us from bypassing other non-WAW-conflicting + // writes. + + // Indexing must not depend on `forOp`. + if (!isSubsetLocationLoopInvariant(forOp, transferWriteOp)) + return rewriter.notifyMatchFailure( + v.getLoc(), "transfer_write indexing is loop-dependent"); + + return transferWriteOp; +} + +/// Return the `tensor.insert_slice` that produces `yieldOperand`, if: +/// 1. Its destination tensor is a block argument of the `forOp`. +/// 2. The unique use of its result is a yield with operand number matching +/// the block argument. +/// 3. All indices are defined outside of the loop. +/// Return failure otherwise. +/// +/// This is sufficient condition to hoist the `tensor.insert_slice`; other +/// operands can always be yielded by the loop where needed. +/// Note: 1. + 2. ensure that the yield / iter_args cycle results in proper +/// semantics (i.e. no ping-ping between iter_args across iterations). +// TODO: generalize beyond scf::ForOp. +// TODO: Unify implementations once the "bypassing behavior" is the same. +static FailureOr +getLoopInvariantInsertSliceDefining(RewriterBase &rewriter, scf::ForOp forOp, + OpOperand &yieldOperand) { + assert(isa(yieldOperand.getOwner()) && "must be an scf.yield"); + + Value v = yieldOperand.get(); + auto insertSliceOp = v.getDefiningOp(); + if (!insertSliceOp) + return rewriter.notifyMatchFailure(v.getLoc(), "not an insert_slice"); + + // Tensor inserted into must be a BBArg at position matching yield operand. + // TODO: In the future we should not perform this check if we want to bypass + // other non-WAW-conflicting writes. + auto bbArg = insertSliceOp.getDest().dyn_cast(); + if (!bbArg || bbArg.getOwner()->getParentOp() != forOp || + bbArg.getArgNumber() != + /*num iv=*/1 + yieldOperand.getOperandNumber()) + return rewriter.notifyMatchFailure(v.getLoc(), "not a matching bbarg"); + + // Indexing inserted into must not depend on `forOp`. + if (!isSubsetLocationLoopInvariant(forOp, insertSliceOp)) + return rewriter.notifyMatchFailure( + v.getLoc(), "insert_slice indexing is loop-dependent"); + + return insertSliceOp; +} + +/// Check if the chunk of data inserted by the `writeOp` is read by any other +/// op than the candidateReadOp. This conflicting operation prevents hoisting, +/// return it or nullptr if none is found. +// TODO: Generalize subset disjunction analysis/interface. +// TODO: Support more subset op types. +static Operation *isTensorChunkAccessedByUnknownOp(Operation *writeOp, + Operation *candidateReadOp, + BlockArgument tensorArg) { + // Make sure none of the other uses read the part of the tensor modified + // by the transfer_write. + llvm::SmallVector uses; + uses.push_back(tensorArg.getUses()); + while (!uses.empty()) { + for (OpOperand &use : uses.pop_back_val()) { + Operation *user = use.getOwner(); + // Skip the candidate use, only inspect the "other" uses. + if (user == candidateReadOp || user == writeOp) + continue; + + // TODO: Consider all transitive uses through + // extract_slice/insert_slice. Atm we just bail because a stronger + // analysis is needed for these cases. + if (isa(user)) + return user; + + // Consider all transitive uses through a vector.transfer_write. + if (isa(writeOp)) { + if (auto writeUser = dyn_cast(user)) { + uses.push_back(writeUser->getResult(0).getUses()); + continue; + } + } + + // Consider all nested uses through an scf::ForOp. We may have + // pass-through tensor arguments left from previous level of + // hoisting. + if (auto forUser = dyn_cast(user)) { + Value arg = forUser.getLoopBody().getArgument( + use.getOperandNumber() - forUser.getNumControlOperands() + + /*iv value*/ 1); + uses.push_back(arg.getUses()); + continue; + } + + // Follow the use yield, only if it doesn't escape the original region. + scf::YieldOp yieldUser = dyn_cast(user); + if (yieldUser && + writeOp->getParentOp()->isAncestor(yieldUser->getParentOp())) { + Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber()); + uses.push_back(ret.getUses()); + continue; + } + + // If the write is a vector::TransferWriteOp, it may have been bypassed + // and we need to check subset disjunction + if (isa(writeOp)) { + auto read = dyn_cast(user); + if (!read || !vector::isDisjointTransferIndices( + cast(read.getOperation()), + cast(writeOp))) { + return user; + } + } + } + } + return nullptr; +} + +/// Mechanical hoisting of a matching read / write pair. +/// Return the newly created scf::ForOp with an extra yields. +// TODO: Unify implementations once the "bypassing behavior" is the same. +static scf::ForOp hoistTransferReadWrite( + RewriterBase &rewriter, vector::TransferReadOp transferReadOp, + vector::TransferWriteOp transferWriteOp, BlockArgument tensorBBArg) { + scf::ForOp forOp = cast(tensorBBArg.getOwner()->getParentOp()); + LLVM_DEBUG(DBGS() << "--Start hoisting\n"; + DBGS() << "--Hoist read : " << transferReadOp << "\n"; + DBGS() << "--Hoist write: " << transferWriteOp << "\n"; + DBGS() << "--Involving : " << tensorBBArg << "\n"); + + // TODO: don't hardcode /*numIvs=*/1. + assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); + int64_t initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; + + // 1. Hoist the read op. Thanks to our previous checks we know this will not + // trigger dominance violations once BBArgs are updated. + // TODO: should the rewriter ever want to track this move ? + transferReadOp->moveBefore(forOp); + if (!forOp.isDefinedOutsideOfLoop(transferReadOp.getSource())) { + assert(transferReadOp.getSource() == tensorBBArg && + "transferReadOp source not defined above must be the tracked bbArg"); + rewriter.startRootUpdate(transferReadOp); + transferReadOp.getSourceMutable().assign( + forOp.getInitArgs()[initArgNumber]); + rewriter.finalizeRootUpdate(transferReadOp); + } + + // 2. Rewrite `loop` with an additional yield. This is the quantity that is + // computed iteratively but whose storage has become loop-invariant. + NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc, + ArrayRef newBBArgs) { + return SmallVector{transferWriteOp.getVector()}; + }; + auto newForOp = replaceLoopWithNewYields( + rewriter, forOp, {transferReadOp.getVector()}, yieldFn); + rewriter.eraseOp(forOp); + + // 3. Update the yield. Invariant: initArgNumber is the destination tensor. + auto yieldOp = + cast(newForOp.getRegion().front().getTerminator()); + // TODO: transferWriteOp.getSource is actually the destination tensor!! + rewriter.startRootUpdate(yieldOp); + yieldOp->setOperand(initArgNumber, transferWriteOp.getSource()); + rewriter.finalizeRootUpdate(yieldOp); + + // 4. Hoist write after and make uses of newForOp.getResult(initArgNumber) + // flow through it. + // TODO: should the rewriter ever want to track this move ? + transferWriteOp->moveAfter(newForOp); + rewriter.startRootUpdate(transferWriteOp); + transferWriteOp.getVectorMutable().assign(newForOp.getResults().back()); + // TODO: transferWriteOp.getSource is actually the destination tensor!! + transferWriteOp.getSourceMutable().assign(newForOp.getResult(initArgNumber)); + rewriter.finalizeRootUpdate(transferWriteOp); + rewriter.replaceAllUsesExcept(newForOp.getResult(initArgNumber), + transferWriteOp.getResult(), transferWriteOp); + return newForOp; +} + +/// Mechanical hoisting of a matching read / write pair. +/// Return the newly created scf::ForOp with an extra yields. +// TODO: Unify implementations once the "bypassing behavior" is the same. +static scf::ForOp hoistExtractInsertSlice(RewriterBase &rewriter, + tensor::ExtractSliceOp extractSliceOp, + tensor::InsertSliceOp insertSliceOp, + BlockArgument tensorBBArg) { + scf::ForOp forOp = cast(tensorBBArg.getOwner()->getParentOp()); + LLVM_DEBUG(DBGS() << "--Start hoisting\n"; + DBGS() << "--Hoist read : " << extractSliceOp << "\n"; + DBGS() << "--Hoist write: " << insertSliceOp << "\n"; + DBGS() << "--Involving : " << tensorBBArg << "\n"); + + // TODO: don't hardcode /*numIvs=*/1. + assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); + int64_t initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; + + // 1. Hoist the read op. Thanks to our previous checks we know this will not + // trigger dominance violations once BBArgs are updated. + // TODO: should the rewriter ever want to track this move ? + extractSliceOp->moveBefore(forOp); + if (!forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) { + assert(extractSliceOp.getSource() == tensorBBArg && + "extractSlice source not defined above must be the tracked bbArg"); + rewriter.startRootUpdate(extractSliceOp); + extractSliceOp.getSourceMutable().assign( + forOp.getInitArgs()[initArgNumber]); + rewriter.finalizeRootUpdate(extractSliceOp); + } + + // 2. Rewrite `loop` with an additional yield. This is the quantity that is + // computed iteratively but whose storage has become loop-invariant. + NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc, + ArrayRef newBBArgs) { + return SmallVector{insertSliceOp.getSource()}; + }; + auto newForOp = replaceLoopWithNewYields(rewriter, forOp, + extractSliceOp.getResult(), yieldFn); + rewriter.eraseOp(forOp); + + // 3. Update the yield. Invariant: initArgNumber is the destination tensor. + auto yieldOp = + cast(newForOp.getRegion().front().getTerminator()); + // TODO: should the rewriter ever want to track this ? + rewriter.startRootUpdate(yieldOp); + yieldOp->setOperand(initArgNumber, insertSliceOp.getDest()); + rewriter.finalizeRootUpdate(yieldOp); + + // 4. Hoist write after and make uses of newForOp.getResult(initArgNumber) + // flow through it. + // TODO: should the rewriter ever want to track this move ? + insertSliceOp->moveAfter(newForOp); + rewriter.startRootUpdate(insertSliceOp); + insertSliceOp.getSourceMutable().assign(newForOp.getResults().back()); + insertSliceOp.getDestMutable().assign(newForOp.getResult(initArgNumber)); + rewriter.finalizeRootUpdate(insertSliceOp); + rewriter.replaceAllUsesExcept(newForOp.getResult(initArgNumber), + insertSliceOp.getResult(), insertSliceOp); + return newForOp; +} + +/// Greedily hoist redundant subset extract/insert operations on tensors +/// outside `forOp`. +/// Return the unmodified `forOp` if no hoisting occured. +/// Return a new scf::ForOp if hoisting on tensors occured. +scf::ForOp +mlir::linalg::hoistRedundantSubsetExtractInsert(RewriterBase &rewriter, + scf::ForOp forOp) { + LLVM_DEBUG(DBGS() << "Enter hoistRedundantSubsetExtractInsert scf.for\n"); + Operation *yield = forOp.getBody()->getTerminator(); + + LLVM_DEBUG(DBGS() << "\n"; DBGS() << "Consider " << forOp << "\n"); + + scf::ForOp newForOp = forOp; + do { + forOp = newForOp; + for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) { + LLVM_DEBUG(DBGS() << "Consider " << it.value() << "\n"); + + // 1. Find a loop invariant subset write yielding `ret` that we can + // consider for hoisting. + // TODO: TypeSwitch when we add more cases. + OpOperand &ret = yield->getOpOperand(it.index()); + FailureOr transferWriteOp = + getLoopInvariantTransferWriteDefining(rewriter, forOp, ret); + FailureOr insertSliceOp = + getLoopInvariantInsertSliceDefining(rewriter, forOp, ret); + if (failed(transferWriteOp) && failed(insertSliceOp)) { + LLVM_DEBUG(DBGS() << "no loop invariant write defining iter_args " + << it.value() << "\n"); + continue; + } + + Operation *writeOp = succeeded(transferWriteOp) + ? transferWriteOp->getOperation() + : insertSliceOp->getOperation(); + + // 2. Only accept writes with a single use (i.e. the yield). + if (!writeOp->hasOneUse()) { + LLVM_DEBUG(DBGS() << "write with more than 1 use " << *writeOp << "\n"); + continue; + } + + LLVM_DEBUG(DBGS() << "Write to hoist: " << *writeOp << "\n"); + + // 3. Find a matching read that can also be hoisted. + Operation *matchingReadOp = nullptr; + // TODO: TypeSwitch. + if (succeeded(transferWriteOp)) { + auto maybeTransferRead = findHoistableMatchingTransferRead( + rewriter, *transferWriteOp, it.value()); + if (succeeded(maybeTransferRead)) + matchingReadOp = maybeTransferRead->getOperation(); + } else if (succeeded(insertSliceOp)) { + auto maybeExtractSlice = findHoistableMatchingInsertSlice( + rewriter, *insertSliceOp, it.value()); + if (succeeded(maybeExtractSlice)) + matchingReadOp = maybeExtractSlice->getOperation(); + } else { + llvm_unreachable("unexpected case"); + } + if (!matchingReadOp) { + LLVM_DEBUG(DBGS() << "No matching read\n"); + continue; + } + + // 4. Make sure no other use reads the part of the modified tensor. + // This is necessary to guard against RAW and WAR hazards when + // non-conflicting WAW are bypassed. + Operation *maybeUnknownOp = + isTensorChunkAccessedByUnknownOp(writeOp, matchingReadOp, it.value()); + if (maybeUnknownOp) { + LLVM_DEBUG(DBGS() << "Tensor chunk accessed by unknown op, skip: " + << *maybeUnknownOp << "\n"); + continue; + } + + // 5. Perform the actual mechanical hoisting. + // TODO: TypeSwitch. + LLVM_DEBUG(DBGS() << "Read to hoist: " << *matchingReadOp << "\n"); + if (succeeded(transferWriteOp)) { + newForOp = hoistTransferReadWrite( + rewriter, cast(matchingReadOp), + *transferWriteOp, it.value()); + } else if (succeeded(insertSliceOp)) { + newForOp = hoistExtractInsertSlice( + rewriter, cast(matchingReadOp), + *insertSliceOp, it.value()); + } else { + llvm_unreachable("unexpected case"); + } + + LLVM_DEBUG(DBGS() << "Module post-hoisting: " + << newForOp->getParentOfType() << "\n"); + break; + } + } while (forOp != newForOp); + + return newForOp; +} diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -136,6 +136,16 @@ return v1 && v1 == v2; } +bool isEqualConstantIntOrValueArray(ArrayRef ofrs1, + ArrayRef ofrs2) { + if (ofrs1.size() != ofrs2.size()) + return false; + for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2)) + if (!isEqualConstantIntOrValue(ofr1, ofr2)) + return false; + return true; +} + /// Helper function to convert a vector of `OpFoldResult`s into a vector of /// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold result /// if it casts to a `Value` or create an index-type constant if it casts to diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-transform-dialect-interpreter --split-input-file --allow-unregistered-dialect %s | FileCheck %s +// RUN: mlir-opt -test-transform-dialect-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s // CHECK-LABEL: func @hoist_vector_transfer_pairs( // CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref, @@ -29,7 +29,7 @@ // CHECK: vector.transfer_read %{{.*}} : memref, vector<5xf32> // CHECK: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32> // CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> -// CHECK: "some_use"(%[[MEMREF2]]) : (memref) -> vector<3xf32> +// CHECK: "some_use"(%[[MEMREF2]], %{{.*}}) : (memref, vector<3xf32>) -> vector<3xf32> // CHECK: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> // CHECK: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32> // CHECK: vector.transfer_write %{{.*}} : vector<3xf32>, memref @@ -56,7 +56,7 @@ "some_crippling_use"(%memref5) : (memref) -> () %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32> %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32> - %u2 = "some_use"(%memref2) : (memref) -> vector<3xf32> + %u2 = "some_use"(%memref2, %r2) : (memref, vector<3xf32>) -> vector<3xf32> %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32> %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32> %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32> @@ -173,6 +173,51 @@ // ----- +// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops( +// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]+]]: memref<64x64xi32>, +// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]+]]: memref<64x64xi32>, +// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]+]]: memref<64x64xi32>) { +// CHECK: %[[C0:.*]] = arith.constant 0 : i32 +// CHECK: affine.for %[[I:.*]] = 0 to 64 { +// CHECK: affine.for %[[J:.*]] = 0 to 64 step 16 { +// CHECK: %[[R0:.*]] = vector.transfer_read %[[MEMREF2]][%[[I]], %[[J]]], %[[C0]] : memref<64x64xi32>, vector<16xi32> +// CHECK: %[[R:.*]] = affine.for %[[K:.*]] = 0 to 64 iter_args(%[[ACC:.*]] = %[[R0]]) -> (vector<16xi32>) { +// CHECK: %[[AV:.*]] = vector.transfer_read %[[MEMREF0]][%[[I]], %[[K]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> +// CHECK: %[[BV:.*]] = vector.transfer_read %[[MEMREF1]][%[[K]], %[[J]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> +// CHECK: %[[T0:.*]] = arith.muli %[[AV]], %[[BV]] : vector<16xi32> +// CHECK: %[[T1:.*]] = arith.addi %[[ACC]], %[[T0]] : vector<16xi32> +// CHECK: affine.yield %[[T1]] : vector<16xi32> +// CHECK: } +// CHECK: vector.transfer_write %[[R]], %[[MEMREF2]][%[[I]], %[[J]]] : vector<16xi32>, memref<64x64xi32> +// CHECK: } +// CHECK: } +func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi32>, %memref1: memref<64x64xi32>, %memref2: memref<64x64xi32>) { + %c0_i32 = arith.constant 0 : i32 + affine.for %arg3 = 0 to 64 { + affine.for %arg4 = 0 to 64 step 16 { + affine.for %arg5 = 0 to 64 { + %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32> + %1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> + %2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> + %3 = arith.muli %0, %1 : vector<16xi32> + %4 = arith.addi %2, %3 : vector<16xi32> + vector.transfer_write %4, %memref2[%arg3, %arg4] : vector<16xi32>, memref<64x64xi32> + } + } + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.hoist_redundant_vector_transfers %0 + : (!pdl.operation) -> !pdl.operation +} + +// ----- + // CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor func.func @hoist_vector_transfer_pairs_tensor( %tensor0: tensor, %tensor1: tensor, %tensor2: tensor, @@ -256,7 +301,7 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation - transform.structured.hoist_redundant_vector_transfers %0 + transform.structured.hoist_redundant_tensor_subsets %0 : (!pdl.operation) -> !pdl.operation } @@ -351,7 +396,7 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation - transform.structured.hoist_redundant_vector_transfers %0 + transform.structured.hoist_redundant_tensor_subsets %0 : (!pdl.operation) -> !pdl.operation } @@ -468,26 +513,26 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation - transform.structured.hoist_redundant_vector_transfers %0 + transform.structured.hoist_redundant_tensor_subsets %0 : (!pdl.operation) -> !pdl.operation } // ----- // CHECK-LABEL: func @hoist_vector_transfer_write_pairs_disjoint_tensor( -// CHECK-SAME: %[[T:.*]]: tensor, +// CHECK-SAME: %[[T0:.*]]: tensor, // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[R0:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C0]]], %{{.*}} : tensor, vector<2xf32> -// CHECK-DAG: %[[R1:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C3]]], %{{.*}} : tensor, vector<2xf32> -// CHECK: %[[F:.*]]:2 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[R3:.*]] = %[[R1:.*]], %[[R2:.*]] = %[[R0]]) -> (vector<2xf32>, vector<2xf32>) { -// CHECK: %[[R4:.*]] = "some_use"(%[[R2]]) : (vector<2xf32>) -> vector<2xf32> -// CHECK: %[[R5:.*]] = "some_use"(%[[R3]]) : (vector<2xf32>) -> vector<2xf32> -// CHECK: scf.yield %[[R5]], %[[R4]] : vector<2xf32>, vector<2xf32> +// CHECK: %[[F:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[T:.*]] = %[[T0:.*]]) -> (tensor) { +// CHECK: %[[R0:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C0]]], %{{.*}} : tensor, vector<2xf32> +// CHECK: %[[U0:.*]] = "some_use"(%[[R0]]) : (vector<2xf32>) -> vector<2xf32> +// CHECK: %[[W0:.*]] = vector.transfer_write %[[U0]], %[[T]][%[[C0]], %[[C0]]] : vector<2xf32>, tensor +// CHECK: %[[R1:.*]] = vector.transfer_read %[[W0]][%[[C0]], %[[C3]]], %{{.*}} : tensor, vector<2xf32> +// CHECK: %[[U1:.*]] = "some_use"(%[[R1]]) : (vector<2xf32>) -> vector<2xf32> +// CHECK: %[[W1:.*]] = vector.transfer_write %[[U1]], %[[W0]][%[[C0]], %[[C3]]] : vector<2xf32>, tensor +// CHECK: scf.yield %[[W1]] : tensor // CHECK: } -// CHECK: %[[W0:.*]] = vector.transfer_write %[[F]]#1, %[[T]][%[[C0]], %[[C0]]] : vector<2xf32>, tensor -// CHECK: %[[W1:.*]] = vector.transfer_write %[[F]]#0, %[[W0]][%[[C0]], %[[C3]]] : vector<2xf32>, tensor -// CHECK: return %[[W1]] : tensor +// CHECK: return %[[F]] : tensor func.func @hoist_vector_transfer_write_pairs_disjoint_tensor( %tensor: tensor, %val: index, %lb : index, %ub : index, %step: index) -> @@ -500,7 +545,12 @@ -> (tensor) { %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor, vector<2xf32> %u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32> + // + // w10 cannot be hoisted first, otherwise the use in %w11 violates domination. %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor + // + // r01 cannot be hoisted first, otherwise the use of %w10 violates domination. + // As a consequence, nothing can hoist. %r01 = vector.transfer_read %w10[%c0, %c3], %cst: tensor, vector<2xf32> %u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32> %w11 = vector.transfer_write %u01, %w10[%c0, %c3] : vector<2xf32>, tensor @@ -513,51 +563,119 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation - transform.structured.hoist_redundant_vector_transfers %0 + transform.structured.hoist_redundant_tensor_subsets %0 : (!pdl.operation) -> !pdl.operation } // ----- -// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops( -// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]+]]: memref<64x64xi32>, -// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]+]]: memref<64x64xi32>, -// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]+]]: memref<64x64xi32>) { -// CHECK: %[[C0:.*]] = arith.constant 0 : i32 -// CHECK: affine.for %[[I:.*]] = 0 to 64 { -// CHECK: affine.for %[[J:.*]] = 0 to 64 step 16 { -// CHECK: %[[R0:.*]] = vector.transfer_read %[[MEMREF2]][%[[I]], %[[J]]], %[[C0]] : memref<64x64xi32>, vector<16xi32> -// CHECK: %[[R:.*]] = affine.for %[[K:.*]] = 0 to 64 iter_args(%[[ACC:.*]] = %[[R0]]) -> (vector<16xi32>) { -// CHECK: %[[AV:.*]] = vector.transfer_read %[[MEMREF0]][%[[I]], %[[K]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> -// CHECK: %[[BV:.*]] = vector.transfer_read %[[MEMREF1]][%[[K]], %[[J]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> -// CHECK: %[[T0:.*]] = arith.muli %[[AV]], %[[BV]] : vector<16xi32> -// CHECK: %[[T1:.*]] = arith.addi %[[ACC]], %[[T0]] : vector<16xi32> -// CHECK: affine.yield %[[T1]] : vector<16xi32> -// CHECK: } -// CHECK: vector.transfer_write %[[R]], %[[MEMREF2]][%[[I]], %[[J]]] : vector<16xi32>, memref<64x64xi32> -// CHECK: } -// CHECK: } -func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi32>, %memref1: memref<64x64xi32>, %memref2: memref<64x64xi32>) { - %c0_i32 = arith.constant 0 : i32 - affine.for %arg3 = 0 to 64 { - affine.for %arg4 = 0 to 64 step 16 { - affine.for %arg5 = 0 to 64 { - %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32> - %1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> - %2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> - %3 = arith.muli %0, %1 : vector<16xi32> - %4 = arith.addi %2, %3 : vector<16xi32> - vector.transfer_write %4, %memref2[%arg3, %arg4] : vector<16xi32>, memref<64x64xi32> - } +// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor_and_slices_static_large_tensor +// CHECK-SAME: %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<100x100xf32>, +// CHECK-SAME: %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<200x200xf32>, +// CHECK-SAME: %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<300x300xf32> +func.func @hoist_vector_transfer_pairs_tensor_and_slices_static_large_tensor( + %tensor0: tensor<100x100xf32>, %tensor1: tensor<200x200xf32>, %tensor2: tensor<300x300xf32>, + %val: index, %lb : index, %ub : index, %step: index) -> + ( + tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32> + ) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + + // CHECK: scf.for %[[I:.*]] = {{.*}} iter_args( + // CHECK-SAME: %[[TENSOR0_ARG:[0-9a-zA-Z]+]] = %[[TENSOR0]], + // CHECK-SAME: %[[TENSOR1_ARG:[0-9a-zA-Z]+]] = %[[TENSOR1]], + // CHECK-SAME: %[[TENSOR2_ARG:[0-9a-zA-Z]+]] = %[[TENSOR2]] + // CHECK-SAME: ) -> + // CHECK-SAME: (tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32> + %0:3 = scf.for %i = %lb to %ub step %step + iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2) + -> (tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>) { + + // Hoisted + // CHECK: %[[ST0:.*]] = tensor.extract_slice %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}}: tensor<100x100xf32> to tensor + // CHECK: %[[V0:.*]] = vector.transfer_read %[[ST0]]{{.*}} : tensor, vector<1xf32> + + // CHECK: %[[R:.*]]:3 = scf.for %[[J:.*]] = {{.*}} iter_args( + // CHECK-SAME: %[[TENSOR1_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR1_ARG]] + // CHECK-SAME: %[[TENSOR2_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR2_ARG]] + // CHECK-SAME: %[[V0_ARG_L2:[0-9a-zA-Z]+]] = %[[V0]] + // CHECK-SAME: ) -> + // CHECK-SAME: (tensor<200x200xf32>, tensor<300x300xf32>, vector<1xf32> + %1:3 = scf.for %j = %lb to %ub step %step + iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2) + -> (tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>) { + // Hoists. + %st0 = tensor.extract_slice %arg6[%i, %i][%step, %step][1, 1] : tensor<100x100xf32> to tensor + %r0 = vector.transfer_read %st0[%c0, %c0], %cst: tensor, vector<1xf32> + + // CHECK: %[[ST1:.*]] = tensor.extract_slice %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<200x200xf32> to tensor + // CHECK: %[[V1:.*]] = vector.transfer_read %[[ST1]]{{.*}} : tensor, vector<2xf32> + // Does not hoist (slice depends on %j) + %st1 = tensor.extract_slice %arg7[%j, %c0][%step, %step][1, 1] : tensor<200x200xf32> to tensor + %r1 = vector.transfer_read %st1[%c0, %c0], %cst: tensor, vector<2xf32> + + // CHECK: %[[ST2:.*]] = tensor.extract_slice %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<300x300xf32> to tensor + // CHECK: %[[V2:.*]] = vector.transfer_read %[[ST2]]{{.*}} : tensor, vector<3xf32> + // Does not hoist, 2 slice %arg8. + %st2 = tensor.extract_slice %arg8[%i, %c0][%step, %step][1, 1] : tensor<300x300xf32> to tensor + %r2 = vector.transfer_read %st2[%c0, %c0], %cst: tensor, vector<3xf32> + + // CHECK: %[[U0:.*]] = "some_use"(%[[V0_ARG_L2]]) : (vector<1xf32>) -> vector<1xf32> + // CHECK: %[[U1:.*]] = "some_use"(%[[V1]]) : (vector<2xf32>) -> vector<2xf32> + // CHECK: %[[U2:.*]] = "some_use"(%[[V2]]) : (vector<3xf32>) -> vector<3xf32> + %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32> + %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32> + %u2 = "some_use"(%r2) : (vector<3xf32>) -> vector<3xf32> + + // Hoists + %w0 = vector.transfer_write %u0, %st0[%c0, %c0] : vector<1xf32>, tensor + + // CHECK-DAG: %[[STI1:.*]] = vector.transfer_write %[[U1]], %{{.*}} : vector<2xf32>, tensor + // Does not hoist (associated slice depends on %j). + %w1 = vector.transfer_write %u1, %st1[%i, %i] : vector<2xf32>, tensor + + // CHECK-DAG: %[[STI2:.*]] = vector.transfer_write %[[U2]], %{{.*}} : vector<3xf32>, tensor + // Does not hoist, 2 slice / insert_slice for %arg8. + %w2 = vector.transfer_write %u2, %st2[%c0, %c0] : vector<3xf32>, tensor + + // Hoists. + %sti0 = tensor.insert_slice %w0 into %arg6[%i, %i][%step, %step][1, 1] : tensor into tensor<100x100xf32> + + // CHECK-DAG: tensor.insert_slice %[[STI1]] into %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor into tensor<200x200xf32> + // Does not hoist (depends on %j). + %sti1 = tensor.insert_slice %w1 into %arg7[%j, %c0][%step, %step][1, 1] : tensor into tensor<200x200xf32> + + // CHECK-DAG: tensor.insert_slice %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor into tensor<300x300xf32> + // Does not hoist, 2 slice / insert_slice for %arg8. + %sti2 = tensor.insert_slice %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor into tensor<300x300xf32> + // Extract with a different stride to make sure we cannot fold this extract with the above insert. + %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][2, 1] : tensor<300x300xf32> to tensor + %sti22 = tensor.insert_slice %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor into tensor<300x300xf32> + + // CHECK: scf.yield {{.*}} : tensor<200x200xf32>, tensor<300x300xf32>, vector<1xf32> + // CHECK: } + scf.yield %sti0, %sti1, %sti22: + tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32> } + + // Hoisted + // CHECK: %[[STI0:.*]] = vector.transfer_write %[[R]]#2, %[[ST0]]{{.*}} : vector<1xf32>, tensor + // CHECK: tensor.insert_slice %[[STI0]] into %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}} : tensor into tensor<100x100xf32> + + // CHECK: scf.yield {{.*}} : tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32> + scf.yield %1#0, %1#1, %1#2 : + tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32> + + // CHECK: } } - return + return %0#0, %0#1, %0#2 : tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32> } transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation - transform.structured.hoist_redundant_vector_transfers %0 + transform.structured.hoist_redundant_tensor_subsets %0 : (!pdl.operation) -> !pdl.operation }