diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h b/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h @@ -9,8 +9,10 @@ #ifndef MLIR_DIALECT_LINALG_TRANSFORMS_HOIST_PADDING_H_ #define MLIR_DIALECT_LINALG_TRANSFORMS_HOIST_PADDING_H_ +#include "mlir/Support/LogicalResult.h" + namespace mlir { -struct LogicalResult; +class Value; namespace linalg { class PadTensorOp; @@ -57,7 +59,8 @@ /// } /// } /// ``` -LogicalResult hoistPaddingOnTensors(PadTensorOp &padTensorOp, int nLoops); +FailureOr hoistPaddingOnTensors(PadTensorOp opToHoist, int numLoops, + PadTensorOp &hoistedOp); } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -355,11 +355,12 @@ ValueRange{ivVal, lbVal, stepVal}); } -LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp, - int nLoops) { - LLVM_DEBUG(DBGS() << "Try to hoist " << *(padTensorOp) << " by " << nLoops +FailureOr mlir::linalg::hoistPaddingOnTensors(PadTensorOp opToHoist, + int numLoops, + PadTensorOp &hoistedOp) { + LLVM_DEBUG(DBGS() << "Try to hoist " << *(opToHoist) << " by " << numLoops << " loops\n"); - HoistingAnalysis analysis(padTensorOp, nLoops); + HoistingAnalysis analysis(opToHoist, numLoops); if (!analysis.isValid()) { LLVM_DEBUG(DBGS() << "Analysis failed -> Skip\n"); return failure(); @@ -376,8 +377,8 @@ // Update actual number of loops, which may be smaller. int nPackedLoops = analysis.packingLoops.size(); - Location loc = padTensorOp->getLoc(); - RankedTensorType paddedTensorType = padTensorOp.getResultType(); + Location loc = opToHoist->getLoc(); + RankedTensorType paddedTensorType = opToHoist.getResultType(); int paddedRank = paddedTensorType.getRank(); // Create the packed tensor into which we amortize @@ -404,8 +405,8 @@ clonedLoopIvs.reserve(nPackedLoops); leadingPackedTensorIndexings.reserve(nPackedLoops); BlockAndValueMapping bvm; - // Insert `padTensorOp` into the backwardSlice so we clone it too. - analysis.backwardSlice.insert(padTensorOp); + // Insert `opToHoist` into the backwardSlice so we clone it too. + analysis.backwardSlice.insert(opToHoist); // Stack step 1. iteratively clone loops and push `packedTensor`. for (Operation *op : analysis.backwardSlice) { // Specifically sit out in the extract_slice(packedTensor) case: this is the @@ -466,7 +467,7 @@ b.getIndexAttr(1)); Value inserted = - b.create(loc, bvm.lookup(padTensorOp.result()), + b.create(loc, bvm.lookup(opToHoist.result()), packedTensor, offsets, sizes, strides); // Stack step 3. iteratively pop the stack and propagate the yield. @@ -480,7 +481,7 @@ // Now the packed tensor is ready, replace the original padding op by a // 1x..x1 slice [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1]. - b.setInsertionPoint(padTensorOp); + b.setInsertionPoint(opToHoist); SmallVector loopIterationCounts = llvm::to_vector<4>( llvm::map_range(analysis.packingLoops, [&](Operation *loop) { return buildLoopIterationCount(b, outer, cast(loop)); @@ -495,18 +496,10 @@ // strides = [1 .. 1] (defined above) packedTensor = scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0); - padTensorOp.replaceAllUsesWith( - b.create(loc, padTensorOp.getResultType(), - packedTensor, offsets, sizes, strides) - ->getResult(0)); + Value newResult = b.create( + loc, opToHoist.getResultType(), packedTensor, offsets, sizes, strides); - Operation *toErase = padTensorOp; - - // Make the newly cloned `padTensorOp` available to the caller. - padTensorOp = - cast(bvm.lookup(padTensorOp.result()).getDefiningOp()); - - toErase->erase(); - - return success(); + // Make the newly cloned `opToHoist` available to the caller. + hoistedOp = cast(bvm.lookup(opToHoist.result()).getDefiningOp()); + return newResult; } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -771,7 +771,13 @@ /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); if (testHoistPadding) { getFunction().walk([&](linalg::PadTensorOp padTensorOp) { - (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); + PadTensorOp hoistedOp; + FailureOr newResult = linalg::hoistPaddingOnTensors( + padTensorOp, testHoistPadding, hoistedOp); + if (succeeded(newResult)) { + padTensorOp.getResult().replaceAllUsesWith(newResult.getValue()); + padTensorOp->erase(); + } }); } if (testInterchangePattern.hasValue())