diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -39,6 +39,13 @@ let summary = "Lower high level async operations (e.g. async.execute) to the" "explicit async.runtime and async.coro operations"; let constructor = "mlir::createAsyncToAsyncRuntimePass()"; + let options = [ + // Temporary for bringup, should become the default. + Option<"eliminateBlockingAwaitOps", "eliminate-blocking-await-ops", "bool", + /*default=*/"false", + "Rewrite functions with blocking async.runtime.await as coroutines " + "with async.runtime.await_and_resume."> + ]; let dependentDialects = ["async::AsyncDialect"]; } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -73,8 +73,18 @@ }; } // namespace -/// Builds an coroutine template compatible with LLVM coroutines switched-resume -/// lowering using `async.runtime.*` and `async.coro.*` operations. +/// Utility to partially update the regular function CFG to the coroutine CFG +/// compatible with LLVM coroutines switched-resume lowering using +/// `async.runtime.*` and `async.coro.*` operations. Modifies the entry block +/// by prepending its ops with coroutine setup. Also inserts trailing blocks. +/// +/// The result types of the passed `func` must start with an `async.token` +/// and be continued with some number of `async.value`s. +/// +/// It's up to the caller of this function to fix up the terminators of the +/// preexisting blocks of the passed func op. If the passed `func` is legal, +/// this typically means rewriting every return op as a yield op and a branch op +/// to the suspend block. /// /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html /// @@ -87,15 +97,16 @@ /// /// Coroutine structure (only the important bits): /// -/// func @async_execute_fn() -/// -> (!async.token, !async.value) +/// func @some_fn() -> (!async.token, !async.value) /// { /// ^entry(): /// %token = : !async.token // create async runtime token /// %value = : !async.value // create async value /// %id = async.coro.id // create a coroutine id /// %hdl = async.coro.begin %id // create a coroutine handle -/// br ^cleanup +/// /* other ops of the preexisting entry block */ +/// +/// /* other preexisting blocks */ /// /// ^set_error: // this block created lazily only if needed (see code below) /// async.runtime.set_error %token : !async.token @@ -111,16 +122,11 @@ /// return %token, %value : !async.token, !async.value /// } /// -/// The actual code for the async.execute operation body region will be inserted -/// before the entry block terminator. -/// -/// static CoroMachinery setupCoroMachinery(FuncOp func) { - assert(func.getBody().empty() && "Function must have empty body"); + assert(!func.getBlocks().empty() && "Function must have an entry block"); MLIRContext *ctx = func.getContext(); - Block *entryBlock = func.addEntryBlock(); - + Block *entryBlock = &func.getBlocks().front(); auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); // ------------------------------------------------------------------------ // @@ -166,10 +172,6 @@ ret.insert(ret.end(), retValues.begin(), retValues.end()); builder.create(ret); - // Branch from the entry block to the cleanup block to create a valid CFG. - builder.setInsertionPointToEnd(entryBlock); - builder.create(cleanupBlock); - // `async.await` op lowering will create resume blocks for async // continuations, and will conditionally branch to cleanup or suspend blocks. @@ -242,13 +244,14 @@ // Prepare a function for coroutine lowering by adding entry/cleanup/suspend // blocks, adding async.coro operations and setting up control flow. + func.addEntryBlock(); CoroMachinery coro = setupCoroMachinery(func); // Suspend async function at the end of an entry block, and resume it using // Async resume operation (execution will be resumed in a thread managed by // the async runtime). Block *entryBlock = &func.getBlocks().front(); - auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock); + auto builder = ImplicitLocOpBuilder::atBlockEnd(loc, entryBlock); // Save the coroutine state: async.coro.save auto coroSaveOp = @@ -256,6 +259,7 @@ // Pass coroutine to the runtime to be resumed on a runtime managed thread. builder.create(coro.coroHandle); + builder.create(coro.cleanup); // Split the entry block before the terminator (branch to suspend block). auto *terminatorOp = entryBlock->getTerminator(); @@ -557,6 +561,132 @@ //===----------------------------------------------------------------------===// +/// Rewrite a func as a coroutine by: +/// 1) Wrapping the results into `async.value`. +/// 2) Prepending the results with `async.token`. +/// 3) Setting up coroutine blocks. +/// 4) Rewriting return ops as yield op and branch op into the suspend block. +static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) { + auto *ctx = func->getContext(); + auto loc = func.getLoc(); + SmallVector resultTypes; + resultTypes.reserve(func.getCallableResults().size()); + llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes), + [](Type type) { return ValueType::get(type); }); + func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes)); + func.insertResult(0, TokenType::get(ctx), {}); + CoroMachinery coro = setupCoroMachinery(func); + for (Block &block : func.getBlocks()) { + if (&block == coro.suspend) + continue; + + Operation *terminator = block.getTerminator(); + if (auto returnOp = dyn_cast(*terminator)) { + ImplicitLocOpBuilder builder(loc, returnOp); + builder.create(returnOp.getOperands()); + builder.create(coro.cleanup); + returnOp.erase(); + } + } + return coro; +} + +/// Rewrites a call into a function that has been rewritten as a coroutine. +/// +/// The invocation of this function is safe only when call ops are traversed in +/// reverse order of how they appear in a single block. See `funcsToCoroutines`. +static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) { + auto loc = func.getLoc(); + ImplicitLocOpBuilder callBuilder(loc, oldCall); + auto newCall = callBuilder.create( + func.getName(), func.getCallableResults(), oldCall.getArgOperands()); + + // Await on the async token and all the value results and unwrap the latter. + callBuilder.create(loc, newCall.getResults().front()); + SmallVector unwrappedResults; + unwrappedResults.reserve(newCall->getResults().size() - 1); + for (Value result : newCall.getResults().drop_front()) + unwrappedResults.push_back( + callBuilder.create(loc, result).result()); + // Careful, when result of a call is piped into another call this could lead + // to a dangling pointer. + oldCall.replaceAllUsesWith(unwrappedResults); + oldCall.erase(); +} + +static LogicalResult +funcsToCoroutines(ModuleOp module, + llvm::DenseMap &outlinedFunctions) { + // The following code supports the general case when 2 functions mutually + // recurse into each other. Because of this and that we are relying on + // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase + // a FuncOp while inserting an equivalent coroutine, because that could lead + // to dangling pointers. + + SmallVector funcWorklist; + + // Careful, it's okay to add a func to the worklist multiple times if and only + // if the loop processing the worklist will skip the functions that have + // already been converted to coroutines. + auto addToWorklist = [&outlinedFunctions, &funcWorklist](FuncOp func) { + // N.B. To refactor this code into a separate pass the lookup in + // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary + // func and recognizing if it has a coroutine structure is messy. Passing + // this dict between the passes is ugly. + if (outlinedFunctions.find(func) == outlinedFunctions.end()) { + for (Operation &op : func.body().getOps()) { + if (dyn_cast(op) || dyn_cast(op)) { + funcWorklist.push_back(func); + break; + } + } + } + }; + + // Traverse in post-order collecting for each func op the await ops it has. + for (FuncOp func : module.getOps()) + addToWorklist(func); + + SymbolTableCollection symbolTable; + SymbolUserMap symbolUserMap(symbolTable, module); + + // Rewrite funcs, while updating call sites and adding them to the worklist. + while (!funcWorklist.empty()) { + auto func = funcWorklist.pop_back_val(); + auto insertion = outlinedFunctions.insert({func, CoroMachinery{}}); + if (!insertion.second) + // This function has already been processed because this is either + // the corecursive case, or a caller with multiple calls to a newly + // created corouting. Either way, skip updating the call sites. + continue; + insertion.first->second = rewriteFuncAsCoroutine(func); + SmallVector users(symbolUserMap.getUsers(func).begin(), + symbolUserMap.getUsers(func).end()); + // If there are multiple calls from the same block they need to be traversed + // in reverse order so that symbolUserMap references are not invalidated + // when updating the users of the call op which is earlier in the block. + llvm::sort(users, [](Operation *a, Operation *b) { + Block *blockA = a->getBlock(); + Block *blockB = b->getBlock(); + // Impose arbitrary order on blocks so that there is a well-defined order. + return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b)); + }); + // Rewrite the callsites to await on results of the newly created coroutine. + for (Operation *op : users) { + if (CallOp call = dyn_cast(*op)) { + FuncOp caller = call->getParentOfType(); + rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op. + addToWorklist(caller); + } else { + op->emitError("Unexpected reference to func referenced by symbol"); + return failure(); + } + } + } + return success(); +} + +//===----------------------------------------------------------------------===// void AsyncToAsyncRuntimePass::runOnOperation() { ModuleOp module = getOperation(); SymbolTable symbolTable(module); @@ -579,6 +709,12 @@ return outlinedFunctions.find(parentFunc) != outlinedFunctions.end(); }; + if (eliminateBlockingAwaitOps && + failed(funcsToCoroutines(module, outlinedFunctions))) { + signalPassFailure(); + return; + } + // Lower async operations to async.runtime operations. MLIRContext *ctx = module->getContext(); RewritePatternSet asyncPatterns(ctx); @@ -622,6 +758,9 @@ return outlinedFunctions.find(func) == outlinedFunctions.end(); }); + if (eliminateBlockingAwaitOps) + runtimeTarget.addIllegalOp(); + if (failed(applyPartialConversion(module, runtimeTarget, std::move(asyncPatterns)))) { signalPassFailure(); diff --git a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir @@ -0,0 +1,304 @@ +// RUN: mlir-opt %s -split-input-file \ +// RUN: -async-to-async-runtime="eliminate-blocking-await-ops=true" \ +// RUN: | FileCheck %s --dump-input=always + +// CHECK-LABEL: func @simple_callee +// CHECK-SAME: (%[[ARG:.*]]: f32) +// CHECK-SAME: -> (!async.token, !async.value {builtin.foo = "bar"}) +func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) { +// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token +// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value +// CHECK: %[[ID:.*]] = async.coro.id +// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]] + +// CHECK: %[[VAL:.*]] = addf %[[ARG]], %[[ARG]] : f32 + %0 = addf %arg0, %arg0 : f32 +// CHECK: %[[VAL_STORAGE:.*]] = async.runtime.create : !async.value + %1 = async.runtime.create: !async.value +// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : !async.value + async.runtime.store %0, %1: !async.value +// CHECK: async.runtime.set_available %[[VAL_STORAGE]] : !async.value + async.runtime.set_available %1: !async.value + +// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.await_and_resume %[[VAL_STORAGE]], %[[HDL]] +// CHECK: async.coro.suspend %[[SAVED]] +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]] + %2 = async.await %1 : !async.value + +// CHECK: ^[[RESUME]]: +// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[VAL_STORAGE]] : !async.value +// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]] + +// CHECK: ^[[BRANCH_OK]]: +// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : !async.value +// CHECK: %[[RETURNED:.*]] = mulf %[[ARG]], %[[LOADED]] : f32 +// CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : !async.value +// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]] +// CHECK: async.runtime.set_available %[[TOKEN]] +// CHECK: br ^[[CLEANUP]] + %3 = mulf %arg0, %2 : f32 + return %3: f32 + +// CHECK: ^[[BRANCH_ERROR]]: +// CHECK: async.runtime.set_error %[[TOKEN]] +// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]] +// CHECK: br ^[[CLEANUP]] + + +// CHECK: ^[[CLEANUP]]: +// CHECK: async.coro.free %[[ID]], %[[HDL]] +// CHECK: br ^[[SUSPEND]] + +// CHECK: ^[[SUSPEND]]: +// CHECK: async.coro.end %[[HDL]] +// CHECK: return %[[TOKEN]], %[[RETURNED_STORAGE]] : !async.token, !async.value +} + +// CHECK-LABEL: func @simple_caller() +// CHECK-SAME: -> (!async.token, !async.value) +func @simple_caller() -> f32 { +// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token +// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value +// CHECK: %[[ID:.*]] = async.coro.id +// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]] + +// CHECK: %[[CONSTANT:.*]] = constant + %c = constant 1.0 : f32 +// CHECK: %[[RETURNED_TO_CALLER:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value) +// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER]]#0, %[[HDL]] +// CHECK: async.coro.suspend %[[SAVED]] +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]] + %r = call @simple_callee(%c): (f32) -> f32 + +// CHECK: ^[[RESUME]]: +// CHECK: %[[IS_TOKEN_ERROR:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER]]#0 : !async.token +// CHECK: cond_br %[[IS_TOKEN_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK:.*]] + +// CHECK: ^[[BRANCH_TOKEN_OK]]: +// CHECK: %[[IS_VALUE_ERROR:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER]]#1 : !async.value +// CHECK: cond_br %[[IS_VALUE_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK:.*]] + +// CHECK: ^[[BRANCH_VALUE_OK]]: +// CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : !async.value +// CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : !async.value +// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]] +// CHECK: async.runtime.set_available %[[TOKEN]] +// CHECK: br ^[[CLEANUP]] + return %r: f32 +// CHECK: ^[[BRANCH_ERROR]]: +// CHECK: async.runtime.set_error %[[TOKEN]] +// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]] +// CHECK: br ^[[CLEANUP]] + + +// CHECK: ^[[CLEANUP]]: +// CHECK: async.coro.free %[[ID]], %[[HDL]] +// CHECK: br ^[[SUSPEND]] + +// CHECK: ^[[SUSPEND]]: +// CHECK: async.coro.end %[[HDL]] +// CHECK: return %[[TOKEN]], %[[RETURNED_STORAGE]] : !async.token, !async.value +} + +// CHECK-LABEL: func @double_caller() +// CHECK-SAME: -> (!async.token, !async.value) +func @double_caller() -> f32 { +// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token +// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value +// CHECK: %[[ID:.*]] = async.coro.id +// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]] + +// CHECK: %[[CONSTANT:.*]] = constant + %c = constant 1.0 : f32 +// CHECK: %[[RETURNED_TO_CALLER_1:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value) +// CHECK: %[[SAVED_1:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER_1]]#0, %[[HDL]] +// CHECK: async.coro.suspend %[[SAVED_1]] +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_1:.*]], ^[[CLEANUP:.*]] + %r = call @simple_callee(%c): (f32) -> f32 + +// CHECK: ^[[RESUME_1]]: +// CHECK: %[[IS_TOKEN_ERROR_1:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_1]]#0 : !async.token +// CHECK: cond_br %[[IS_TOKEN_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_1:.*]] + +// CHECK: ^[[BRANCH_TOKEN_OK_1]]: +// CHECK: %[[IS_VALUE_ERROR_1:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_1]]#1 : !async.value +// CHECK: cond_br %[[IS_VALUE_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_1:.*]] + +// CHECK: ^[[BRANCH_VALUE_OK_1]]: +// CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : !async.value +// CHECK: %[[RETURNED_TO_CALLER_2:.*]]:2 = call @simple_callee(%[[LOADED_1]]) : (f32) -> (!async.token, !async.value) +// CHECK: %[[SAVED_2:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER_2]]#0, %[[HDL]] +// CHECK: async.coro.suspend %[[SAVED_2]] +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_2:.*]], ^[[CLEANUP:.*]] + %s = call @simple_callee(%r): (f32) -> f32 + +// CHECK: ^[[RESUME_2]]: +// CHECK: %[[IS_TOKEN_ERROR_2:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_2]]#0 : !async.token +// CHECK: cond_br %[[IS_TOKEN_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_2:.*]] + +// CHECK: ^[[BRANCH_TOKEN_OK_2]]: +// CHECK: %[[IS_VALUE_ERROR_2:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_2]]#1 : !async.value +// CHECK: cond_br %[[IS_VALUE_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_2:.*]] + +// CHECK: ^[[BRANCH_VALUE_OK_2]]: +// CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : !async.value +// CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : !async.value +// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]] +// CHECK: async.runtime.set_available %[[TOKEN]] +// CHECK: br ^[[CLEANUP]] + return %s: f32 +// CHECK: ^[[BRANCH_ERROR]]: +// CHECK: async.runtime.set_error %[[TOKEN]] +// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]] +// CHECK: br ^[[CLEANUP]] + +// CHECK: ^[[CLEANUP]]: +// CHECK: async.coro.free %[[ID]], %[[HDL]] +// CHECK: br ^[[SUSPEND]] + +// CHECK: ^[[SUSPEND]]: +// CHECK: async.coro.end %[[HDL]] +// CHECK: return %[[TOKEN]], %[[RETURNED_STORAGE]] : !async.token, !async.value +} + +// CHECK-LABEL: func @recursive +// CHECK-SAME: (%[[ARG:.*]]: !async.token) -> !async.token +func @recursive(%arg: !async.token) { +// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token +// CHECK: %[[ID:.*]] = async.coro.id +// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]] +// CHECK: %[[SAVED_1:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.await_and_resume %[[ARG]], %[[HDL]] +// CHECK: async.coro.suspend %[[SAVED_1]] +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_1:.*]], ^[[CLEANUP:.*]] + + async.await %arg : !async.token +// CHECK: ^[[RESUME_1]]: +// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token +// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]] + +// CHECK: ^[[BRANCH_OK]]: +// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token +%r = async.runtime.create : !async.token +// CHECK: async.runtime.set_available %[[GIVEN]] +async.runtime.set_available %r: !async.token +// CHECK: %[[RETURNED_TO_CALLER:.*]] = call @recursive(%[[GIVEN]]) : (!async.token) -> !async.token +call @recursive(%r): (!async.token) -> () +// CHECK: %[[SAVED_2:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER]], %[[HDL]] +// CHECK: async.coro.suspend %[[SAVED_2]] +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_2:.*]], ^[[CLEANUP:.*]] + +// CHECK: ^[[RESUME_2]]: +// CHECK: async.runtime.set_available %[[TOKEN]] +// CHECK: br ^[[CLEANUP]] + +// CHECK: ^[[BRANCH_ERROR]]: +// CHECK: async.runtime.set_error %[[TOKEN]] +// CHECK: br ^[[CLEANUP]] +return + +// CHECK: ^[[CLEANUP]]: +// CHECK: async.coro.free %[[ID]], %[[HDL]] +// CHECK: br ^[[SUSPEND]] + +// CHECK: ^[[SUSPEND]]: +// CHECK: async.coro.end %[[HDL]] +// CHECK: return %[[TOKEN]] : !async.token +} + +// CHECK-LABEL: func @corecursive1 +// CHECK-SAME: (%[[ARG:.*]]: !async.token) -> !async.token +func @corecursive1(%arg: !async.token) { +// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token +// CHECK: %[[ID:.*]] = async.coro.id +// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]] +// CHECK: %[[SAVED_1:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.await_and_resume %[[ARG]], %[[HDL]] +// CHECK: async.coro.suspend %[[SAVED_1]] +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_1:.*]], ^[[CLEANUP:.*]] + + async.await %arg : !async.token +// CHECK: ^[[RESUME_1]]: +// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token +// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]] + +// CHECK: ^[[BRANCH_OK]]: +// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token +%r = async.runtime.create : !async.token +// CHECK: async.runtime.set_available %[[GIVEN]] +async.runtime.set_available %r: !async.token +// CHECK: %[[RETURNED_TO_CALLER:.*]] = call @corecursive2(%[[GIVEN]]) : (!async.token) -> !async.token +call @corecursive2(%r): (!async.token) -> () +// CHECK: %[[SAVED_2:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER]], %[[HDL]] +// CHECK: async.coro.suspend %[[SAVED_2]] +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_2:.*]], ^[[CLEANUP:.*]] + +// CHECK: ^[[RESUME_2]]: +// CHECK: async.runtime.set_available %[[TOKEN]] +// CHECK: br ^[[CLEANUP]] + +// CHECK: ^[[BRANCH_ERROR]]: +// CHECK: async.runtime.set_error %[[TOKEN]] +// CHECK: br ^[[CLEANUP]] +return + +// CHECK: ^[[CLEANUP]]: +// CHECK: async.coro.free %[[ID]], %[[HDL]] +// CHECK: br ^[[SUSPEND]] + +// CHECK: ^[[SUSPEND]]: +// CHECK: async.coro.end %[[HDL]] +// CHECK: return %[[TOKEN]] : !async.token +} + +// CHECK-LABEL: func @corecursive2 +// CHECK-SAME: (%[[ARG:.*]]: !async.token) -> !async.token +func @corecursive2(%arg: !async.token) { +// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token +// CHECK: %[[ID:.*]] = async.coro.id +// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]] +// CHECK: %[[SAVED_1:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.await_and_resume %[[ARG]], %[[HDL]] +// CHECK: async.coro.suspend %[[SAVED_1]] +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_1:.*]], ^[[CLEANUP:.*]] + + async.await %arg : !async.token +// CHECK: ^[[RESUME_1]]: +// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token +// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]] + +// CHECK: ^[[BRANCH_OK]]: +// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token +%r = async.runtime.create : !async.token +// CHECK: async.runtime.set_available %[[GIVEN]] +async.runtime.set_available %r: !async.token +// CHECK: %[[RETURNED_TO_CALLER:.*]] = call @corecursive1(%[[GIVEN]]) : (!async.token) -> !async.token +call @corecursive1(%r): (!async.token) -> () +// CHECK: %[[SAVED_2:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER]], %[[HDL]] +// CHECK: async.coro.suspend %[[SAVED_2]] +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_2:.*]], ^[[CLEANUP:.*]] + +// CHECK: ^[[RESUME_2]]: +// CHECK: async.runtime.set_available %[[TOKEN]] +// CHECK: br ^[[CLEANUP]] + +// CHECK: ^[[BRANCH_ERROR]]: +// CHECK: async.runtime.set_error %[[TOKEN]] +// CHECK: br ^[[CLEANUP]] +return + +// CHECK: ^[[CLEANUP]]: +// CHECK: async.coro.free %[[ID]], %[[HDL]] +// CHECK: br ^[[SUSPEND]] + +// CHECK: ^[[SUSPEND]]: +// CHECK: async.coro.end %[[HDL]] +// CHECK: return %[[TOKEN]] : !async.token +}