diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -37,30 +37,6 @@ class ParallelOp; } // namespace scf -/// Create a clone of `loop` with `newIterOperands` added as new initialization -/// values and `newYieldedValues` added as new yielded values. The returned -/// ForOp has `newYieldedValues.size()` new result values. The `loop` induction -/// variable and `newIterOperands` are remapped to the new induction variable -/// and the new entry block arguments respectively. -/// -/// Additionally, if `replaceLoopResults` is true, all uses of -/// `loop.getResults()` are replaced with the first `loop.getNumResults()` -/// return values respectively. This additional replacement is provided as a -/// convenience to update the consumers of `loop`, in the case e.g. when `loop` -/// is soon to be deleted. -/// -/// Return the cloned loop. -/// -/// This convenience function is useful to factorize common mechanisms related -/// to hoisting roundtrips to memory into yields. It does not perform any -/// legality checks. -/// -/// Prerequisite: `newYieldedValues.size() == newYieldedValues.size()`. -scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop, - ValueRange newIterOperands, - ValueRange newYieldedValues, - bool replaceLoopResults = true); - /// Replace the `loop` with `newIterOperands` added as new initialization /// values. `newYieldValuesFn` is a callback that can be used to specify /// the additional values to be yielded by the loop. The number of @@ -74,9 +50,6 @@ /// initialization values of a loop. The loop is dead after this method. /// - All uses of the `newIterOperands` within the generated new loop /// are replaced with the corresponding `BlockArgument` in the loop body. -/// TODO: This method could be used instead of `cloneWithNewYields`. Making -/// this change though hits assertions in the walk mechanism that is unrelated -/// to this method itself. using NewYieldValueFn = std::function( OpBuilder &b, Location loc, ArrayRef newBBArgs)>; scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -305,8 +305,13 @@ // Rewrite `loop` with additional new yields. OpBuilder b(read.transferReadOp); - auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.getVector(), - write.transferWriteOp.getVector()); + NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc, + ArrayRef newBBArgs) { + return SmallVector{write.transferWriteOp.getVector()}; + }; + auto newForOp = replaceLoopWithNewYields( + b, forOp, read.transferReadOp.getVector(), yieldFn); + // Transfer write has been hoisted, need to update the vector and tensor // source. Replace the result of the loop to use the new tensor created // outside the loop. @@ -397,10 +402,9 @@ while (changed) { changed = false; // First move loop invariant ops outside of their loop. This needs to be - // done before as we cannot move ops without interputing the function walk. - func.walk([&](LoopLikeOpInterface loopLike) { - moveLoopInvariantCode(loopLike); - }); + // done before as we cannot move ops without interrupting the function walk. + func.walk( + [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); func.walk([&](vector::TransferReadOp transferRead) { if (!transferRead.getShapedType().isa()) @@ -492,13 +496,16 @@ // Rewrite `loop` with new yields by cloning and erase the original loop. OpBuilder b(transferRead); - auto newForOp = cloneWithNewYields(b, loop, transferRead.getVector(), - transferWrite.getVector()); - - // Transfer write has been hoisted, need to update the written value to + NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc, + ArrayRef newBBArgs) { + return SmallVector{transferWrite.getVector()}; + }; + auto newForOp = + replaceLoopWithNewYields(b, loop, transferRead.getVector(), yieldFn); + + // Transfer write has been hoisted, need to update the written vector by // the value yielded by the newForOp. - transferWrite.getVector().replaceAllUsesWith( - newForOp.getResults().take_back()[0]); + transferWrite.getVectorMutable().assign(newForOp.getResults().back()); changed = true; loop.erase(); diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -36,61 +36,6 @@ }; } // namespace -scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop, - ValueRange newIterOperands, - ValueRange newYieldedValues, - bool replaceLoopResults) { - assert(newIterOperands.size() == newYieldedValues.size() && - "newIterOperands must be of the same size as newYieldedValues"); - - // Create a new loop before the existing one, with the extra operands. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(loop); - auto operands = llvm::to_vector<4>(loop.getIterOperands()); - operands.append(newIterOperands.begin(), newIterOperands.end()); - scf::ForOp newLoop = - b.create(loop.getLoc(), loop.getLowerBound(), - loop.getUpperBound(), loop.getStep(), operands); - - auto &loopBody = *loop.getBody(); - auto &newLoopBody = *newLoop.getBody(); - // Clone / erase the yield inside the original loop to both: - // 1. augment its operands with the newYieldedValues. - // 2. automatically apply the BlockAndValueMapping on its operand - auto yield = cast(loopBody.getTerminator()); - b.setInsertionPoint(yield); - auto yieldOperands = llvm::to_vector<4>(yield.getOperands()); - yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end()); - auto newYield = b.create(yield.getLoc(), yieldOperands); - - // Clone the loop body with remaps. - BlockAndValueMapping bvm; - // a. remap the induction variable. - bvm.map(loop.getInductionVar(), newLoop.getInductionVar()); - // b. remap the BB args. - bvm.map(loopBody.getArguments(), - newLoopBody.getArguments().take_front(loopBody.getNumArguments())); - // c. remap the iter args. - bvm.map(newIterOperands, - newLoop.getRegionIterArgs().take_back(newIterOperands.size())); - b.setInsertionPointToStart(&newLoopBody); - // Skip the original yield terminator which does not have enough operands. - for (auto &o : loopBody.without_terminator()) - b.clone(o, bvm); - - // Replace `loop`'s results if requested. - if (replaceLoopResults) { - for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( - loop.getNumResults()))) - std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); - } - - // TODO: this is unsafe in the context of a PatternRewrite. - newYield.erase(); - - return newLoop; -} - scf::ForOp mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, ValueRange newIterOperands, NewYieldValueFn newYieldValuesFn) { @@ -110,7 +55,7 @@ newLoopBody->getOperations().splice(newLoopBody->end(), loopBody->getOperations()); - // Generate the new yield values to use by using the callback and ppend the + // Generate the new yield values to use by using the callback and append the // yield values to the scf.yield operation. auto yield = cast(newLoopBody->getTerminator()); ArrayRef newBBArgs = diff --git a/mlir/test/Transforms/scf-loop-utils.mlir b/mlir/test/Transforms/scf-loop-utils.mlir deleted file mode 100644 --- a/mlir/test/Transforms/scf-loop-utils.mlir +++ /dev/null @@ -1,40 +0,0 @@ -// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils=test-clone-with-new-yields -mlir-disable-threading %s | FileCheck %s - -// CHECK-LABEL: @hoist -// CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index, -// CHECK-SAME: %[[ub:[a-zA-Z0-9]*]]: index, -// CHECK-SAME: %[[step:[a-zA-Z0-9]*]]: index -func.func @hoist(%lb: index, %ub: index, %step: index) { - // CHECK: %[[A:.*]] = "fake_read"() : () -> index - // CHECK: %[[RES:.*]] = scf.for %[[I:.*]] = %[[lb]] to %[[ub]] step %[[step]] iter_args(%[[VAL:.*]] = %[[A]]) -> (index) - // CHECK: %[[YIELD:.*]] = "fake_compute"(%[[VAL]]) : (index) -> index - // CHECK: scf.yield %[[YIELD]] : index - // CHECK: "fake_write"(%[[RES]]) : (index) -> () - scf.for %i = %lb to %ub step %step { - %0 = "fake_read"() : () -> (index) - %1 = "fake_compute"(%0) : (index) -> (index) - "fake_write"(%1) : (index) -> () - } - return -} - -// CHECK-LABEL: @hoist2 -// CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index, -// CHECK-SAME: %[[ub:[a-zA-Z0-9]*]]: index, -// CHECK-SAME: %[[step:[a-zA-Z0-9]*]]: index -// CHECK-SAME: %[[extra_arg:[a-zA-Z0-9]*]]: f32 -func.func @hoist2(%lb: index, %ub: index, %step: index, %extra_arg: f32) -> f32 { - // CHECK: %[[A:.*]] = "fake_read"() : () -> index - // CHECK: %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[lb]] to %[[ub]] step %[[step]] iter_args(%[[VAL0:.*]] = %[[extra_arg]], %[[VAL1:.*]] = %[[A]]) -> (f32, index) - // CHECK: %[[YIELD:.*]] = "fake_compute"(%[[VAL1]]) : (index) -> index - // CHECK: scf.yield %[[VAL0]], %[[YIELD]] : f32, index - // CHECK: "fake_write"(%[[RES]]#1) : (index) -> () - // CHECK: return %[[RES]]#0 : f32 - %0 = scf.for %i = %lb to %ub step %step iter_args(%iter = %extra_arg) -> (f32) { - %0 = "fake_read"() : () -> (index) - %1 = "fake_compute"(%0) : (index) -> (index) - "fake_write"(%1) : (index) -> () - scf.yield %iter: f32 - } - return %0: f32 -} diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -34,12 +34,6 @@ explicit TestSCFForUtilsPass() = default; TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {} - Option testCloneWithNewYields{ - *this, "test-clone-with-new-yields", - llvm::cl::desc( - "Test cloning of a loop while returning additional yield values"), - llvm::cl::init(false)}; - Option testReplaceWithNewYields{ *this, "test-replace-with-new-yields", llvm::cl::desc("Test replacing a loop with a new loop that returns new " @@ -50,27 +44,6 @@ func::FuncOp func = getOperation(); SmallVector toErase; - if (testCloneWithNewYields) { - func.walk([&](Operation *fakeRead) { - if (fakeRead->getName().getStringRef() != "fake_read") - return; - auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner(); - auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner(); - auto loop = fakeRead->getParentOfType(); - - OpBuilder b(loop); - loop.moveOutOfLoop(fakeRead); - fakeWrite->moveAfter(loop); - auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0), - fakeCompute->getResult(0)); - fakeCompute->getResult(0).replaceAllUsesWith( - newLoop.getResults().take_back()[0]); - toErase.push_back(loop); - }); - for (auto loop : llvm::reverse(toErase)) - loop.erase(); - } - if (testReplaceWithNewYields) { func.walk([&](scf::ForOp forOp) { if (forOp.getNumResults() == 0)