diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -61,6 +61,28 @@ ValueRange newYieldedValues, bool replaceLoopResults = true); +/// Replace the `loop` with `newIterOperands` added as new initialization +/// values. `newYieldValuesFn` is a callback that can be used to specify +/// the additional values to be yielded by the loop. The number of +/// values returned by the callback should match the number of new +/// initialization values. This function +/// - Moves (i.e. doesnt clone) operations from the `loop` to the newly created +/// loop +/// - Replaces the uses of `loop` with the new loop. +/// - `loop` isnt erased, but is left in a "no-op" state where the body of the +/// loop just yields the basic block arguments that correspond to the +/// initialization values of a loop. The loop is dead after this method. +/// - All uses of the `newIterOperands` within the generated new loop +/// are replaced with the corresponding `BlockArgument` in the loop body. +/// TODO: This method could be used instead of `cloneWithNewYields`. Making +/// this change though hits assertions in the walk mechanism that is unrelated +/// to this method itself. +using NewYieldValueFn = std::function( + OpBuilder &b, Location loc, ArrayRef newBBArgs)>; +scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, + ValueRange newIterOperands, + NewYieldValueFn newYieldValuesFn); + /// Outline a region with a single block into a new FuncOp. /// Assumes the FuncOp result types is the type of the yielded operands of the /// single block. This constraint makes it easy to determine the result. diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -91,6 +91,70 @@ return newLoop; } +scf::ForOp mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, + ValueRange newIterOperands, + NewYieldValueFn newYieldValuesFn) { + // Create a new loop before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(loop); + auto operands = llvm::to_vector(loop.getIterOperands()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + scf::ForOp newLoop = builder.create( + loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), + operands, [](OpBuilder &, Location, Value, ValueRange) {}); + + Block *loopBody = loop.getBody(); + Block *newLoopBody = newLoop.getBody(); + + // Move the body of the original loop to the new loop. + newLoopBody->getOperations().splice(newLoopBody->end(), + loopBody->getOperations()); + + // Generate the new yield values to use by using the callback and ppend the + // yield values to the scf.yield operation. + auto yield = cast(newLoopBody->getTerminator()); + ArrayRef newBBArgs = + newLoopBody->getArguments().take_back(newIterOperands.size()); + { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(yield); + SmallVector newYieldedValues = + newYieldValuesFn(builder, loop.getLoc(), newBBArgs); + assert(newIterOperands.size() == newYieldedValues.size() && + "expected as many new yield values as new iter operands"); + yield.getResultsMutable().append(newYieldedValues); + } + + // Remap the BlockArguments from the original loop to the new loop + // BlockArguments. + ArrayRef bbArgs = loopBody->getArguments(); + for (auto it : + llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size()))) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + + // Replace all uses of `newIterOperands` with the corresponding basic block + // arguments. + for (auto it : llvm::zip(newIterOperands, newBBArgs)) { + std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newLoop->isProperAncestor(user); + }); + } + + // Replace all uses of the original loop with corresponding values from the + // new loop. + loop.replaceAllUsesWith( + newLoop.getResults().take_front(loop.getNumResults())); + + // Add a fake yield to the original loop body that just returns the + // BlockArguments corresponding to the iter_args. This makes it a no-op loop. + // The loop is dead. The caller is expected to erase it. + builder.setInsertionPointToEnd(loopBody); + builder.create(loop->getLoc(), loop.getRegionIterArgs()); + + return newLoop; +} + /// Outline a region with a single block into a new FuncOp. /// Assumes the FuncOp result types is the type of the yielded operands of the /// single block. This constraint makes it easy to determine the result. diff --git a/mlir/test/Transforms/scf-loop-utils.mlir b/mlir/test/Transforms/scf-loop-utils.mlir --- a/mlir/test/Transforms/scf-loop-utils.mlir +++ b/mlir/test/Transforms/scf-loop-utils.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils -mlir-disable-threading %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils=test-clone-with-new-yields -mlir-disable-threading %s | FileCheck %s // CHECK-LABEL: @hoist // CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index, diff --git a/mlir/test/Transforms/scf-replace-with-new-yields.mlir b/mlir/test/Transforms/scf-replace-with-new-yields.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/scf-replace-with-new-yields.mlir @@ -0,0 +1,21 @@ + +// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils=test-replace-with-new-yields -mlir-disable-threading %s | FileCheck %s + +func.func @doubleup(%lb: index, %ub: index, %step: index, %extra_arg: f32) -> f32 { + %0 = scf.for %i = %lb to %ub step %step iter_args(%iter = %extra_arg) -> (f32) { + %1 = arith.addf %iter, %iter : f32 + scf.yield %1: f32 + } + return %0: f32 +} +// CHECK-LABEL: func @doubleup +// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]: f32 +// CHECK: %[[NEWLOOP:.+]]:2 = scf.for +// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[ARG]], %[[INIT2:.+]] = %[[ARG]] +// CHECK: %[[DOUBLE:.+]] = arith.addf %[[INIT1]], %[[INIT1]] +// CHECK: %[[DOUBLE2:.+]] = arith.addf %[[DOUBLE]], %[[DOUBLE]] +// CHECK: scf.yield %[[DOUBLE]], %[[DOUBLE2]] +// CHECK: %[[OLDLOOP:.+]] = scf.for +// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[ARG]]) +// CHECK: scf.yield %[[INIT]] +// CHECK: return %[[NEWLOOP]]#0 diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -32,29 +32,67 @@ StringRef getArgument() const final { return "test-scf-for-utils"; } StringRef getDescription() const final { return "test scf.for utils"; } explicit TestSCFForUtilsPass() = default; + TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {} + + Option testCloneWithNewYields{ + *this, "test-clone-with-new-yields", + llvm::cl::desc( + "Test cloning of a loop while returning additional yield values"), + llvm::cl::init(false)}; + + Option testReplaceWithNewYields{ + *this, "test-replace-with-new-yields", + llvm::cl::desc("Test replacing a loop with a new loop that returns new " + "additional yeild values"), + llvm::cl::init(false)}; void runOnOperation() override { func::FuncOp func = getOperation(); SmallVector toErase; - func.walk([&](Operation *fakeRead) { - if (fakeRead->getName().getStringRef() != "fake_read") - return; - auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner(); - auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner(); - auto loop = fakeRead->getParentOfType(); - - OpBuilder b(loop); - loop.moveOutOfLoop(fakeRead); - fakeWrite->moveAfter(loop); - auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0), - fakeCompute->getResult(0)); - fakeCompute->getResult(0).replaceAllUsesWith( - newLoop.getResults().take_back()[0]); - toErase.push_back(loop); - }); - for (auto loop : llvm::reverse(toErase)) - loop.erase(); + if (testCloneWithNewYields) { + func.walk([&](Operation *fakeRead) { + if (fakeRead->getName().getStringRef() != "fake_read") + return; + auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner(); + auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner(); + auto loop = fakeRead->getParentOfType(); + + OpBuilder b(loop); + loop.moveOutOfLoop(fakeRead); + fakeWrite->moveAfter(loop); + auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0), + fakeCompute->getResult(0)); + fakeCompute->getResult(0).replaceAllUsesWith( + newLoop.getResults().take_back()[0]); + toErase.push_back(loop); + }); + for (auto loop : llvm::reverse(toErase)) + loop.erase(); + } + + if (testReplaceWithNewYields) { + func.walk([&](scf::ForOp forOp) { + if (forOp.getNumResults() == 0) + return; + auto newInitValues = forOp.getInitArgs(); + if (newInitValues.empty()) + return; + NewYieldValueFn fn = [&](OpBuilder &b, Location loc, + ArrayRef newBBArgs) { + Block *block = newBBArgs.front().getOwner(); + SmallVector newYieldValues; + for (auto yieldVal : + cast(block->getTerminator()).getResults()) { + newYieldValues.push_back( + b.create(loc, yieldVal, yieldVal)); + } + return newYieldValues; + }; + OpBuilder b(forOp); + replaceLoopWithNewYields(b, forOp, newInitValues, fn); + }); + } } }; @@ -88,7 +126,8 @@ "__test_pipelining_loop__"; static const StringLiteral kTestPipeliningStageMarker = "__test_pipelining_stage__"; -/// Marker to express the order in which operations should be after pipelining. +/// Marker to express the order in which operations should be after +/// pipelining. static const StringLiteral kTestPipeliningOpOrderMarker = "__test_pipelining_op_order__";