diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -26,6 +26,8 @@ std::unique_ptr> createAsyncRefCountingOptimizationPass(); +std::unique_ptr> createAsyncToAsyncRuntimePass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -38,4 +38,11 @@ let dependentDialects = ["async::AsyncDialect"]; } +def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> { + let summary = "Lower high level async operations (e.g. async.execute) to the" + "explicit async.rutime and async.coro operations"; + let constructor = "mlir::createAsyncToAsyncRuntimePass()"; + let dependentDialects = ["async::AsyncDialect"]; +} + #endif // MLIR_DIALECT_ASYNC_PASSES diff --git a/mlir/integration_test/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir b/mlir/integration_test/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir --- a/mlir/integration_test/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir +++ b/mlir/integration_test/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir @@ -2,6 +2,7 @@ // RUN: -linalg-tile-to-parallel-loops="linalg-tile-sizes=256" \ // RUN: -async-parallel-for="num-concurrent-async-execute=4" \ // RUN: -async-ref-counting \ +// RUN: -async-to-async-runtime \ // RUN: -convert-async-to-llvm \ // RUN: -lower-affine \ // RUN: -convert-linalg-to-loops \ diff --git a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir --- a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir +++ b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s -async-parallel-for \ // RUN: -async-ref-counting \ +// RUN: -async-to-async-runtime \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ diff --git a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir --- a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir +++ b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -async-parallel-for \ +// RUN: -async-to-async-runtime \ // RUN: -async-ref-counting \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -14,14 +14,10 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" -#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/RegionUtils.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/Support/FormatVariadic.h" #define DEBUG_TYPE "convert-async-to-llvm" @@ -257,232 +253,6 @@ blockBuilder.create(ValueRange()); } -//===----------------------------------------------------------------------===// -// async.execute op outlining to the coroutine functions. -//===----------------------------------------------------------------------===// - -/// Function targeted for coroutine transformation has two additional blocks at -/// the end: coroutine cleanup and coroutine suspension. -/// -/// async.await op lowering additionaly creates a resume block for each -/// operation to enable non-blocking waiting via coroutine suspension. -namespace { -struct CoroMachinery { - // Async execute region returns a completion token, and an async value for - // each yielded value. - // - // %token, %result = async.execute -> !async.value { - // %0 = constant ... : T - // async.yield %0 : T - // } - Value asyncToken; // token representing completion of the async region - llvm::SmallVector returnValues; // returned async values - - Value coroHandle; // coroutine handle (!async.coro.handle value) - Block *cleanup; // coroutine cleanup block - Block *suspend; // coroutine suspension block -}; -} // namespace - -/// Builds an coroutine template compatible with LLVM coroutines switched-resume -/// lowering using `async.runtime.*` and `async.coro.*` operations. -/// -/// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html -/// -/// - `entry` block sets up the coroutine. -/// - `cleanup` block cleans up the coroutine state. -/// - `suspend block after the @llvm.coro.end() defines what value will be -/// returned to the initial caller of a coroutine. Everything before the -/// @llvm.coro.end() will be executed at every suspension point. -/// -/// Coroutine structure (only the important bits): -/// -/// func @async_execute_fn() -/// -> (!async.token, !async.value) -/// { -/// ^entry(): -/// %token = : !async.token // create async runtime token -/// %value = : !async.value // create async value -/// %id = async.coro.id // create a coroutine id -/// %hdl = async.coro.begin %id // create a coroutine handle -/// br ^cleanup -/// -/// ^cleanup: -/// async.coro.free %hdl // delete the coroutine state -/// br ^suspend -/// -/// ^suspend: -/// async.coro.end %hdl // marks the end of a coroutine -/// return %token, %value : !async.token, !async.value -/// } -/// -/// The actual code for the async.execute operation body region will be inserted -/// before the entry block terminator. -/// -/// -static CoroMachinery setupCoroMachinery(FuncOp func) { - assert(func.getBody().empty() && "Function must have empty body"); - - MLIRContext *ctx = func.getContext(); - Block *entryBlock = func.addEntryBlock(); - - auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); - - // ------------------------------------------------------------------------ // - // Allocate async token/values that we will return from a ramp function. - // ------------------------------------------------------------------------ // - auto retToken = builder.create(TokenType::get(ctx)).result(); - - llvm::SmallVector retValues; - for (auto resType : func.getCallableResults().drop_front()) - retValues.emplace_back(builder.create(resType).result()); - - // ------------------------------------------------------------------------ // - // Initialize coroutine: get coroutine id and coroutine handle. - // ------------------------------------------------------------------------ // - auto coroIdOp = builder.create(CoroIdType::get(ctx)); - auto coroHdlOp = - builder.create(CoroHandleType::get(ctx), coroIdOp.id()); - - Block *cleanupBlock = func.addBlock(); - Block *suspendBlock = func.addBlock(); - - // ------------------------------------------------------------------------ // - // Coroutine cleanup block: deallocate coroutine frame, free the memory. - // ------------------------------------------------------------------------ // - builder.setInsertionPointToStart(cleanupBlock); - builder.create(coroIdOp.id(), coroHdlOp.handle()); - - // Branch into the suspend block. - builder.create(suspendBlock); - - // ------------------------------------------------------------------------ // - // Coroutine suspend block: mark the end of a coroutine and return allocated - // async token. - // ------------------------------------------------------------------------ // - builder.setInsertionPointToStart(suspendBlock); - - // Mark the end of a coroutine: async.coro.end - builder.create(coroHdlOp.handle()); - - // Return created `async.token` and `async.values` from the suspend block. - // This will be the return value of a coroutine ramp function. - SmallVector ret{retToken}; - ret.insert(ret.end(), retValues.begin(), retValues.end()); - builder.create(ret); - - // Branch from the entry block to the cleanup block to create a valid CFG. - builder.setInsertionPointToEnd(entryBlock); - builder.create(cleanupBlock); - - // `async.await` op lowering will create resume blocks for async - // continuations, and will conditionally branch to cleanup or suspend blocks. - - CoroMachinery machinery; - machinery.asyncToken = retToken; - machinery.returnValues = retValues; - machinery.coroHandle = coroHdlOp.handle(); - machinery.cleanup = cleanupBlock; - machinery.suspend = suspendBlock; - return machinery; -} - -/// Outline the body region attached to the `async.execute` op into a standalone -/// function. -/// -/// Note that this is not reversible transformation. -static std::pair -outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { - ModuleOp module = execute->getParentOfType(); - - MLIRContext *ctx = module.getContext(); - Location loc = execute.getLoc(); - - // Collect all outlined function inputs. - llvm::SetVector functionInputs(execute.dependencies().begin(), - execute.dependencies().end()); - functionInputs.insert(execute.operands().begin(), execute.operands().end()); - getUsedValuesDefinedAbove(execute.body(), functionInputs); - - // Collect types for the outlined function inputs and outputs. - auto typesRange = llvm::map_range( - functionInputs, [](Value value) { return value.getType(); }); - SmallVector inputTypes(typesRange.begin(), typesRange.end()); - auto outputTypes = execute.getResultTypes(); - - auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); - auto funcAttrs = ArrayRef(); - - // TODO: Derive outlined function name from the parent FuncOp (support - // multiple nested async.execute operations). - FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); - symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator())); - - SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); - - // Prepare a function for coroutine lowering by adding entry/cleanup/suspend - // blocks, adding async.coro operations and setting up control flow. - CoroMachinery coro = setupCoroMachinery(func); - - // Suspend async function at the end of an entry block, and resume it using - // Async resume operation (execution will be resumed in a thread managed by - // the async runtime). - Block *entryBlock = &func.getBlocks().front(); - auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock); - - // Save the coroutine state: async.coro.save - auto coroSaveOp = - builder.create(CoroStateType::get(ctx), coro.coroHandle); - - // Pass coroutine to the runtime to be resumed on a runtime managed thread. - builder.create(coro.coroHandle); - - // Split the entry block before the terminator (branch to suspend block). - auto *terminatorOp = entryBlock->getTerminator(); - Block *suspended = terminatorOp->getBlock(); - Block *resume = suspended->splitBlock(terminatorOp); - - // Add async.coro.suspend as a suspended block terminator. - builder.setInsertionPointToEnd(suspended); - builder.create(coroSaveOp.state(), coro.suspend, resume, - coro.cleanup); - - size_t numDependencies = execute.dependencies().size(); - size_t numOperands = execute.operands().size(); - - // Await on all dependencies before starting to execute the body region. - builder.setInsertionPointToStart(resume); - for (size_t i = 0; i < numDependencies; ++i) - builder.create(func.getArgument(i)); - - // Await on all async value operands and unwrap the payload. - SmallVector unwrappedOperands(numOperands); - for (size_t i = 0; i < numOperands; ++i) { - Value operand = func.getArgument(numDependencies + i); - unwrappedOperands[i] = builder.create(loc, operand).result(); - } - - // Map from function inputs defined above the execute op to the function - // arguments. - BlockAndValueMapping valueMapping; - valueMapping.map(functionInputs, func.getArguments()); - valueMapping.map(execute.body().getArguments(), unwrappedOperands); - - // Clone all operations from the execute operation body into the outlined - // function body. - for (Operation &op : execute.body().getOps()) - builder.clone(op, valueMapping); - - // Replace the original `async.execute` with a call to outlined function. - ImplicitLocOpBuilder callBuilder(loc, execute); - auto callOutlinedFunc = callBuilder.create( - func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); - execute.replaceAllUsesWith(callOutlinedFunc.getResults()); - execute.erase(); - - return {func, coro}; -} - //===----------------------------------------------------------------------===// // Convert Async dialect types to LLVM types. //===----------------------------------------------------------------------===// @@ -933,6 +703,10 @@ // Cast from i8* to the LLVM pointer type. auto valueType = op.value().getType(); auto llvmValueType = getTypeConverter()->convertType(valueType); + if (!llvmValueType) + return rewriter.notifyMatchFailure( + op, "failed to convert stored value type to LLVM type"); + auto castedStoragePtr = rewriter.create( loc, LLVM::LLVMPointerType::get(llvmValueType), storagePtr.getResult(0)); @@ -972,6 +746,10 @@ // Cast from i8* to the LLVM pointer type. auto valueType = op.result().getType(); auto llvmValueType = getTypeConverter()->convertType(valueType); + if (!llvmValueType) + return rewriter.notifyMatchFailure( + op, "failed to convert loaded value type to LLVM type"); + auto castedStoragePtr = rewriter.create( loc, LLVM::LLVMPointerType::get(llvmValueType), storagePtr.getResult(0)); @@ -1074,205 +852,6 @@ }; } // namespace -//===----------------------------------------------------------------------===// -// Convert async.create_group operation to async.runtime.create -//===----------------------------------------------------------------------===// - -namespace { -class CreateGroupOpLowering : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CreateGroupOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, GroupType::get(op->getContext())); - return success(); - } -}; -} // namespace - -//===----------------------------------------------------------------------===// -// Convert async.add_to_group operation to async.runtime.add_to_group. -//===----------------------------------------------------------------------===// - -namespace { -class AddToGroupOpLowering : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(AddToGroupOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, rewriter.getIndexType(), operands); - return success(); - } -}; -} // namespace - -//===----------------------------------------------------------------------===// -// Convert async.await and async.await_all operations to the async.runtime.await -// or async.runtime.await_and_resume operations. -//===----------------------------------------------------------------------===// - -namespace { -template -class AwaitOpLoweringBase : public OpConversionPattern { - using AwaitAdaptor = typename AwaitType::Adaptor; - -public: - AwaitOpLoweringBase( - MLIRContext *ctx, - const llvm::DenseMap &outlinedFunctions) - : OpConversionPattern(ctx), - outlinedFunctions(outlinedFunctions) {} - - LogicalResult - matchAndRewrite(AwaitType op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - // We can only await on one the `AwaitableType` (for `await` it can be - // a `token` or a `value`, for `await_all` it must be a `group`). - if (!op.operand().getType().template isa()) - return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); - - // Check if await operation is inside the outlined coroutine function. - auto func = op->template getParentOfType(); - auto outlined = outlinedFunctions.find(func); - const bool isInCoroutine = outlined != outlinedFunctions.end(); - - Location loc = op->getLoc(); - Value operand = AwaitAdaptor(operands).operand(); - - // Inside regular functions we use the blocking wait operation to wait for - // the async object (token, value or group) to become available. - if (!isInCoroutine) - rewriter.create(loc, operand); - - // Inside the coroutine we convert await operation into coroutine suspension - // point, and resume execution asynchronously. - if (isInCoroutine) { - const CoroMachinery &coro = outlined->getSecond(); - Block *suspended = op->getBlock(); - - ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); - MLIRContext *ctx = op->getContext(); - - // Save the coroutine state and resume on a runtime managed thread when - // the operand becomes available. - auto coroSaveOp = - builder.create(CoroStateType::get(ctx), coro.coroHandle); - builder.create(operand, coro.coroHandle); - - // Split the entry block before the await operation. - Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); - - // Add async.coro.suspend as a suspended block terminator. - builder.setInsertionPointToEnd(suspended); - builder.create(coroSaveOp.state(), coro.suspend, resume, - coro.cleanup); - - // Make sure that replacement value will be constructed in resume block. - rewriter.setInsertionPointToStart(resume); - } - - // Erase or replace the await operation with the new value. - if (Value replaceWith = getReplacementValue(op, operand, rewriter)) - rewriter.replaceOp(op, replaceWith); - else - rewriter.eraseOp(op); - - return success(); - } - - virtual Value getReplacementValue(AwaitType op, Value operand, - ConversionPatternRewriter &rewriter) const { - return Value(); - } - -private: - const llvm::DenseMap &outlinedFunctions; -}; - -/// Lowering for `async.await` with a token operand. -class AwaitTokenOpLowering : public AwaitOpLoweringBase { - using Base = AwaitOpLoweringBase; - -public: - using Base::Base; -}; - -/// Lowering for `async.await` with a value operand. -class AwaitValueOpLowering : public AwaitOpLoweringBase { - using Base = AwaitOpLoweringBase; - -public: - using Base::Base; - - Value - getReplacementValue(AwaitOp op, Value operand, - ConversionPatternRewriter &rewriter) const override { - // Load from the async value storage. - auto valueType = operand.getType().cast().getValueType(); - return rewriter.create(op->getLoc(), valueType, operand); - } -}; - -/// Lowering for `async.await_all` operation. -class AwaitAllOpLowering : public AwaitOpLoweringBase { - using Base = AwaitOpLoweringBase; - -public: - using Base::Base; -}; - -} // namespace - -//===----------------------------------------------------------------------===// -// Convert async.yield operation to async.runtime operations. -//===----------------------------------------------------------------------===// - -class YieldOpLowering : public OpConversionPattern { -public: - YieldOpLowering( - MLIRContext *ctx, - const llvm::DenseMap &outlinedFunctions) - : OpConversionPattern(ctx), - outlinedFunctions(outlinedFunctions) {} - - LogicalResult - matchAndRewrite(async::YieldOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - // Check if yield operation is inside the outlined coroutine function. - auto func = op->template getParentOfType(); - auto outlined = outlinedFunctions.find(func); - if (outlined == outlinedFunctions.end()) - return rewriter.notifyMatchFailure( - op, "operation is not inside the outlined async.execute function"); - - Location loc = op->getLoc(); - const CoroMachinery &coro = outlined->getSecond(); - - // Store yielded values into the async values storage and switch async - // values state to available. - for (auto tuple : llvm::zip(operands, coro.returnValues)) { - Value yieldValue = std::get<0>(tuple); - Value asyncValue = std::get<1>(tuple); - rewriter.create(loc, yieldValue, asyncValue); - rewriter.create(loc, asyncValue); - } - - // Switch the coroutine completion token to available state. - rewriter.replaceOpWithNewOp(op, coro.asyncToken); - - return success(); - } - -private: - const llvm::DenseMap &outlinedFunctions; -}; - //===----------------------------------------------------------------------===// namespace { @@ -1284,89 +863,25 @@ void ConvertAsyncToLLVMPass::runOnOperation() { ModuleOp module = getOperation(); - SymbolTable symbolTable(module); - - MLIRContext *ctx = &getContext(); - - // Outline all `async.execute` body regions into async functions (coroutines). - llvm::DenseMap outlinedFunctions; - - // We use conversion to LLVM type to ensure that all `async.value` operands - // and results can be lowered to LLVM load and store operations. - LLVMTypeConverter llvmConverter(ctx); - llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); - - // Returns true if the `async.value` payload is convertible to LLVM. - auto isConvertibleToLlvm = [&](Type type) -> bool { - auto valueType = type.cast().getValueType(); - return static_cast(llvmConverter.convertType(valueType)); - }; - - WalkResult outlineResult = module.walk([&](ExecuteOp execute) { - // All operands and results must be convertible to LLVM. - if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) { - execute.emitOpError("operands payload must be convertible to LLVM type"); - return WalkResult::interrupt(); - } - if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) { - execute.emitOpError("results payload must be convertible to LLVM type"); - return WalkResult::interrupt(); - } - - outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); - - return WalkResult::advance(); - }); - - // Failed to outline all async execute operations. - if (outlineResult.wasInterrupted()) { - signalPassFailure(); - return; - } - - LLVM_DEBUG({ - llvm::dbgs() << "Outlined " << outlinedFunctions.size() - << " async functions\n"; - }); + MLIRContext *ctx = module->getContext(); // Add declarations for all functions required by the coroutines lowering. addResumeFunction(module); addAsyncRuntimeApiDeclarations(module); addCRuntimeDeclarations(module); - // ------------------------------------------------------------------------ // - // Lower async operations to async.runtime operations. - // ------------------------------------------------------------------------ // - OwningRewritePatternList asyncPatterns; - - // Async lowering does not use type converter because it must preserve all - // types for async.runtime operations. - asyncPatterns.insert(ctx); - asyncPatterns.insert(ctx, - outlinedFunctions); - - // All high level async operations must be lowered to the runtime operations. - ConversionTarget runtimeTarget(*ctx); - runtimeTarget.addLegalDialect(); - runtimeTarget.addIllegalOp(); - runtimeTarget.addIllegalOp(); - - if (failed(applyPartialConversion(module, runtimeTarget, - std::move(asyncPatterns)))) { - signalPassFailure(); - return; - } - - // ------------------------------------------------------------------------ // // Lower async.runtime and async.coro operations to Async Runtime API and // LLVM coroutine intrinsics. - // ------------------------------------------------------------------------ // // Convert async dialect types and operations to LLVM dialect. AsyncRuntimeTypeConverter converter; OwningRewritePatternList patterns; + // We use conversion to LLVM type to lower async.runtime load and store + // operations. + LLVMTypeConverter llvmConverter(ctx); + llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); + // Convert async types in function signatures and function calls. populateFuncOpTypeConversionPattern(patterns, ctx, converter); populateCallOpTypeConversionPattern(patterns, ctx, converter); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -0,0 +1,512 @@ +//===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering from high level async operations to async.coro +// and async.runtime operations. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/SetVector.h" + +using namespace mlir; +using namespace mlir::async; + +#define DEBUG_TYPE "async-to-async-runtime" +// Prefix for functions outlined from `async.execute` op regions. +static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; + +namespace { + +class AsyncToAsyncRuntimePass + : public AsyncToAsyncRuntimeBase { +public: + AsyncToAsyncRuntimePass() = default; + void runOnOperation() override; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// async.execute op outlining to the coroutine functions. +//===----------------------------------------------------------------------===// + +/// Function targeted for coroutine transformation has two additional blocks at +/// the end: coroutine cleanup and coroutine suspension. +/// +/// async.await op lowering additionaly creates a resume block for each +/// operation to enable non-blocking waiting via coroutine suspension. +namespace { +struct CoroMachinery { + // Async execute region returns a completion token, and an async value for + // each yielded value. + // + // %token, %result = async.execute -> !async.value { + // %0 = constant ... : T + // async.yield %0 : T + // } + Value asyncToken; // token representing completion of the async region + llvm::SmallVector returnValues; // returned async values + + Value coroHandle; // coroutine handle (!async.coro.handle value) + Block *cleanup; // coroutine cleanup block + Block *suspend; // coroutine suspension block +}; +} // namespace + +/// Builds an coroutine template compatible with LLVM coroutines switched-resume +/// lowering using `async.runtime.*` and `async.coro.*` operations. +/// +/// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html +/// +/// - `entry` block sets up the coroutine. +/// - `cleanup` block cleans up the coroutine state. +/// - `suspend block after the @llvm.coro.end() defines what value will be +/// returned to the initial caller of a coroutine. Everything before the +/// @llvm.coro.end() will be executed at every suspension point. +/// +/// Coroutine structure (only the important bits): +/// +/// func @async_execute_fn() +/// -> (!async.token, !async.value) +/// { +/// ^entry(): +/// %token = : !async.token // create async runtime token +/// %value = : !async.value // create async value +/// %id = async.coro.id // create a coroutine id +/// %hdl = async.coro.begin %id // create a coroutine handle +/// br ^cleanup +/// +/// ^cleanup: +/// async.coro.free %hdl // delete the coroutine state +/// br ^suspend +/// +/// ^suspend: +/// async.coro.end %hdl // marks the end of a coroutine +/// return %token, %value : !async.token, !async.value +/// } +/// +/// The actual code for the async.execute operation body region will be inserted +/// before the entry block terminator. +/// +/// +static CoroMachinery setupCoroMachinery(FuncOp func) { + assert(func.getBody().empty() && "Function must have empty body"); + + MLIRContext *ctx = func.getContext(); + Block *entryBlock = func.addEntryBlock(); + + auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); + + // ------------------------------------------------------------------------ // + // Allocate async token/values that we will return from a ramp function. + // ------------------------------------------------------------------------ // + auto retToken = builder.create(TokenType::get(ctx)).result(); + + llvm::SmallVector retValues; + for (auto resType : func.getCallableResults().drop_front()) + retValues.emplace_back(builder.create(resType).result()); + + // ------------------------------------------------------------------------ // + // Initialize coroutine: get coroutine id and coroutine handle. + // ------------------------------------------------------------------------ // + auto coroIdOp = builder.create(CoroIdType::get(ctx)); + auto coroHdlOp = + builder.create(CoroHandleType::get(ctx), coroIdOp.id()); + + Block *cleanupBlock = func.addBlock(); + Block *suspendBlock = func.addBlock(); + + // ------------------------------------------------------------------------ // + // Coroutine cleanup block: deallocate coroutine frame, free the memory. + // ------------------------------------------------------------------------ // + builder.setInsertionPointToStart(cleanupBlock); + builder.create(coroIdOp.id(), coroHdlOp.handle()); + + // Branch into the suspend block. + builder.create(suspendBlock); + + // ------------------------------------------------------------------------ // + // Coroutine suspend block: mark the end of a coroutine and return allocated + // async token. + // ------------------------------------------------------------------------ // + builder.setInsertionPointToStart(suspendBlock); + + // Mark the end of a coroutine: async.coro.end + builder.create(coroHdlOp.handle()); + + // Return created `async.token` and `async.values` from the suspend block. + // This will be the return value of a coroutine ramp function. + SmallVector ret{retToken}; + ret.insert(ret.end(), retValues.begin(), retValues.end()); + builder.create(ret); + + // Branch from the entry block to the cleanup block to create a valid CFG. + builder.setInsertionPointToEnd(entryBlock); + builder.create(cleanupBlock); + + // `async.await` op lowering will create resume blocks for async + // continuations, and will conditionally branch to cleanup or suspend blocks. + + CoroMachinery machinery; + machinery.asyncToken = retToken; + machinery.returnValues = retValues; + machinery.coroHandle = coroHdlOp.handle(); + machinery.cleanup = cleanupBlock; + machinery.suspend = suspendBlock; + return machinery; +} + +/// Outline the body region attached to the `async.execute` op into a standalone +/// function. +/// +/// Note that this is not reversible transformation. +static std::pair +outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { + ModuleOp module = execute->getParentOfType(); + + MLIRContext *ctx = module.getContext(); + Location loc = execute.getLoc(); + + // Collect all outlined function inputs. + llvm::SetVector functionInputs(execute.dependencies().begin(), + execute.dependencies().end()); + functionInputs.insert(execute.operands().begin(), execute.operands().end()); + getUsedValuesDefinedAbove(execute.body(), functionInputs); + + // Collect types for the outlined function inputs and outputs. + auto typesRange = llvm::map_range( + functionInputs, [](Value value) { return value.getType(); }); + SmallVector inputTypes(typesRange.begin(), typesRange.end()); + auto outputTypes = execute.getResultTypes(); + + auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); + auto funcAttrs = ArrayRef(); + + // TODO: Derive outlined function name from the parent FuncOp (support + // multiple nested async.execute operations). + FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); + symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator())); + + SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); + + // Prepare a function for coroutine lowering by adding entry/cleanup/suspend + // blocks, adding async.coro operations and setting up control flow. + CoroMachinery coro = setupCoroMachinery(func); + + // Suspend async function at the end of an entry block, and resume it using + // Async resume operation (execution will be resumed in a thread managed by + // the async runtime). + Block *entryBlock = &func.getBlocks().front(); + auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock); + + // Save the coroutine state: async.coro.save + auto coroSaveOp = + builder.create(CoroStateType::get(ctx), coro.coroHandle); + + // Pass coroutine to the runtime to be resumed on a runtime managed thread. + builder.create(coro.coroHandle); + + // Split the entry block before the terminator (branch to suspend block). + auto *terminatorOp = entryBlock->getTerminator(); + Block *suspended = terminatorOp->getBlock(); + Block *resume = suspended->splitBlock(terminatorOp); + + // Add async.coro.suspend as a suspended block terminator. + builder.setInsertionPointToEnd(suspended); + builder.create(coroSaveOp.state(), coro.suspend, resume, + coro.cleanup); + + size_t numDependencies = execute.dependencies().size(); + size_t numOperands = execute.operands().size(); + + // Await on all dependencies before starting to execute the body region. + builder.setInsertionPointToStart(resume); + for (size_t i = 0; i < numDependencies; ++i) + builder.create(func.getArgument(i)); + + // Await on all async value operands and unwrap the payload. + SmallVector unwrappedOperands(numOperands); + for (size_t i = 0; i < numOperands; ++i) { + Value operand = func.getArgument(numDependencies + i); + unwrappedOperands[i] = builder.create(loc, operand).result(); + } + + // Map from function inputs defined above the execute op to the function + // arguments. + BlockAndValueMapping valueMapping; + valueMapping.map(functionInputs, func.getArguments()); + valueMapping.map(execute.body().getArguments(), unwrappedOperands); + + // Clone all operations from the execute operation body into the outlined + // function body. + for (Operation &op : execute.body().getOps()) + builder.clone(op, valueMapping); + + // Replace the original `async.execute` with a call to outlined function. + ImplicitLocOpBuilder callBuilder(loc, execute); + auto callOutlinedFunc = callBuilder.create( + func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); + execute.replaceAllUsesWith(callOutlinedFunc.getResults()); + execute.erase(); + + return {func, coro}; +} + +//===----------------------------------------------------------------------===// +// Convert async.create_group operation to async.runtime.create +//===----------------------------------------------------------------------===// + +namespace { +class CreateGroupOpLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CreateGroupOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, GroupType::get(op->getContext())); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.add_to_group operation to async.runtime.add_to_group. +//===----------------------------------------------------------------------===// + +namespace { +class AddToGroupOpLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AddToGroupOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, rewriter.getIndexType(), operands); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.await and async.await_all operations to the async.runtime.await +// or async.runtime.await_and_resume operations. +//===----------------------------------------------------------------------===// + +namespace { +template +class AwaitOpLoweringBase : public OpConversionPattern { + using AwaitAdaptor = typename AwaitType::Adaptor; + +public: + AwaitOpLoweringBase( + MLIRContext *ctx, + const llvm::DenseMap &outlinedFunctions) + : OpConversionPattern(ctx), + outlinedFunctions(outlinedFunctions) {} + + LogicalResult + matchAndRewrite(AwaitType op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // We can only await on one the `AwaitableType` (for `await` it can be + // a `token` or a `value`, for `await_all` it must be a `group`). + if (!op.operand().getType().template isa()) + return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); + + // Check if await operation is inside the outlined coroutine function. + auto func = op->template getParentOfType(); + auto outlined = outlinedFunctions.find(func); + const bool isInCoroutine = outlined != outlinedFunctions.end(); + + Location loc = op->getLoc(); + Value operand = AwaitAdaptor(operands).operand(); + + // Inside regular functions we use the blocking wait operation to wait for + // the async object (token, value or group) to become available. + if (!isInCoroutine) + rewriter.create(loc, operand); + + // Inside the coroutine we convert await operation into coroutine suspension + // point, and resume execution asynchronously. + if (isInCoroutine) { + const CoroMachinery &coro = outlined->getSecond(); + Block *suspended = op->getBlock(); + + ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); + MLIRContext *ctx = op->getContext(); + + // Save the coroutine state and resume on a runtime managed thread when + // the operand becomes available. + auto coroSaveOp = + builder.create(CoroStateType::get(ctx), coro.coroHandle); + builder.create(operand, coro.coroHandle); + + // Split the entry block before the await operation. + Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); + + // Add async.coro.suspend as a suspended block terminator. + builder.setInsertionPointToEnd(suspended); + builder.create(coroSaveOp.state(), coro.suspend, resume, + coro.cleanup); + + // Make sure that replacement value will be constructed in resume block. + rewriter.setInsertionPointToStart(resume); + } + + // Erase or replace the await operation with the new value. + if (Value replaceWith = getReplacementValue(op, operand, rewriter)) + rewriter.replaceOp(op, replaceWith); + else + rewriter.eraseOp(op); + + return success(); + } + + virtual Value getReplacementValue(AwaitType op, Value operand, + ConversionPatternRewriter &rewriter) const { + return Value(); + } + +private: + const llvm::DenseMap &outlinedFunctions; +}; + +/// Lowering for `async.await` with a token operand. +class AwaitTokenOpLowering : public AwaitOpLoweringBase { + using Base = AwaitOpLoweringBase; + +public: + using Base::Base; +}; + +/// Lowering for `async.await` with a value operand. +class AwaitValueOpLowering : public AwaitOpLoweringBase { + using Base = AwaitOpLoweringBase; + +public: + using Base::Base; + + Value + getReplacementValue(AwaitOp op, Value operand, + ConversionPatternRewriter &rewriter) const override { + // Load from the async value storage. + auto valueType = operand.getType().cast().getValueType(); + return rewriter.create(op->getLoc(), valueType, operand); + } +}; + +/// Lowering for `async.await_all` operation. +class AwaitAllOpLowering : public AwaitOpLoweringBase { + using Base = AwaitOpLoweringBase; + +public: + using Base::Base; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.yield operation to async.runtime operations. +//===----------------------------------------------------------------------===// + +class YieldOpLowering : public OpConversionPattern { +public: + YieldOpLowering( + MLIRContext *ctx, + const llvm::DenseMap &outlinedFunctions) + : OpConversionPattern(ctx), + outlinedFunctions(outlinedFunctions) {} + + LogicalResult + matchAndRewrite(async::YieldOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // Check if yield operation is inside the outlined coroutine function. + auto func = op->template getParentOfType(); + auto outlined = outlinedFunctions.find(func); + if (outlined == outlinedFunctions.end()) + return rewriter.notifyMatchFailure( + op, "operation is not inside the outlined async.execute function"); + + Location loc = op->getLoc(); + const CoroMachinery &coro = outlined->getSecond(); + + // Store yielded values into the async values storage and switch async + // values state to available. + for (auto tuple : llvm::zip(operands, coro.returnValues)) { + Value yieldValue = std::get<0>(tuple); + Value asyncValue = std::get<1>(tuple); + rewriter.create(loc, yieldValue, asyncValue); + rewriter.create(loc, asyncValue); + } + + // Switch the coroutine completion token to available state. + rewriter.replaceOpWithNewOp(op, coro.asyncToken); + + return success(); + } + +private: + const llvm::DenseMap &outlinedFunctions; +}; + +//===----------------------------------------------------------------------===// + +void AsyncToAsyncRuntimePass::runOnOperation() { + ModuleOp module = getOperation(); + SymbolTable symbolTable(module); + + // Outline all `async.execute` body regions into async functions (coroutines). + llvm::DenseMap outlinedFunctions; + + module.walk([&](ExecuteOp execute) { + outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); + }); + + LLVM_DEBUG({ + llvm::dbgs() << "Outlined " << outlinedFunctions.size() + << " functions built from async.execute operations\n"; + }); + + // Lower async operations to async.runtime operations. + MLIRContext *ctx = module->getContext(); + OwningRewritePatternList asyncPatterns; + + // Async lowering does not use type converter because it must preserve all + // types for async.runtime operations. + asyncPatterns.insert(ctx); + asyncPatterns.insert(ctx, + outlinedFunctions); + + // All high level async operations must be lowered to the runtime operations. + ConversionTarget runtimeTarget(*ctx); + runtimeTarget.addLegalDialect(); + runtimeTarget.addIllegalOp(); + runtimeTarget.addIllegalOp(); + + if (failed(applyPartialConversion(module, runtimeTarget, + std::move(asyncPatterns)))) { + signalPassFailure(); + return; + } +} + +std::unique_ptr> mlir::createAsyncToAsyncRuntimePass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ AsyncParallelFor.cpp AsyncRefCounting.cpp AsyncRefCountingOptimization.cpp + AsyncToAsyncRuntime.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -convert-async-to-llvm | FileCheck %s +// RUN: mlir-opt %s -split-input-file -async-to-async-runtime -convert-async-to-llvm | FileCheck %s // CHECK-LABEL: reference_counting func @reference_counting(%arg0: !async.token) { @@ -247,8 +247,7 @@ // ----- -// RUN: mlir-opt %s -split-input-file -convert-async-to-llvm | FileCheck %s - +// CHECK-LABEL: @async_value_operands func @async_value_operands() { // CHECK: %[[RET:.*]]:2 = call @async_execute_fn %token, %result = async.execute -> !async.value { diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir @@ -0,0 +1,303 @@ +// RUN: mlir-opt %s -split-input-file -async-to-async-runtime -print-ir-after-all | FileCheck %s --dump-input=always + +// CHECK-LABEL: @execute_no_async_args +func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) { + %token = async.execute { + %c0 = constant 0 : index + store %arg0, %arg1[%c0] : memref<1xf32> + async.yield + } + async.await %token : !async.token + return +} + +// Function outlined from the async.execute operation. +// CHECK-LABEL: func private @async_execute_fn +// CHECK-SAME: -> !async.token + +// Create token for return op, and mark a function as a coroutine. +// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token +// CHECK: %[[ID:.*]] = async.coro.id +// CHECK: %[[HDL:.*]] = async.coro.begin + +// Pass a suspended coroutine to the async runtime. +// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.resume %[[HDL]] +// CHECK: async.coro.suspend %[[SAVED]] +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]] + +// Resume coroutine after suspension. +// CHECK: ^[[RESUME]]: +// CHECK: store +// CHECK: async.runtime.set_available %[[TOKEN]] + +// Delete coroutine. +// CHECK: ^[[CLEANUP]]: +// CHECK: async.coro.free %[[ID]], %[[HDL]] + +// Suspend coroutine, and also a return statement for ramp function. +// CHECK: ^[[SUSPEND]]: +// CHECK: async.coro.end %[[HDL]] +// CHECK: return %[[TOKEN]] + +// ----- + +// CHECK-LABEL: @nested_async_execute +func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) { + // CHECK: %[[TOKEN:.*]] = call @async_execute_fn_0(%arg0, %arg2, %arg1) + %token0 = async.execute { + %c0 = constant 0 : index + + %token1 = async.execute { + %c1 = constant 1: index + store %arg0, %arg2[%c0] : memref<1xf32> + async.yield + } + async.await %token1 : !async.token + + store %arg1, %arg2[%c0] : memref<1xf32> + async.yield + } + // CHECK: async.runtime.await %[[TOKEN]] + // CHECK-NEXT: return + async.await %token0 : !async.token + return +} + +// Function outlined from the inner async.execute operation. +// CHECK-LABEL: func private @async_execute_fn +// CHECK-SAME: -> !async.token + +// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token +// CHECK: %[[ID:.*]] = async.coro.id +// CHECK: %[[HDL:.*]] = async.coro.begin + +// CHECK: async.runtime.resume %[[HDL]] +// CHECK: async.coro.suspend +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]] + +// CHECK: ^[[RESUME]]: +// CHECK: store +// CHECK: async.runtime.set_available %[[TOKEN]] + +// Function outlined from the outer async.execute operation. +// CHECK-LABEL: func private @async_execute_fn_0 +// CHECK-SAME: -> !async.token + +// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token +// CHECK: %[[ID:.*]] = async.coro.id +// CHECK: %[[HDL:.*]] = async.coro.begin + +// Suspend coroutine in the beginning. +// CHECK: async.runtime.resume %[[HDL]] +// CHECK: async.coro.suspend +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_0:.*]], ^[[CLEANUP:.*]] + +// Suspend coroutine second time waiting for the completion of inner execute op. +// CHECK: ^[[RESUME_0]]: +// CHECK: %[[INNER_TOKEN:.*]] = call @async_execute_fn +// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.await_and_resume %[[INNER_TOKEN]], %[[HDL]] +// CHECK: async.coro.suspend %[[SAVED]] +// CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]] + +// Set token available after second resumption. +// CHECK: ^[[RESUME_1]]: +// CHECK: store +// CHECK: async.runtime.set_available %[[TOKEN]] + +// CHECK: ^[[CLEANUP]]: +// CHECK: ^[[SUSPEND]]: + +// ----- + +// CHECK-LABEL: @async_execute_token_dependency +func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) { + // CHECK: %[[TOKEN:.*]] = call @async_execute_fn + %token = async.execute { + %c0 = constant 0 : index + store %arg0, %arg1[%c0] : memref<1xf32> + async.yield + } + // CHECK: call @async_execute_fn_0(%[[TOKEN]], %arg0, %arg1) + %token_0 = async.execute [%token] { + %c0 = constant 0 : index + store %arg0, %arg1[%c0] : memref<1xf32> + async.yield + } + return +} + +// Function outlined from the first async.execute operation. +// CHECK-LABEL: func private @async_execute_fn +// CHECK-SAME: -> !async.token +// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token +// CHECK: return %[[TOKEN]] : !async.token + +// Function outlined from the second async.execute operation with dependency. +// CHECK-LABEL: func private @async_execute_fn_0 +// CHECK-SAME: %[[ARG0:.*]]: !async.token +// CHECK-SAME: %[[ARG1:.*]]: f32 +// CHECK-SAME: %[[ARG2:.*]]: memref<1xf32> +// CHECK-SAME: -> !async.token +// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token +// CHECK: %[[HDL:.*]] = async.coro.begin + +// Suspend coroutine in the beginning. +// CHECK: async.runtime.resume %[[HDL]] +// CHECK: async.coro.suspend +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_0:.*]], ^[[CLEANUP:.*]] + +// Suspend coroutine second time waiting for the completion of token dependency. +// CHECK: ^[[RESUME_0]]: +// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.await_and_resume %[[ARG0]], %[[HDL]] +// CHECK: async.coro.suspend %[[SAVED]] +// CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]] + +// Emplace result token after second resumption. +// CHECK: ^[[RESUME_1]]: +// CHECK: store +// CHECK: async.runtime.set_available %[[TOKEN]] + +// CHECK: ^[[CLEANUP]]: +// CHECK: ^[[SUSPEND]]: + +// ----- + +// CHECK-LABEL: @async_group_await_all +func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) { + // CHECK: %[[GROUP:.*]] = async.runtime.create : !async.group + %0 = async.create_group + + // CHECK: %[[TOKEN:.*]] = call @async_execute_fn + %token = async.execute { async.yield } + // CHECK: async.runtime.add_to_group %[[TOKEN]], %[[GROUP]] + async.add_to_group %token, %0 : !async.token + + // CHECK: call @async_execute_fn_0 + async.execute { + async.await_all %0 + async.yield + } + + // CHECK: async.runtime.await %[[GROUP]] : !async.group + async.await_all %0 + return +} + +// Function outlined from the second async.execute operation. +// CHECK-LABEL: func private @async_execute_fn_0 +// CHECK-SAME: (%[[ARG:.*]]: !async.group) -> !async.token + +// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token +// CHECK: %[[HDL:.*]] = async.coro.begin + +// Suspend coroutine in the beginning. +// CHECK: async.runtime.resume %[[HDL]] +// CHECK: async.coro.suspend +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_0:.*]], ^[[CLEANUP:.*]] + +// Suspend coroutine second time waiting for the group. +// CHECK: ^[[RESUME_0]]: +// CHECK: async.runtime.await_and_resume %[[ARG]], %[[HDL]] +// CHECK: async.coro.suspend +// CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]] + +// Emplace result token. +// CHECK: ^[[RESUME_1]]: +// CHECK: async.runtime.set_available %[[TOKEN]] + +// CHECK: ^[[CLEANUP]]: +// CHECK: ^[[SUSPEND]]: + +// ----- + +// CHECK-LABEL: @execute_and_return_f32 +func @execute_and_return_f32() -> f32 { + // CHECK: %[[RET:.*]]:2 = call @async_execute_fn + %token, %result = async.execute -> !async.value { + %c0 = constant 123.0 : f32 + async.yield %c0 : f32 + } + + // CHECK: async.runtime.await %[[RET]]#1 : !async.value + // CHECK: %[[VALUE:.*]] = async.runtime.load %[[RET]]#1 : !async.value + %0 = async.await %result : !async.value + + // CHECK: return %[[VALUE]] + return %0 : f32 +} + +// Function outlined from the async.execute operation. +// CHECK-LABEL: func private @async_execute_fn() +// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token +// CHECK: %[[VALUE:.*]] = async.runtime.create : !async.value +// CHECK: %[[HDL:.*]] = async.coro.begin + +// Suspend coroutine in the beginning. +// CHECK: async.runtime.resume %[[HDL]] +// CHECK: async.coro.suspend +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]] + +// Emplace result value. +// CHECK: ^[[RESUME]]: +// CHECK: %[[CST:.*]] = constant 1.230000e+02 : f32 +// CHECK: async.runtime.store %cst, %[[VALUE]] +// CHECK: async.runtime.set_available %[[VALUE]] +// CHECK: async.runtime.set_available %[[TOKEN]] + +// CHECK: ^[[CLEANUP]]: +// CHECK: ^[[SUSPEND]]: + +// ----- + +// CHECK-LABEL: @async_value_operands +func @async_value_operands() { + // CHECK: %[[RET:.*]]:2 = call @async_execute_fn + %token, %result = async.execute -> !async.value { + %c0 = constant 123.0 : f32 + async.yield %c0 : f32 + } + + // CHECK: %[[TOKEN:.*]] = call @async_execute_fn_0(%[[RET]]#1) + %token0 = async.execute(%result as %value: !async.value) { + %0 = addf %value, %value : f32 + async.yield + } + + // CHECK: async.runtime.await %[[TOKEN]] : !async.token + async.await %token0 : !async.token + + return +} + +// Function outlined from the first async.execute operation. +// CHECK-LABEL: func private @async_execute_fn() + +// Function outlined from the second async.execute operation. +// CHECK-LABEL: func private @async_execute_fn_0 +// CHECK-SAME: (%[[ARG:.*]]: !async.value) -> !async.token + +// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token +// CHECK: %[[HDL:.*]] = async.coro.begin + +// Suspend coroutine in the beginning. +// CHECK: async.runtime.resume %[[HDL]] +// CHECK: async.coro.suspend +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_0:.*]], ^[[CLEANUP:.*]] + +// Suspend coroutine second time waiting for the async operand. +// CHECK: ^[[RESUME_0]]: +// CHECK: async.runtime.await_and_resume %[[ARG]], %[[HDL]] +// CHECK: async.coro.suspend +// CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]] + +// Load from the async.value argument. +// CHECK: ^[[RESUME_1]]: +// CHECK: %[[LOADED:.*]] = async.runtime.load %[[ARG]] : !async.value