diff --git a/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h b/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h @@ -0,0 +1,25 @@ +//===- AsyncToLLVM.h - Convert Async to LLVM dialect ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ASYNCTOLLVM_ASYNCTOLLVM_H +#define MLIR_CONVERSION_ASYNCTOLLVM_ASYNCTOLLVM_H + +#include + +namespace mlir { + +class ModuleOp; +template +class OperationPass; + +/// Create a pass to convert Async operations to the LLVM dialect. +std::unique_ptr> createConvertAsyncToLLVMPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_ASYNCTOLLVM_ASYNCTOLLVM_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -11,6 +11,7 @@ #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -84,6 +84,21 @@ let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"]; } +//===----------------------------------------------------------------------===// +// AsyncToLLVM +//===----------------------------------------------------------------------===// + +def ConvertAsyncToLLVM : Pass<"convert-async-to-llvm", "ModuleOp"> { + let summary = "Convert the operations from the async dialect into the LLVM " + "dialect"; + let description = [{ + Convert `async.execute` operations to LLVM coroutines and use async runtime + API to execute them. + }]; + let constructor = "mlir::createConvertAsyncToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; +} + //===----------------------------------------------------------------------===// // GPUCommon //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -24,7 +24,9 @@ class Async_Op traits = []> : Op; -def Async_ExecuteOp : Async_Op<"execute", [AttrSizedOperandSegments]> { +def Async_ExecuteOp : + Async_Op<"execute", [SingleBlockImplicitTerminator<"YieldOp">, + AttrSizedOperandSegments]> { let summary = "Asynchronous execute operation"; let description = [{ The `body` region attached to the `async.execute` operation semantically diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h @@ -0,0 +1,73 @@ +//===- AsyncRuntime.h - Async runtime reference implementation ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares basic Async runtime API for supporting Async dialect +// to LLVM dialect lowering. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXECUTIONENGINE_ASYNCRUNTIME_H_ +#define MLIR_EXECUTIONENGINE_ASYNCRUNTIME_H_ + +#ifdef _WIN32 +#ifndef MLIR_ASYNCRUNTIME_EXPORT +#ifdef mlir_c_runner_utils_EXPORTS +/* We are building this library */ +#define MLIR_ASYNCRUNTIME_EXPORT __declspec(dllexport) +#define MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS +#else +/* We are using this library */ +#define MLIR_ASYNCRUNTIME_EXPORT __declspec(dllimport) +#endif // mlir_c_runner_utils_EXPORTS +#endif // MLIR_ASYNCRUNTIME_EXPORT +#else +#define MLIR_ASYNCRUNTIME_EXPORT +#define MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS +#endif // _WIN32 + +//===----------------------------------------------------------------------===// +// Async runtime API. +//===----------------------------------------------------------------------===// + +// Runtime implementation of `async.token` data type. +typedef struct AsyncToken MLIR_AsyncToken; + +// Async runtime uses LLVM coroutines to represent asynchronous tasks. Task +// function is a coroutine handle and a resume function that continue coroutine +// execution from a suspension point. +using CoroHandle = void *; // coroutine handle +using CoroResume = void (*)(void *); // coroutine resume function + +// Create a new `async.token` in not-ready state. +extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken(); + +// Switches `async.token` to ready state and runs all awaiters. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeEmplaceToken(AsyncToken *); + +// Blocks the caller thread until the token becomes ready. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeAwaitToken(AsyncToken *); + +// Executes the task (coro handle + resume function) in one of the threads +// managed by the runtime. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle, + CoroResume); + +// Executes the task (coro handle + resume function) in one of the threads +// managed by the runtime after the token becomes ready. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *, CoroHandle, CoroResume); + +//===----------------------------------------------------------------------===// +// Small async runtime support library for testing. +//===----------------------------------------------------------------------===// + +extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimePrintCurrentThreadId(); + +#endif // MLIR_EXECUTIONENGINE_ASYNCRUNTIME_H_ diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -0,0 +1,733 @@ +//===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/FormatVariadic.h" + +#define DEBUG_TYPE "convert-async-to-llvm" + +using namespace mlir; +using namespace mlir::async; + +// Prefix for functions outlined from `async.execute` op regions. +static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; + +//===----------------------------------------------------------------------===// +// Async Runtime C API declaration. +//===----------------------------------------------------------------------===// + +static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; +static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; +static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; +static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; +static constexpr const char *kAwaitAndExecute = + "mlirAsyncRuntimeAwaitTokenAndExecute"; + +namespace { +// Async Runtime API function types. +struct AsyncAPI { + static FunctionType createTokenFunctionType(MLIRContext *ctx) { + return FunctionType::get({}, {TokenType::get(ctx)}, ctx); + } + + static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { + return FunctionType::get({TokenType::get(ctx)}, {}, ctx); + } + + static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { + return FunctionType::get({TokenType::get(ctx)}, {}, ctx); + } + + static FunctionType executeFunctionType(MLIRContext *ctx) { + auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); + auto resume = resumeFunctionType(ctx).getPointerTo(); + return FunctionType::get({hdl, resume}, {}, ctx); + } + + static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) { + auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); + auto resume = resumeFunctionType(ctx).getPointerTo(); + return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx); + } + + // Auxiliary coroutine resume intrinsic wrapper. + static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) { + auto voidTy = LLVM::LLVMType::getVoidTy(ctx); + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); + return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false); + } +}; +} // namespace + +// Adds Async Runtime C API declarations to the module. +static void addAsyncRuntimeApiDeclarations(ModuleOp module) { + auto builder = OpBuilder::atBlockTerminator(module.getBody()); + + MLIRContext *ctx = module.getContext(); + Location loc = module.getLoc(); + + if (!module.lookupSymbol(kCreateToken)) + builder.create(loc, kCreateToken, + AsyncAPI::createTokenFunctionType(ctx)); + + if (!module.lookupSymbol(kEmplaceToken)) + builder.create(loc, kEmplaceToken, + AsyncAPI::emplaceTokenFunctionType(ctx)); + + if (!module.lookupSymbol(kAwaitToken)) + builder.create(loc, kAwaitToken, + AsyncAPI::awaitTokenFunctionType(ctx)); + + if (!module.lookupSymbol(kExecute)) + builder.create(loc, kExecute, AsyncAPI::executeFunctionType(ctx)); + + if (!module.lookupSymbol(kAwaitAndExecute)) + builder.create(loc, kAwaitAndExecute, + AsyncAPI::awaitAndExecuteFunctionType(ctx)); +} + +//===----------------------------------------------------------------------===// +// LLVM coroutines intrinsics declarations. +//===----------------------------------------------------------------------===// + +static constexpr const char *kCoroId = "llvm.coro.id"; +static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64"; +static constexpr const char *kCoroBegin = "llvm.coro.begin"; +static constexpr const char *kCoroSave = "llvm.coro.save"; +static constexpr const char *kCoroSuspend = "llvm.coro.suspend"; +static constexpr const char *kCoroEnd = "llvm.coro.end"; +static constexpr const char *kCoroFree = "llvm.coro.free"; +static constexpr const char *kCoroResume = "llvm.coro.resume"; + +/// Adds coroutine intrinsics declarations to the module. +static void addCoroutineIntrinsicsDeclarations(ModuleOp module) { + using namespace mlir::LLVM; + + MLIRContext *ctx = module.getContext(); + Location loc = module.getLoc(); + + OpBuilder builder(module.getBody()->getTerminator()); + + auto token = LLVMTokenType::get(ctx); + auto voidTy = LLVMType::getVoidTy(ctx); + + auto i8 = LLVMType::getInt8Ty(ctx); + auto i1 = LLVMType::getInt1Ty(ctx); + auto i32 = LLVMType::getInt32Ty(ctx); + auto i64 = LLVMType::getInt64Ty(ctx); + auto i8Ptr = LLVMType::getInt8PtrTy(ctx); + + if (!module.lookupSymbol(kCoroId)) + builder.create( + loc, kCoroId, + LLVMType::getFunctionTy(token, {i32, i8Ptr, i8Ptr, i8Ptr}, false)); + + if (!module.lookupSymbol(kCoroSizeI64)) + builder.create(loc, kCoroSizeI64, + LLVMType::getFunctionTy(i64, false)); + + if (!module.lookupSymbol(kCoroBegin)) + builder.create( + loc, kCoroBegin, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false)); + + if (!module.lookupSymbol(kCoroSave)) + builder.create(loc, kCoroSave, + LLVMType::getFunctionTy(token, i8Ptr, false)); + + if (!module.lookupSymbol(kCoroSuspend)) + builder.create(loc, kCoroSuspend, + LLVMType::getFunctionTy(i8, {token, i1}, false)); + + if (!module.lookupSymbol(kCoroEnd)) + builder.create(loc, kCoroEnd, + LLVMType::getFunctionTy(i1, {i8Ptr, i1}, false)); + + if (!module.lookupSymbol(kCoroFree)) + builder.create( + loc, kCoroFree, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false)); + + if (!module.lookupSymbol(kCoroResume)) + builder.create(loc, kCoroResume, + LLVMType::getFunctionTy(voidTy, i8Ptr, false)); +} + +//===----------------------------------------------------------------------===// +// Add malloc/free declarations to the module. +//===----------------------------------------------------------------------===// + +static constexpr const char *kMalloc = "malloc"; +static constexpr const char *kFree = "free"; + +/// Adds malloc/free declarations to the module. +static void addCRuntimeDeclarations(ModuleOp module) { + using namespace mlir::LLVM; + + MLIRContext *ctx = module.getContext(); + Location loc = module.getLoc(); + + OpBuilder builder(module.getBody()->getTerminator()); + + auto voidTy = LLVMType::getVoidTy(ctx); + auto i64 = LLVMType::getInt64Ty(ctx); + auto i8Ptr = LLVMType::getInt8PtrTy(ctx); + + if (!module.lookupSymbol(kMalloc)) + builder.create( + loc, kMalloc, LLVMType::getFunctionTy(i8Ptr, {i64}, false)); + + if (!module.lookupSymbol(kFree)) + builder.create( + loc, kFree, LLVMType::getFunctionTy(voidTy, i8Ptr, false)); +} + +//===----------------------------------------------------------------------===// +// Coroutine resume function wrapper. +//===----------------------------------------------------------------------===// + +static constexpr const char *kResume = "__resume"; + +// A function that takes a coroutine handle and calls a `llvm.coro.resume` +// intrinsics. We need this function to be able to pass it to the async +// runtime execute API. +static void addResumeFunction(ModuleOp module) { + MLIRContext *ctx = module.getContext(); + + OpBuilder moduleBuilder(module.getBody()->getTerminator()); + Location loc = module.getLoc(); + + if (module.lookupSymbol(kResume)) + return; + + auto voidTy = LLVM::LLVMType::getVoidTy(ctx); + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); + + auto resumeOp = moduleBuilder.create( + loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); + SymbolTable::setSymbolVisibility(resumeOp, SymbolTable::Visibility::Private); + + auto *block = resumeOp.addEntryBlock(); + OpBuilder blockBuilder = OpBuilder::atBlockEnd(block); + + blockBuilder.create(loc, Type(), + blockBuilder.getSymbolRefAttr(kCoroResume), + resumeOp.getArgument(0)); + + blockBuilder.create(loc, ValueRange()); +} + +//===----------------------------------------------------------------------===// +// async.execute op outlining to the coroutine functions. +//===----------------------------------------------------------------------===// + +// Function targeted for coroutine transformation has two additional blocks at +// the end: coroutine cleanup and coroutine suspension. +// +// async.await op lowering additionaly creates a resume block for each +// operation to enable non-blocking waiting via coroutine suspension. +namespace { +struct CoroMachinery { + Value asyncToken; + Value coroHandle; + Block *cleanup; + Block *suspend; +}; +} // namespace + +// Builds an coroutine template compatible with LLVM coroutines lowering. +// +// - `entry` block sets up the coroutine. +// - `cleanup` block cleans up the coroutine state. +// - `suspend block after the @llvm.coro.end() defines what value will be +// returned to the initial caller of a coroutine. Everything before the +// @llvm.coro.end() will be executed at every suspension point. +// +// Coroutine structure (only the important bits): +// +// func @async_execute_fn() -> !async.token { +// ^entryBlock(): +// %token = : !async.token // create async runtime token +// %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle +// br ^cleanup +// +// ^cleanup: +// llvm.call @llvm.coro.free(...) // delete coroutine state +// br ^suspend +// +// ^suspend: +// llvm.call @llvm.coro.end(...) // marks the end of a coroutine +// return %token : !async.token +// } +// +// The actual code for the async.execute operation body region will be inserted +// before the entry block terminator. +// +// +static CoroMachinery setupCoroMachinery(FuncOp func) { + assert(func.getBody().empty() && "Function must have empty body"); + + MLIRContext *ctx = func.getContext(); + + auto token = LLVM::LLVMTokenType::get(ctx); + auto i1 = LLVM::LLVMType::getInt1Ty(ctx); + auto i32 = LLVM::LLVMType::getInt32Ty(ctx); + auto i64 = LLVM::LLVMType::getInt64Ty(ctx); + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); + + Block *entryBlock = func.addEntryBlock(); + Location loc = func.getBody().getLoc(); + + OpBuilder builder = OpBuilder::atBlockBegin(entryBlock); + + // ------------------------------------------------------------------------ // + // Allocate async tokens/values that we will return from a ramp function. + // ------------------------------------------------------------------------ // + auto createToken = + builder.create(loc, kCreateToken, TokenType::get(ctx)); + + // ------------------------------------------------------------------------ // + // Initialize coroutine: allocate frame, get coroutine handle. + // ------------------------------------------------------------------------ // + + // Constants for initializing coroutine frame. + auto constZero = + builder.create(loc, i32, builder.getI32IntegerAttr(0)); + auto constFalse = + builder.create(loc, i1, builder.getBoolAttr(false)); + auto nullPtr = builder.create(loc, i8Ptr); + + // Get coroutine id: @llvm.coro.id + auto coroId = builder.create( + loc, token, builder.getSymbolRefAttr(kCoroId), + ValueRange({constZero, nullPtr, nullPtr, nullPtr})); + + // Get coroutine frame size: @llvm.coro.size.i64 + auto coroSize = builder.create( + loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange()); + + // Allocate memory for coroutine frame. + auto coroAlloc = builder.create( + loc, i8Ptr, builder.getSymbolRefAttr(kMalloc), + ValueRange(coroSize.getResult(0))); + + // Begin a coroutine: @llvm.coro.begin + auto coroHdl = builder.create( + loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin), + ValueRange({coroId.getResult(0), coroAlloc.getResult(0)})); + + Block *cleanupBlock = func.addBlock(); + Block *suspendBlock = func.addBlock(); + + // ------------------------------------------------------------------------ // + // Coroutine cleanup block: deallocate coroutine frame, free the memory. + // ------------------------------------------------------------------------ // + builder.setInsertionPointToStart(cleanupBlock); + + // Get a pointer to the coroutine frame memory: @llvm.coro.free. + auto coroMem = builder.create( + loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree), + ValueRange({coroId.getResult(0), coroHdl.getResult(0)})); + + // Free the memory. + builder.create(loc, Type(), builder.getSymbolRefAttr(kFree), + ValueRange(coroMem.getResult(0))); + // Branch into the suspend block. + builder.create(loc, suspendBlock); + + // ------------------------------------------------------------------------ // + // Coroutine suspend block: mark the end of a coroutine and return allocated + // async token. + // ------------------------------------------------------------------------ // + builder.setInsertionPointToStart(suspendBlock); + + // Mark the end of a coroutine: @llvm.coro.end. + builder.create(loc, i1, builder.getSymbolRefAttr(kCoroEnd), + ValueRange({coroHdl.getResult(0), constFalse})); + + // Return created `async.token` from the suspend block. This will be the + // return value of a coroutine ramp function. + builder.create(loc, createToken.getResult(0)); + + // Branch from the entry block to the cleanup block to create a valid CFG. + builder.setInsertionPointToEnd(entryBlock); + + builder.create(loc, cleanupBlock); + + // `async.await` op lowering will create resume blocks for async + // continuations, and will conditionally branch to cleanup or suspend blocks. + + return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock, + suspendBlock}; +} + +// Adds a suspension point before the `op`, and moves `op` and all operations +// after it into the resume block. Returns a pointer to the resume block. +// +// `coroState` must be a value returned from the call to @llvm.coro.save(...) +// intrinsic (saved coroutine state). +// +// Before: +// +// ^bb0: +// "opBefore"(...) +// "op"(...) +// ^cleanup: ... +// ^suspend: ... +// +// After: +// +// ^bb0: +// "opBefore"(...) +// %suspend = llmv.call @llvm.coro.suspend(...) +// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] +// ^resume: +// "op"(...) +// ^cleanup: ... +// ^suspend: ... +// +static Block *addSuspensionPoint(CoroMachinery coro, Value coroState, + Operation *op) { + MLIRContext *ctx = op->getContext(); + auto i1 = LLVM::LLVMType::getInt1Ty(ctx); + auto i8 = LLVM::LLVMType::getInt8Ty(ctx); + + Location loc = op->getLoc(); + Block *splitBlock = op->getBlock(); + + // Split the block before `op`, newly added block is the resume block. + Block *resume = splitBlock->splitBlock(op); + + // Add a coroutine suspension in place of original `op` in the split block. + OpBuilder builder = OpBuilder::atBlockEnd(splitBlock); + + auto constFalse = + builder.create(loc, i1, builder.getBoolAttr(false)); + + // Suspend a coroutine: @llvm.coro.suspend + auto coroSuspend = builder.create( + loc, i8, builder.getSymbolRefAttr(kCoroSuspend), + ValueRange({coroState, constFalse})); + + // After a suspension point decide if we should branch into resume, cleanup + // or suspend block of the coroutine (see @llvm.coro.suspend return code + // documentation). + auto constZero = + builder.create(loc, i8, builder.getI8IntegerAttr(0)); + auto constNegOne = + builder.create(loc, i8, builder.getI8IntegerAttr(-1)); + + Block *resumeOrCleanup = builder.createBlock(resume); + + // Suspend the coroutine ...? + builder.setInsertionPointToEnd(splitBlock); + auto isNegOne = builder.create( + loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne); + builder.create(loc, isNegOne, /*trueDest=*/coro.suspend, + /*falseDest=*/resumeOrCleanup); + + // ... or resume or cleanup the coroutine? + builder.setInsertionPointToStart(resumeOrCleanup); + auto isZero = builder.create( + loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero); + builder.create(loc, isZero, /*trueDest=*/resume, + /*falseDest=*/coro.cleanup); + + return resume; +} + +// Outline the body region attached to the `async.execute` op into a standalone +// function. +static std::pair +outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { + ModuleOp module = execute.getParentOfType(); + + MLIRContext *ctx = module.getContext(); + Location loc = execute.getLoc(); + + OpBuilder moduleBuilder(module.getBody()->getTerminator()); + + // Get values captured by the async region + llvm::SetVector usedAbove; + getUsedValuesDefinedAbove(execute.body(), usedAbove); + + // Collect types of the captured values. + auto usedAboveTypes = + llvm::map_range(usedAbove, [](Value value) { return value.getType(); }); + SmallVector inputTypes(usedAboveTypes.begin(), usedAboveTypes.end()); + auto outputTypes = execute.getResultTypes(); + + auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes); + auto funcAttrs = ArrayRef(); + + // TODO: Derive outlined function name from the parent FuncOp (support + // multiple nested async.execute operations). + FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); + symbolTable.insert(func, moduleBuilder.getInsertionPoint()); + + SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); + + // Prepare a function for coroutine lowering by adding entry/cleanup/suspend + // blocks, adding llvm.coro instrinsics and setting up control flow. + CoroMachinery coro = setupCoroMachinery(func); + + // Suspend async function at the end of an entry block, and resume it using + // Async execute API (execution will be resumed in a thread managed by the + // async runtime). + Block *entryBlock = &func.getBlocks().front(); + OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock); + + // A pointer to coroutine resume intrinsic wrapper. + auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); + auto resumePtr = builder.create( + loc, resumeFnTy.getPointerTo(), kResume); + + // Save the coroutine state: @llvm.coro.save + auto coroSave = builder.create( + loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), + ValueRange({coro.coroHandle})); + + // Call async runtime API to execute a coroutine in the managed thread. + SmallVector executeArgs = {coro.coroHandle, resumePtr.res()}; + builder.create(loc, Type(), kExecute, executeArgs); + + // Split the entry block before the terminator. + Block *resume = addSuspensionPoint(coro, coroSave.getResult(0), + entryBlock->getTerminator()); + + // Map from values defined above the execute op to the function arguments. + BlockAndValueMapping valueMapping; + valueMapping.map(usedAbove, func.getArguments()); + + // Clone all operations from the execute operation body into the outlined + // function body, and replace all `async.yield` operations with a call + // to async runtime to emplace the result token. + builder.setInsertionPointToStart(resume); + for (Operation &op : execute.body().getOps()) { + if (isa(op)) { + builder.create(loc, kEmplaceToken, Type(), coro.asyncToken); + continue; + } + builder.clone(op, valueMapping); + } + + // Replace the original `async.execute` with a call to outlined function. + OpBuilder callBuilder(execute); + SmallVector usedAboveArgs(usedAbove.begin(), usedAbove.end()); + auto callOutlinedFunc = callBuilder.create( + loc, func.getName(), execute.getResultTypes(), usedAboveArgs); + execute.replaceAllUsesWith(callOutlinedFunc.getResults()); + execute.erase(); + + return {func, coro}; +} + +//===----------------------------------------------------------------------===// +// Convert Async dialect types to LLVM types. +//===----------------------------------------------------------------------===// + +namespace { +class AsyncRuntimeTypeConverter : public TypeConverter { +public: + AsyncRuntimeTypeConverter() { addConversion(convertType); } + + static Type convertType(Type type) { + MLIRContext *ctx = type.getContext(); + // Convert async tokens to opaque pointers. + if (type.isa()) + return LLVM::LLVMType::getInt8PtrTy(ctx); + return type; + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert types for all call operations to lowered async types. +//===----------------------------------------------------------------------===// + +namespace { +class CallOpOpConversion : public ConversionPattern { +public: + explicit CallOpOpConversion(MLIRContext *ctx) + : ConversionPattern(CallOp::getOperationName(), 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + AsyncRuntimeTypeConverter converter; + + SmallVector resultTypes; + converter.convertTypes(op->getResultTypes(), resultTypes); + + CallOp call = cast(op); + rewriter.replaceOpWithNewOp(op, resultTypes, call.callee(), + call.getOperands()); + + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// async.await op lowering to mlirAsyncRuntimeAwaitToken function call. +//===----------------------------------------------------------------------===// + +namespace { +class AwaitOpLowering : public ConversionPattern { +public: + explicit AwaitOpLowering( + MLIRContext *ctx, + const llvm::DenseMap &outlinedFunctions) + : ConversionPattern(AwaitOp::getOperationName(), 1, ctx), + outlinedFunctions(outlinedFunctions) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // We can only await on the token operand. Async valus are not supported. + auto await = cast(op); + if (!await.operand().getType().isa()) + return failure(); + + // Check if `async.await` is inside the outlined coroutine function. + auto func = await.getParentOfType(); + auto outlined = outlinedFunctions.find(func); + const bool isInCoroutine = outlined != outlinedFunctions.end(); + + Location loc = op->getLoc(); + + // Inside regular function we convert await operation to the blocking + // async API await function call. + if (!isInCoroutine) + rewriter.create(loc, Type(), kAwaitToken, + ValueRange(op->getOperand(0))); + + // Inside the coroutine we convert await operation into coroutine suspension + // point, and resume execution asynchronously. + if (isInCoroutine) { + const CoroMachinery &coro = outlined->getSecond(); + + OpBuilder builder(op); + MLIRContext *ctx = op->getContext(); + + // A pointer to coroutine resume intrinsic wrapper. + auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); + auto resumePtr = builder.create( + loc, resumeFnTy.getPointerTo(), kResume); + + // Save the coroutine state: @llvm.coro.save + auto coroSave = builder.create( + loc, LLVM::LLVMTokenType::get(ctx), + builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle)); + + // Call async runtime API to resume a coroutine in the managed thread when + // the async await argument becomes ready. + SmallVector awaitAndExecuteArgs = { + await.getOperand(), coro.coroHandle, resumePtr.res()}; + builder.create(loc, Type(), kAwaitAndExecute, + awaitAndExecuteArgs); + + // Split the entry block before the await operation. + addSuspensionPoint(coro, coroSave.getResult(0), op); + } + + // Original operation was replaced by function call or suspension point. + rewriter.eraseOp(op); + + return success(); + } + +private: + const llvm::DenseMap &outlinedFunctions; +}; +} // namespace + +//===----------------------------------------------------------------------===// + +namespace { +struct ConvertAsyncToLLVMPass + : public ConvertAsyncToLLVMBase { + void runOnOperation() override; +}; + +void ConvertAsyncToLLVMPass::runOnOperation() { + ModuleOp module = getOperation(); + SymbolTable symbolTable(module); + + // Outline all `async.execute` body regions into async functions (coroutines). + llvm::DenseMap outlinedFunctions; + + WalkResult outlineResult = module.walk([&](ExecuteOp execute) { + // We currently do not support execute operations that take async + // token dependencies, async value arguments or produce async results. + if (!execute.dependencies().empty() || !execute.operands().empty() || + !execute.results().empty()) { + execute.emitOpError( + "Can't outline async.execute op with async dependencies, arguments " + "or returned async results"); + return WalkResult::interrupt(); + } + + outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); + + return WalkResult::advance(); + }); + + // Failed to outline all async execute operations. + if (outlineResult.wasInterrupted()) { + signalPassFailure(); + return; + } + + LLVM_DEBUG({ + llvm::dbgs() << "Outlined " << outlinedFunctions.size() + << " async functions\n"; + }); + + // Add declarations for all functions required by the coroutines lowering. + addResumeFunction(module); + addAsyncRuntimeApiDeclarations(module); + addCoroutineIntrinsicsDeclarations(module); + addCRuntimeDeclarations(module); + + MLIRContext *ctx = &getContext(); + + // Convert async dialect types and operations to LLVM dialect. + AsyncRuntimeTypeConverter converter; + OwningRewritePatternList patterns; + + populateFuncOpTypeConversionPattern(patterns, ctx, converter); + patterns.insert(ctx); + patterns.insert(ctx, outlinedFunctions); + + ConversionTarget target(*ctx); + target.addLegalDialect(); + target.addIllegalDialect(); + target.addDynamicallyLegalOp( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + target.addDynamicallyLegalOp( + [&](CallOp op) { return converter.isLegal(op.getResultTypes()); }); + + if (failed(applyPartialConversion(module, target, patterns))) + signalPassFailure(); +} +} // namespace + +std::unique_ptr> mlir::createConvertAsyncToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt b/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRAsyncToLLVM + AsyncToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/AsyncToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRLLVMIR + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(AffineToStandard) +add_subdirectory(AsyncToLLVM) add_subdirectory(AVX512ToLLVM) add_subdirectory(GPUCommon) add_subdirectory(GPUToNVVM) diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -0,0 +1,92 @@ +//===- AsyncRuntime.cpp - Async runtime reference implementation ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements basic Async runtime API for supporting Async dialect +// to LLVM dialect lowering. +// +//===----------------------------------------------------------------------===// + +#include "mlir/ExecutionEngine/AsyncRuntime.h" + +#ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS + +#include +#include +#include +#include +#include +#include + +//===----------------------------------------------------------------------===// +// Async runtime API. +//===----------------------------------------------------------------------===// + +struct AsyncToken { + bool ready = false; + std::mutex mu; + std::condition_variable cv; + std::vector> awaiters; +}; + +// Create a new `async.token` in not-ready state. +extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken() { + AsyncToken *token = new AsyncToken; + return token; +} + +// Switches `async.token` to ready state and runs all awaiters. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { + std::unique_lock lock(token->mu); + token->ready = true; + token->cv.notify_all(); + for (auto &awaiter : token->awaiters) + awaiter(); +} + +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeAwaitToken(AsyncToken *token) { + std::unique_lock lock(token->mu); + if (!token->ready) + token->cv.wait(lock, [token] { return token->ready; }); + delete token; +} + +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { + std::thread thread([handle, resume]() { (*resume)(handle); }); + thread.detach(); +} + +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, CoroHandle handle, + CoroResume resume) { + std::unique_lock lock(token->mu); + + auto execute = [token, handle, resume]() { + mlirAsyncRuntimeExecute(handle, resume); + delete token; + }; + + if (token->ready) + execute(); + else + token->awaiters.push_back([execute]() { execute(); }); +} + +//===----------------------------------------------------------------------===// +// Small async runtime support library for testing. +//===----------------------------------------------------------------------===// + +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimePrintCurrentThreadId() { + static thread_local std::thread::id thisId = std::this_thread::get_id(); + std::cout << "Current thread id: " << thisId << "\n"; +} + +#endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt --- a/mlir/lib/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/ExecutionEngine/CMakeLists.txt @@ -2,6 +2,7 @@ # is a big dependency which most don't need. set(LLVM_OPTIONAL_SOURCES + AsyncRuntime.cpp CRunnerUtils.cpp SparseUtils.cpp ExecutionEngine.cpp @@ -96,3 +97,14 @@ mlir_c_runner_utils_static ) target_compile_definitions(mlir_runner_utils PRIVATE mlir_runner_utils_EXPORTS) + +add_mlir_library(mlir_async_runtime + SHARED + AsyncRuntime.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + mlir_c_runner_utils_static +) +target_compile_definitions(mlir_async_runtime PRIVATE mlir_async_runtime_EXPORTS) diff --git a/mlir/lib/ExecutionEngine/OptUtils.cpp b/mlir/lib/ExecutionEngine/OptUtils.cpp --- a/mlir/lib/ExecutionEngine/OptUtils.cpp +++ b/mlir/lib/ExecutionEngine/OptUtils.cpp @@ -25,6 +25,7 @@ #include "llvm/Support/Error.h" #include "llvm/Support/StringSaver.h" #include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/Coroutines.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include @@ -56,6 +57,7 @@ llvm::initializeAggressiveInstCombine(registry); llvm::initializeAnalysis(registry); llvm::initializeVectorization(registry); + llvm::initializeCoroutines(registry); } // Populate pass managers according to the optimization and size levels. @@ -73,6 +75,9 @@ builder.SLPVectorize = optLevel > 1 && sizeLevel < 2; builder.DisableUnrollLoops = (optLevel == 0); + // Add all coroutine passes to the builder. + addCoroutinePassesToExtensionPoints(builder); + if (targetMachine) { // Add pass to initialize TTI for this specific target. Otherwise, TTI will // be initialized to NoTTIImpl by default. diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -54,6 +54,7 @@ mlir_test_cblas_interface mlir_runner_utils mlir_c_runner_utils + mlir_async_runtime ) if(LLVM_BUILD_EXAMPLES) diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -0,0 +1,111 @@ +// RUN: mlir-opt %s -split-input-file -convert-async-to-llvm | FileCheck %s + +// CHECK-LABEL: execute_no_async_args +func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) { + // CHECK: %[[TOKEN:.*]] = call @async_execute_fn(%arg0, %arg1) + %token = async.execute { + %c0 = constant 0 : index + store %arg0, %arg1[%c0] : memref<1xf32> + async.yield + } + // CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]]) + // CHECK-NEXT: return + async.await %token : !async.token + return +} + +// Function outlined from the async.execute operation. +// CHECK: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>) +// CHECK-SAME: -> !llvm.ptr attributes {sym_visibility = "private"} + +// Create token for return op, and mark a function as a coroutine. +// CHECK: %[[RET:.*]] = call @mlirAsyncRuntimeCreateToken() +// CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin + +// Pass a suspended coroutine to the async runtime. +// CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume +// CHECK: %[[STATE:.*]] = llvm.call @llvm.coro.save +// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL]], %[[RESUME]]) +// CHECK: %[[SUSPENDED:.*]] = llvm.call @llvm.coro.suspend(%[[STATE]] + +// Decide the next block based on the code returned from suspend. +// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i8) +// CHECK: %[[NONE:.*]] = llvm.mlir.constant(-1 : i8) +// CHECK: %[[IS_NONE:.*]] = llvm.icmp "eq" %[[SUSPENDED]], %[[NONE]] +// CHECK: llvm.cond_br %[[IS_NONE]], ^[[SUSPEND:.*]], ^[[RESUME_OR_CLEANUP:.*]] + +// Decide if branch to resume or cleanup block. +// CHECK: ^[[RESUME_OR_CLEANUP]]: +// CHECK: %[[IS_ZERO:.*]] = llvm.icmp "eq" %[[SUSPENDED]], %[[ZERO]] +// CHECK: llvm.cond_br %[[IS_ZERO]], ^[[RESUME:.*]], ^[[CLEANUP:.*]] + +// Resume coroutine after suspension. +// CHECK: ^[[RESUME]]: +// CHECK: store %arg0, %arg1[%c0] : memref<1xf32> +// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET]]) + +// Delete coroutine. +// CHECK: ^[[CLEANUP]]: +// CHECK: %[[MEM:.*]] = llvm.call @llvm.coro.free +// CHECK: llvm.call @free(%[[MEM]]) + +// Suspend coroutine, and also a return statement for ramp function. +// CHECK: ^[[SUSPEND]]: +// CHECK: llvm.call @llvm.coro.end +// CHECK: return %[[RET]] + +// ----- + +// CHECK-LABEL: nested_async_execute +func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) { + // CHECK: %[[TOKEN:.*]] = call @async_execute_fn_0(%arg0, %arg2, %arg1) + %token0 = async.execute { + %c0 = constant 0 : index + + %token1 = async.execute { + %c1 = constant 1: index + store %arg0, %arg2[%c0] : memref<1xf32> + async.yield + } + async.await %token1 : !async.token + + store %arg1, %arg2[%c0] : memref<1xf32> + async.yield + } + // CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]]) + // CHECK-NEXT: return + async.await %token0 : !async.token + return +} + +// Function outlined from the inner async.execute operation. +// CHECK: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>, %arg2: index) +// CHECK-SAME: -> !llvm.ptr attributes {sym_visibility = "private"} +// CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken() +// CHECK: %[[HDL_0:.*]] = llvm.call @llvm.coro.begin +// CHECK: call @mlirAsyncRuntimeExecute +// CHECK: llvm.call @llvm.coro.suspend +// CHECK: store %arg0, %arg1[%arg2] : memref<1xf32> +// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]]) + +// Function outlined from the outer async.execute operation. +// CHECK: func @async_execute_fn_0(%arg0: f32, %arg1: memref<1xf32>, %arg2: f32) +// CHECK-SAME: -> !llvm.ptr attributes {sym_visibility = "private"} +// CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken() +// CHECK: %[[HDL_1:.*]] = llvm.call @llvm.coro.begin + +// Suspend coroutine in the beginning. +// CHECK: call @mlirAsyncRuntimeExecute +// CHECK: llvm.call @llvm.coro.suspend + +// Suspend coroutine second time waiting for the completion of inner execute op. +// CHECK: %[[TOKEN_1:.*]] = call @async_execute_fn +// CHECK: llvm.call @llvm.coro.save +// CHECK: call @mlirAsyncRuntimeAwaitTokenAndExecute(%[[TOKEN_1]], %[[HDL_1]] +// CHECK: llvm.call @llvm.coro.suspend + +// Emplace result token after second resumption. +// CHECK: store %arg2, %arg1[%c0] : memref<1xf32> +// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]]) + + diff --git a/mlir/test/mlir-cpu-runner/async.mlir b/mlir/test/mlir-cpu-runner/async.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/async.mlir @@ -0,0 +1,77 @@ +// RUN: mlir-opt %s -convert-async-to-llvm \ +// RUN: -convert-linalg-to-loops \ +// RUN: -convert-linalg-to-llvm \ +// RUN: -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main -entry-point-result=void -O0 \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \ +// RUN: | FileCheck %s + +func @main() { + %i0 = constant 0 : index + %i1 = constant 1 : index + %i2 = constant 2 : index + %i3 = constant 3 : index + + %c0 = constant 0.0 : f32 + %c1 = constant 1.0 : f32 + %c2 = constant 2.0 : f32 + %c3 = constant 3.0 : f32 + %c4 = constant 4.0 : f32 + + %A = alloc() : memref<4xf32> + linalg.fill(%A, %c0) : memref<4xf32>, f32 + + // CHECK: [0, 0, 0, 0] + %U = memref_cast %A : memref<4xf32> to memref<*xf32> + call @print_memref_f32(%U): (memref<*xf32>) -> () + + // CHECK: Current thread id: [[MAIN:.*]] + // CHECK: [1, 0, 0, 0] + store %c1, %A[%i0]: memref<4xf32> + call @mlirAsyncRuntimePrintCurrentThreadId(): () -> () + call @print_memref_f32(%U): (memref<*xf32>) -> () + + %outer = async.execute { + // CHECK: Current thread id: [[THREAD0:.*]] + // CHECK: [1, 2, 0, 0] + store %c2, %A[%i1]: memref<4xf32> + call @mlirAsyncRuntimePrintCurrentThreadId(): () -> () + call @print_memref_f32(%U): (memref<*xf32>) -> () + + %inner = async.execute { + // CHECK: Current thread id: [[THREAD1:.*]] + // CHECK: [1, 2, 3, 0] + store %c3, %A[%i2]: memref<4xf32> + call @mlirAsyncRuntimePrintCurrentThreadId(): () -> () + call @print_memref_f32(%U): (memref<*xf32>) -> () + + async.yield + } + async.await %inner : !async.token + + // CHECK: Current thread id: [[THREAD2:.*]] + // CHECK: [1, 2, 3, 4] + store %c4, %A[%i3]: memref<4xf32> + call @mlirAsyncRuntimePrintCurrentThreadId(): () -> () + call @print_memref_f32(%U): (memref<*xf32>) -> () + + async.yield + } + async.await %outer : !async.token + + // CHECK: Current thread id: [[MAIN]] + // CHECK: [1, 2, 3, 4] + call @mlirAsyncRuntimePrintCurrentThreadId(): () -> () + call @print_memref_f32(%U): (memref<*xf32>) -> () + + dealloc %A : memref<4xf32> + + return +} + +func @mlirAsyncRuntimePrintCurrentThreadId() -> () + +func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }