diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -65,6 +65,7 @@ llvm::SmallVector returnValues; // returned async values Value coroHandle; // coroutine handle (!async.coro.handle value) + Block *entry; // coroutine entry block Block *setError; // switch completion token and all values to error state Block *cleanup; // coroutine cleanup block Block *suspend; // coroutine suspension block @@ -73,16 +74,15 @@ /// Utility to partially update the regular function CFG to the coroutine CFG /// compatible with LLVM coroutines switched-resume lowering using -/// `async.runtime.*` and `async.coro.*` operations. Modifies the entry block -/// by prepending its ops with coroutine setup. Also inserts trailing blocks. +/// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block +/// that branches into preexisting entry block. Also inserts trailing blocks. /// /// The result types of the passed `func` must start with an `async.token` /// and be continued with some number of `async.value`s. /// -/// It's up to the caller of this function to fix up the terminators of the -/// preexisting blocks of the passed func op. If the passed `func` is legal, -/// this typically means rewriting every return op as a yield op and a branch op -/// to the suspend block. +/// The func given to this function needs to have been preprocessed to have +/// either branch or yield ops as terminators. Branches to the cleanup block are +/// inserted after each yield. /// /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html /// @@ -102,9 +102,9 @@ /// %value = : !async.value // create async value /// %id = async.coro.id // create a coroutine id /// %hdl = async.coro.begin %id // create a coroutine handle -/// /* other ops of the preexisting entry block */ +/// br ^preexisting_entry_block /// -/// /* other preexisting blocks */ +/// /* preexisting blocks modified to branch to the cleanup block */ /// /// ^set_error: // this block created lazily only if needed (see code below) /// async.runtime.set_error %token : !async.token @@ -125,6 +125,8 @@ MLIRContext *ctx = func.getContext(); Block *entryBlock = &func.getBlocks().front(); + Block *originalEntryBlock = + entryBlock->splitBlock(entryBlock->getOperations().begin()); auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); // ------------------------------------------------------------------------ // @@ -142,6 +144,7 @@ auto coroIdOp = builder.create(CoroIdType::get(ctx)); auto coroHdlOp = builder.create(CoroHandleType::get(ctx), coroIdOp.id()); + builder.create(originalEntryBlock); Block *cleanupBlock = func.addBlock(); Block *suspendBlock = func.addBlock(); @@ -173,11 +176,25 @@ // `async.await` op lowering will create resume blocks for async // continuations, and will conditionally branch to cleanup or suspend blocks. + for (Block &block : func.body().getBlocks()) { + if (&block == entryBlock || &block == cleanupBlock || + &block == suspendBlock) + continue; + Operation *terminator = block.getTerminator(); + if (auto yield = dyn_cast(terminator)) { + builder.setInsertionPointToEnd(&block); + builder.create(cleanupBlock); + } else { + assert(dyn_cast(terminator) && "Unexpected terminator"); + } + } + CoroMachinery machinery; machinery.func = func; machinery.asyncToken = retToken; machinery.returnValues = retValues; machinery.coroHandle = coroHdlOp.handle(); + machinery.entry = entryBlock; machinery.setError = nullptr; // created lazily only if needed machinery.cleanup = cleanupBlock; machinery.suspend = suspendBlock; @@ -239,68 +256,75 @@ symbolTable.insert(func); SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); + auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock()); + + // Prepare for coroutine conversion by creating the body of the function. + { + size_t numDependencies = execute.dependencies().size(); + size_t numOperands = execute.operands().size(); + + // Await on all dependencies before starting to execute the body region. + for (size_t i = 0; i < numDependencies; ++i) + builder.create(func.getArgument(i)); + + // Await on all async value operands and unwrap the payload. + SmallVector unwrappedOperands(numOperands); + for (size_t i = 0; i < numOperands; ++i) { + Value operand = func.getArgument(numDependencies + i); + unwrappedOperands[i] = builder.create(loc, operand).result(); + } + + // Map from function inputs defined above the execute op to the function + // arguments. + BlockAndValueMapping valueMapping; + valueMapping.map(functionInputs, func.getArguments()); + valueMapping.map(execute.body().getArguments(), unwrappedOperands); - // Prepare a function for coroutine lowering by adding entry/cleanup/suspend - // blocks, adding async.coro operations and setting up control flow. - func.addEntryBlock(); + // Clone all operations from the execute operation body into the outlined + // function body. + for (Operation &op : execute.body().getOps()) + builder.clone(op, valueMapping); + } + + // Adding entry/cleanup/suspend blocks. CoroMachinery coro = setupCoroMachinery(func); // Suspend async function at the end of an entry block, and resume it using // Async resume operation (execution will be resumed in a thread managed by // the async runtime). - Block *entryBlock = &func.getBlocks().front(); - auto builder = ImplicitLocOpBuilder::atBlockEnd(loc, entryBlock); + { + if (auto branch = dyn_cast(coro.entry->getTerminator())) { - // Save the coroutine state: async.coro.save - auto coroSaveOp = - builder.create(CoroStateType::get(ctx), coro.coroHandle); + builder.setInsertionPointToEnd(coro.entry); - // Pass coroutine to the runtime to be resumed on a runtime managed thread. - builder.create(coro.coroHandle); - builder.create(coro.cleanup); + // Save the coroutine state: async.coro.save + auto coroSaveOp = + builder.create(CoroStateType::get(ctx), coro.coroHandle); - // Split the entry block before the terminator (branch to suspend block). - auto *terminatorOp = entryBlock->getTerminator(); - Block *suspended = terminatorOp->getBlock(); - Block *resume = suspended->splitBlock(terminatorOp); - - // Add async.coro.suspend as a suspended block terminator. - builder.setInsertionPointToEnd(suspended); - builder.create(coroSaveOp.state(), coro.suspend, resume, - coro.cleanup); - - size_t numDependencies = execute.dependencies().size(); - size_t numOperands = execute.operands().size(); - - // Await on all dependencies before starting to execute the body region. - builder.setInsertionPointToStart(resume); - for (size_t i = 0; i < numDependencies; ++i) - builder.create(func.getArgument(i)); - - // Await on all async value operands and unwrap the payload. - SmallVector unwrappedOperands(numOperands); - for (size_t i = 0; i < numOperands; ++i) { - Value operand = func.getArgument(numDependencies + i); - unwrappedOperands[i] = builder.create(loc, operand).result(); - } + // Pass coroutine to the runtime to be resumed on a runtime managed + // thread. + builder.create(coro.coroHandle); - // Map from function inputs defined above the execute op to the function - // arguments. - BlockAndValueMapping valueMapping; - valueMapping.map(functionInputs, func.getArguments()); - valueMapping.map(execute.body().getArguments(), unwrappedOperands); + // Add async.coro.suspend as a suspended block terminator. + builder.create(coroSaveOp.state(), coro.suspend, + branch.getDest(), coro.cleanup); - // Clone all operations from the execute operation body into the outlined - // function body. - for (Operation &op : execute.body().getOps()) - builder.clone(op, valueMapping); + branch.erase(); + } else { + assert(false && "Unexpected terminator"); + } + } // Replace the original `async.execute` with a call to outlined function. - ImplicitLocOpBuilder callBuilder(loc, execute); - auto callOutlinedFunc = callBuilder.create( - func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); - execute.replaceAllUsesWith(callOutlinedFunc.getResults()); - execute.erase(); + { + ImplicitLocOpBuilder callBuilder(loc, execute); + auto callOutlinedFunc = callBuilder.create( + func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); + execute.replaceAllUsesWith(callOutlinedFunc.getResults()); + execute.erase(); + } + + LLVM_DEBUG(llvm::dbgs() << "func:\n" << func << "\n"); return {func, coro}; } @@ -571,21 +595,17 @@ func.setResultAttrs(i, func.getResultAttrs(i - 1)); } func.setResultAttrs(0, ArrayRef()); - CoroMachinery coro = setupCoroMachinery(func); for (Block &block : func.getBlocks()) { - if (&block != coro.suspend) { - auto *terminator = block.getTerminator(); - if (auto returnOp = dyn_cast(*terminator)) { - ImplicitLocOpBuilder builder(loc, returnOp); - builder.create(returnOp.getOperands()); - builder.create(coro.cleanup); - returnOp.erase(); - } else { - assert(dyn_cast(*terminator) && "Unexpected terminator."); - } + auto *terminator = block.getTerminator(); + if (auto returnOp = dyn_cast(*terminator)) { + ImplicitLocOpBuilder builder(loc, returnOp); + builder.create(returnOp.getOperands()); + returnOp.erase(); + } else { + assert(dyn_cast(*terminator) && "Unexpected terminator."); } } - return coro; + return setupCoroMachinery(func); } /// Rewrites a call into a function that has been rewiten as a coroutine. diff --git a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir @@ -12,19 +12,20 @@ // CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value // CHECK: %[[ID:.*]] = async.coro.id // CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]] - -// CHECK: %[[VAL:.*]] = addf %[[ARG]], %[[ARG]] : f32 +// CHECK: br ^[[ORIGINAL_ENTRY:.*]] +// CHECK ^[[ORIGINAL_ENTRY]]: +// CHECK: %[[VAL:.*]] = addf %[[ARG]], %[[ARG]] : f32 %0 = addf %arg0, %arg0 : f32 -// CHECK: %[[VAL_STORAGE:.*]] = async.runtime.create : !async.value +// CHECK: %[[VAL_STORAGE:.*]] = async.runtime.create : !async.value %1 = async.runtime.create: !async.value -// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : !async.value +// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : !async.value async.runtime.store %0, %1: !async.value -// CHECK: async.runtime.set_available %[[VAL_STORAGE]] : !async.value +// CHECK: async.runtime.set_available %[[VAL_STORAGE]] : !async.value async.runtime.set_available %1: !async.value -// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]] -// CHECK: async.runtime.await_and_resume %[[VAL_STORAGE]], %[[HDL]] -// CHECK: async.coro.suspend %[[SAVED]] +// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.await_and_resume %[[VAL_STORAGE]], %[[HDL]] +// CHECK: async.coro.suspend %[[SAVED]] // CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]] %2 = async.await %1 : !async.value @@ -64,13 +65,15 @@ // CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value // CHECK: %[[ID:.*]] = async.coro.id // CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]] +// CHECK: br ^[[ORIGINAL_ENTRY:.*]] +// CHECK ^[[ORIGINAL_ENTRY]]: -// CHECK: %[[CONSTANT:.*]] = constant +// CHECK: %[[CONSTANT:.*]] = constant %c = constant 1.0 : f32 -// CHECK: %[[RETURNED_TO_CALLER:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value) -// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]] -// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER]]#0, %[[HDL]] -// CHECK: async.coro.suspend %[[SAVED]] +// CHECK: %[[RETURNED_TO_CALLER:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value) +// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER]]#0, %[[HDL]] +// CHECK: async.coro.suspend %[[SAVED]] // CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]] %r = call @simple_callee(%c): (f32) -> f32 @@ -111,13 +114,15 @@ // CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value // CHECK: %[[ID:.*]] = async.coro.id // CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]] +// CHECK: br ^[[ORIGINAL_ENTRY:.*]] +// CHECK ^[[ORIGINAL_ENTRY]]: -// CHECK: %[[CONSTANT:.*]] = constant +// CHECK: %[[CONSTANT:.*]] = constant %c = constant 1.0 : f32 -// CHECK: %[[RETURNED_TO_CALLER_1:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value) -// CHECK: %[[SAVED_1:.*]] = async.coro.save %[[HDL]] -// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER_1]]#0, %[[HDL]] -// CHECK: async.coro.suspend %[[SAVED_1]] +// CHECK: %[[RETURNED_TO_CALLER_1:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value) +// CHECK: %[[SAVED_1:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER_1]]#0, %[[HDL]] +// CHECK: async.coro.suspend %[[SAVED_1]] // CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_1:.*]], ^[[CLEANUP:.*]] %r = call @simple_callee(%c): (f32) -> f32 diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir @@ -328,8 +328,8 @@ // ----- -// CHECK-LABEL: @execute_asserttion -func @execute_asserttion(%arg0: i1) { +// CHECK-LABEL: @execute_assertion +func @execute_assertion(%arg0: i1) { %token = async.execute { assert %arg0, "error" async.yield