diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/AsmState.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/Support/Debug.h" @@ -153,9 +154,9 @@ bool isValid() { return valid.has_value() && valid.value(); } bool isInvalid() { return valid.has_value() && !valid.value(); } - /// Footprint of the packedTensor, computed from the packingLoops. - SmallVector getPackedTensorSizes(RewriterBase &rewriter, - Location loc) const; + /// Footprint of the hoistedPackedTensor, computed from the packingLoops. + SmallVector getHoistedPackedTensorSizes(RewriterBase &rewriter, + Location loc) const; /// Performs optional hoisting to enable hoist padding to occur. This may be /// necessary when `sliceOp` is not defined outside of the outermost enclosing @@ -450,8 +451,8 @@ } SmallVector -HoistPaddingAnalysis::getPackedTensorSizes(RewriterBase &rewriter, - Location loc) const { +HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter, + Location loc) const { SmallVector dynamicTensorSizes; // Upper bound the packing loop lengths to size the packed tensor. Taking @@ -525,7 +526,8 @@ // Build a packing loop nest by iteratively traversing the backward slice and // clone the operations, iteratively stepping into the loops that we encounter. // The implementation proceeds in a stack-like fashion: -// 1. Iteratively clone and step into the loops, pushing the `packedTensor` +// 1. Iteratively clone and step into the loops, pushing the +// `hoistedPackedTensor` // deeper in the stack. // 2. At the innermost loop level, create a GenericOp if `transposeVector` is // non-empty. @@ -537,7 +539,7 @@ ArrayRef transposeVector, RankedTensorType transposedTensorType, tensor::EmptyOp emptyOp, const HoistPaddingAnalysis &analysis) { SmallVector offsets, sizes, strides; - SmallVector clonedLoopIvs, leadingPackedTensorIndexings; + SmallVector clonedLoopIvs, leadingHoistedPackedTensorIndexings; scf::ForOp outerLoop = analysis.outermostEnclosingForOp; @@ -558,14 +560,14 @@ bbArg = operand.get().dyn_cast(); } - // Step 1. iteratively clone loops and push `packedTensor`. - Value packedTensor = emptyOp.getResult(); + // Step 1. iteratively clone loops and push `hoistedPackedTensor`. + Value hoistedPackedTensor = emptyOp.getResult(); OpBuilder::InsertionGuard g(rewriter); for (Operation *op : analysis.backwardSlice) { - // Specifically sit out in the extract_slice(packedTensor) case: this is - // the piece we seek to replace. + // Specifically sit out in the extract_slice(hoistedPackedTensor) case: this + // is the piece we seek to replace. if (auto sliceOp = dyn_cast(op)) { - if (bvm.lookupOrDefault(sliceOp.getSource()) == packedTensor) { + if (bvm.lookupOrDefault(sliceOp.getSource()) == hoistedPackedTensor) { LLVM_DEBUG(DBGS() << "--Skip: " << sliceOp << "\n"); continue; } @@ -579,11 +581,12 @@ continue; } - // Create a packing loop that takes `packedTensor` as iteration argument. + // Create a packing loop that takes `hoistedPackedTensor` as iteration + // argument. auto clonedForOp = rewriter.create( loc, bvm.lookupOrDefault(forOp.getLowerBound()), bvm.lookupOrDefault(forOp.getUpperBound()), - bvm.lookupOrDefault(forOp.getStep()), packedTensor); + bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor); // Map the induction var, region args and results to the `clonedForOp`. bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar()); @@ -600,16 +603,18 @@ // Assert the loop-independent iteration count can be computed. if (!loopIndependentIterationCount) llvm_unreachable("loop independence prerequisite not met"); - leadingPackedTensorIndexings.push_back(loopIndependentIterationCount); - packedTensor = clonedForOp.getRegionIterArgs().front(); + leadingHoistedPackedTensorIndexings.push_back( + loopIndependentIterationCount); + hoistedPackedTensor = clonedForOp.getRegionIterArgs().front(); } // Step 2. Construct offsets, sizes and strides for the innermost level of the // packing loop. int64_t nPackedLoops = clonedLoopIvs.size(); // offsets = [clonedLoopIvs, 0 .. 0]. - offsets = SmallVector{leadingPackedTensorIndexings.begin(), - leadingPackedTensorIndexings.end()}; + offsets = + SmallVector{leadingHoistedPackedTensorIndexings.begin(), + leadingHoistedPackedTensorIndexings.end()}; offsets.append(paddedRank, rewriter.getIndexAttr(0)); // sizes = [1 .. 1, transposedShape]. sizes = SmallVector(nPackedLoops, rewriter.getIndexAttr(1)); @@ -627,7 +632,8 @@ Value paddedTensor = bvm.lookup(opToHoist.getResult()); if (!transposeVector.empty()) { Value outputTensor = rewriter.create( - loc, transposedTensorType, packedTensor, offsets, sizes, strides); + loc, transposedTensorType, hoistedPackedTensor, offsets, sizes, + strides); maybeTransposeOp = makeTransposeOp(rewriter, loc, paddedTensor, outputTensor, transposeVector); paddedTensor = maybeTransposeOp.getResult(0); @@ -638,7 +644,7 @@ // Step 4. Create InsertSliceOp at the innermost loop level, inserting an // optionally transposed padded slice into the packed tensor. Value inserted = rewriter.create( - loc, paddedTensor, packedTensor, offsets, sizes, strides); + loc, paddedTensor, hoistedPackedTensor, offsets, sizes, strides); // Step 5. Iteratively pop the stack and propagate the yield. Value valueToYield = inserted; @@ -655,7 +661,7 @@ sizes, strides, clonedLoopIvs, - leadingPackedTensorIndexings, + leadingHoistedPackedTensorIndexings, maybeTransposeOp, cast(bvm.lookup(opToHoist.getResult()).getDefiningOp())}; } @@ -688,7 +694,7 @@ SmallVector packedShape(nPackedLoops, ShapedType::kDynamic); // TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor. llvm::append_range(packedShape, transposedTensorType->getShape()); - auto packedTensorType = RankedTensorType::get( + auto hoistedPackedTensorType = RankedTensorType::get( packedShape, transposedTensorType->getElementType()); // Set the insertion point right before the outer loop and start packing. @@ -696,10 +702,10 @@ OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(outerLoop); SmallVector dynamicTensorSizes = - analysis.getPackedTensorSizes(rewriter, loc); + analysis.getHoistedPackedTensorSizes(rewriter, loc); auto emptyOp = rewriter.create( - loc, packedTensorType.getShape(), packedTensorType.getElementType(), - dynamicTensorSizes); + loc, hoistedPackedTensorType.getShape(), + hoistedPackedTensorType.getElementType(), dynamicTensorSizes); return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector, *transposedTensorType, emptyOp, analysis); @@ -727,14 +733,71 @@ // hoistPaddingOnTensors Implementation. //===----------------------------------------------------------------------===// -// If the original consumer of `sliceOp` was a `forOp` (i.e. through an iter -// arg), propagate the `packedTensor` value through the same iter arg. -// TODO: for multiple loops we need to track the use to the innermost loop. -static Value padThroughLoopIterArg(RewriterBase &rewriter, Value packedTensor, - tensor::ExtractSliceOp sliceOp, - scf::ForOp forOp) { +/// Return true if we can walk back the use-def chain from `extractSliceOp` to +/// expectedSource going through DestinationStyleOpInterface inits only. +/// This is a poor man's analysis that is sufficient to check the extractSliceOp +/// the matches tensor.pad we want to hoist. +/// In the future, it will be easier to ensure this with a matching symmetric +/// tensor.unpad op. +static bool tracesBackToExpectedValue(tensor::ExtractSliceOp extractSliceOp, + Value expectedSource) { + LLVM_DEBUG(DBGS() << "Start tracesBackToExpectedValue on: " << extractSliceOp + << "\n"); + LLVM_DEBUG(DBGS() << "--with extractSlice: " << extractSliceOp << "\n"); + Value source = extractSliceOp.getSource(); + LLVM_DEBUG(DBGS() << "--with starting source: " << source << "\n"); + while (source && source != expectedSource) { + auto destOp = + dyn_cast_or_null(source.getDefiningOp()); + if (!destOp) + break; + LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n"); + source = + destOp.getDpsInitOperand(source.cast().getResultNumber()) + ->get(); + } + LLVM_DEBUG(DBGS() << "--final source: " << source << "\n"); + LLVM_DEBUG(DBGS() << "--expected source: " << expectedSource << "\n"); + return source == expectedSource; +} + +/// If the original consumer of `outerSliceOp` was a `forOp` (i.e. through an +/// iter arg), propagate the `hoistedPackedTensor` value through the same iter +/// arg. +/// TODO: for multiple loops we need to track the use to the innermost loop. +/// +/// Match: +/// ``` +/// %outerSliceOp = tensor.extract_slice .. +/// %f = scf.for ... iter_args(%arg0 = %outerSliceOp) { +/// %hoistedPackedTensor = tensor.pad %arg0 +/// %1 = compute %hoistedPackedTensor +/// %2 = tensor.extract_slice %1 +/// scf.yield %2 +/// } +/// ``` +/// +/// and rewrite as: +/// ``` +/// %outerSliceOp = tensor.extract_slice .. +/// %hoistedPackedTensor = tensor.pad %outerSliceOp +/// %f = scf.for ... iter_args(%arg0 = %hoistedPackedTensor) { +/// %1 = compute %arg0 +/// scf.yield %1 +/// } +/// %2 = tensor.extract_slice %forOp +/// ``` +/// +/// Return null when no rewrite happened. +static tensor::ExtractSliceOp +padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting, + Value hoistedPackedTensor, + tensor::ExtractSliceOp outerSliceOp, scf::ForOp forOp) { + LLVM_DEBUG(DBGS() << "Start padThroughLoopIterArg on: " << forOp << "\n"); + LLVM_DEBUG(DBGS() << "--paddedValueBeforeHoisting: " + << paddedValueBeforeHoisting << "\n"); OpOperand *pUse = nullptr; - for (OpOperand &use : sliceOp->getUses()) { + for (OpOperand &use : outerSliceOp->getUses()) { if (use.getOwner() == forOp) { assert(!pUse && "Multiple slice uses in the for loop"); pUse = &use; @@ -742,20 +805,67 @@ } assert(pUse && "No slice use in the for loop"); OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPointAfter(packedTensor.getDefiningOp()); - Value casted = rewriter.create( - packedTensor.getLoc(), pUse->get().getType(), packedTensor); + rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp()); - std::optional operandNumber = + std::optional maybeOperandNumber = forOp.getIterArgNumberForOpOperand(*pUse); - assert(operandNumber.has_value() && "expected a proper iter arg number"); + assert(maybeOperandNumber.has_value() && "expected a proper iter arg number"); + + int64_t operandNumber = maybeOperandNumber.value(); + auto yieldOp = cast(forOp.getBody(0)->getTerminator()); + auto yieldingExtractSliceOp = yieldOp->getOperand(operandNumber) + .getDefiningOp(); + if (!yieldingExtractSliceOp) + return tensor::ExtractSliceOp(); + + // Poor man's analysis sufficient to ensure extractSlice matches tensor.pad. + // In the future, it will be easier to ensure this with a matching symmetric + // tensor.unpad op. + if (!tracesBackToExpectedValue(yieldingExtractSliceOp, + paddedValueBeforeHoisting)) + return tensor::ExtractSliceOp(); SmallVector initArgs = forOp.getInitArgs(); - initArgs[operandNumber.value()] = casted; - rewriter.startRootUpdate(forOp); - forOp.getInitArgsMutable().assign(initArgs); - rewriter.finalizeRootUpdate(forOp); - return forOp.getRegionIterArgForOpOperand(*pUse); + initArgs[operandNumber] = hoistedPackedTensor; + SmallVector yieldOperands = yieldOp.getOperands(); + yieldOperands[operandNumber] = yieldingExtractSliceOp.getSource(); + + int64_t numOriginalForOpResults = initArgs.size(); + LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults + << "\n"); + tensor::ExtractSliceOp extracted; + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(forOp); + extracted = rewriter.create( + hoistedPackedTensor.getLoc(), hoistedPackedTensor, + outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(), + outerSliceOp.getMixedStrides()); + rewriter.replaceAllUsesWith(forOp.getResult(operandNumber), extracted); + } + scf::ForOp newForOp = + replaceLoopWithNewYields(rewriter, forOp, initArgs, yieldOperands); + + LLVM_DEBUG(DBGS() << "newForOp results: " << newForOp.getNumResults() + << "\n"); + LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n"); + LLVM_DEBUG(DBGS() << "with result #" + << numOriginalForOpResults + operandNumber + << " of forOp, giving us: " << extracted << "\n"); + rewriter.startRootUpdate(extracted); + extracted.getSourceMutable().assign( + newForOp.getResult(numOriginalForOpResults + operandNumber)); + rewriter.finalizeRootUpdate(extracted); + + LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting + << "\n"); + LLVM_DEBUG(DBGS() << "with region iter arg #" + << numOriginalForOpResults + operandNumber << "\n"); + rewriter.replaceAllUsesWith( + paddedValueBeforeHoisting, + newForOp.getRegionIterArg(numOriginalForOpResults + operandNumber)); + + return extracted; } /// Produce a tensor extracted from the packingResult. This can be used as a @@ -781,7 +891,7 @@ scf::ForOp outerLoop = analysis.outermostEnclosingForOp; ArrayRef packingLoops = analysis.packingLoops; - Value packedTensor; + Value hoistedPackedTensor; SmallVector loopIterationCounts; SmallVector offsets(nPackedLoops + paddedRank, rewriter.getIndexAttr(0)); @@ -798,29 +908,29 @@ // offsets = [maybe_leading_ivs = originalLoopIvs, 0 .. 0]. std::copy(loopIterationCounts.begin(), loopIterationCounts.end(), offsets.begin()); - packedTensor = + hoistedPackedTensor = scf::getForInductionVarOwner(packingResult.clonedLoopIvs.front()) ->getResult(0); } else { // If no loops were created, this is just hoisting without packing. - packedTensor = bvm.lookup(opToHoist.getResult()); + hoistedPackedTensor = bvm.lookup(opToHoist.getResult()); } - LLVM_DEBUG(DBGS() << "packedTensor: " << packedTensor << "\n"); + LLVM_DEBUG(DBGS() << "hoistedPackedTensor: " << hoistedPackedTensor << "\n"); // If the consumer of `padOp` was a `forOp`, propagate through iter args. scf::ForOp forOp = analysis.padConsumingForOp; if (forOp) { - packedTensor = - padThroughLoopIterArg(rewriter, packedTensor, analysis.sliceOp, forOp); + return padThroughLoopIterArg(rewriter, opToHoist, hoistedPackedTensor, + analysis.sliceOp, forOp); } // offsets = [maybe_leading_ivs, 0 .. 0]. // sizes = [1 .. 1, transposedShape] (defined above). // strides = [1 .. 1] (defined above) return rewriter.create( - loc, transposedTensorType, packedTensor, offsets, packingResult.sizes, - packingResult.strides); + loc, transposedTensorType, hoistedPackedTensor, offsets, + packingResult.sizes, packingResult.strides); } FailureOr mlir::linalg::hoistPaddingOnTensors( diff --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir @@ -161,12 +161,13 @@ // CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>) { // CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} // CHECK: : tensor to tensor<5x25xf32> - // CHECK: scf.for %{{.*}} iter_args(%[[INNER_PADDED:[0-9a-zA-Z]*]] = %[[PADDED]]) -> (tensor<5x25xf32>) + // CHECK: %[[SCF_YIELD:.*]] = scf.for %{{.*}} iter_args(%[[INNER_PADDED:[0-9a-zA-Z]*]] = %[[PADDED]]) -> (tensor<5x25xf32>) // CHECK: %[[RES:.*]] = linalg.matmul {{.*}} outs(%[[INNER_PADDED]] // CHECK-SAME: : tensor<5x25xf32> // CHECK: scf.yield %[[RES]] : tensor<5x25xf32> - // CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<5x25xf32> to tensor - // CHECK: tensor.insert_slice %[[CAST]] into %{{.*}}[%{{.*}}, 0] [%{{.*}}, 25] [1, 1] + // CHECK: %[[EXTRACTED:.*]] = tensor.extract_slice %[[SCF_YIELD]][%{{.*}}, 0] [%{{.*}}, 25] [1, 1] + // CHECK-SAME: : tensor<5x25xf32> to tensor + // CHECK: tensor.insert_slice %[[EXTRACTED]] into %{{.*}}[%{{.*}}, 0] [%{{.*}}, 25] [1, 1] // CHECK-SAME: : tensor into tensor<24x25xf32> %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> func.return %0 : tensor<24x25xf32>