diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h --- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h @@ -45,6 +45,12 @@ // Runtime implementation of `async.group` data type. typedef struct AsyncGroup AsyncGroup; +// Runtime implementation of `async.value` data type. +typedef struct AsyncValue AsyncValue; + +// Async value payload stored in a memory owned by the async.value. +using ValueStorage = void *; + // Async runtime uses LLVM coroutines to represent asynchronous tasks. Task // function is a coroutine handle and a resume function that continue coroutine // execution from a suspension point. @@ -66,6 +72,10 @@ // Create a new `async.token` in not-ready state. extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken(); +// Create a new `async.valie` in not-ready state. +extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncValue * + mlirAsyncRuntimeCreateValue(int32_t); + // Create a new `async.group` in empty state. extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup(); @@ -76,14 +86,26 @@ extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeEmplaceToken(AsyncToken *); +// Switches `async.value` to ready state and runs all awaiters. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeEmplaceValue(AsyncValue *); + // Blocks the caller thread until the token becomes ready. extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeAwaitToken(AsyncToken *); +// Blocks the caller thread until the value becomes ready. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeAwaitValue(AsyncValue *); + // Blocks the caller thread until the elements in the group become ready. extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *); +// Returns a pointer to the storage owned by the async value. +extern "C" MLIR_ASYNCRUNTIME_EXPORT ValueStorage +mlirAsyncRuntimeGetValueStorage(AsyncValue *); + // Executes the task (coro handle + resume function) in one of the threads // managed by the runtime. extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle, @@ -94,6 +116,11 @@ extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *, CoroHandle, CoroResume); +// Executes the task (coro handle + resume function) in one of the threads +// managed by the runtime after the value becomes ready. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *, CoroHandle, CoroResume); + // Executes the task (coro handle + resume function) in one of the threads // managed by the runtime after the all members of the group become ready. extern "C" MLIR_ASYNCRUNTIME_EXPORT void diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -9,9 +9,11 @@ #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "../PassDetail.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" @@ -36,20 +38,31 @@ static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; +static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; +static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; +static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; +static constexpr const char *kGetValueStorage = + "mlirAsyncRuntimeGetValueStorage"; static constexpr const char *kAddTokenToGroup = "mlirAsyncRuntimeAddTokenToGroup"; -static constexpr const char *kAwaitAndExecute = +static constexpr const char *kAwaitTokenAndExecute = "mlirAsyncRuntimeAwaitTokenAndExecute"; +static constexpr const char *kAwaitValueAndExecute = + "mlirAsyncRuntimeAwaitValueAndExecute"; static constexpr const char *kAwaitAllAndExecute = "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; namespace { -// Async Runtime API function types. +/// Async Runtime API function types. +/// +/// Because we can't create API function signature for type parametrized +/// async.value type, we use opaque pointers (!llvm.ptr) instead. After +/// lowering all async data types become opaque pointers at runtime. struct AsyncAPI { static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { auto ref = LLVM::LLVMType::getInt8PtrTy(ctx); @@ -61,18 +74,40 @@ return FunctionType::get({}, {TokenType::get(ctx)}, ctx); } + static FunctionType createValueFunctionType(MLIRContext *ctx) { + auto i32 = IntegerType::get(32, ctx); + auto value = LLVM::LLVMType::getInt8PtrTy(ctx); + return FunctionType::get({i32}, {value}, ctx); + } + static FunctionType createGroupFunctionType(MLIRContext *ctx) { return FunctionType::get({}, {GroupType::get(ctx)}, ctx); } + static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { + auto value = LLVM::LLVMType::getInt8PtrTy(ctx); + auto storage = LLVM::LLVMType::getInt8PtrTy(ctx); + return FunctionType::get({value}, {storage}, ctx); + } + static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { return FunctionType::get({TokenType::get(ctx)}, {}, ctx); } + static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { + auto value = LLVM::LLVMType::getInt8PtrTy(ctx); + return FunctionType::get({value}, {}, ctx); + } + static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { return FunctionType::get({TokenType::get(ctx)}, {}, ctx); } + static FunctionType awaitValueFunctionType(MLIRContext *ctx) { + auto value = LLVM::LLVMType::getInt8PtrTy(ctx); + return FunctionType::get({value}, {}, ctx); + } + static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { return FunctionType::get({GroupType::get(ctx)}, {}, ctx); } @@ -89,12 +124,19 @@ ctx); } - static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) { + static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); auto resume = resumeFunctionType(ctx).getPointerTo(); return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx); } + static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { + auto value = LLVM::LLVMType::getInt8PtrTy(ctx); + auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); + auto resume = resumeFunctionType(ctx).getPointerTo(); + return FunctionType::get({value, hdl, resume}, {}, ctx); + } + static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); auto resume = resumeFunctionType(ctx).getPointerTo(); @@ -110,7 +152,7 @@ }; } // namespace -// Adds Async Runtime C API declarations to the module. +/// Adds Async Runtime C API declarations to the module. static void addAsyncRuntimeApiDeclarations(ModuleOp module) { auto builder = OpBuilder::atBlockTerminator(module.getBody()); @@ -124,13 +166,20 @@ addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); + addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); + addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); + addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); + addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); - addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx)); + addFuncDecl(kAwaitTokenAndExecute, + AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); + addFuncDecl(kAwaitValueAndExecute, + AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); } @@ -212,9 +261,9 @@ static constexpr const char *kResume = "__resume"; -// A function that takes a coroutine handle and calls a `llvm.coro.resume` -// intrinsics. We need this function to be able to pass it to the async -// runtime execute API. +/// A function that takes a coroutine handle and calls a `llvm.coro.resume` +/// intrinsics. We need this function to be able to pass it to the async +/// runtime execute API. static void addResumeFunction(ModuleOp module) { MLIRContext *ctx = module.getContext(); @@ -245,49 +294,61 @@ // async.execute op outlining to the coroutine functions. //===----------------------------------------------------------------------===// -// Function targeted for coroutine transformation has two additional blocks at -// the end: coroutine cleanup and coroutine suspension. -// -// async.await op lowering additionaly creates a resume block for each -// operation to enable non-blocking waiting via coroutine suspension. +/// Function targeted for coroutine transformation has two additional blocks at +/// the end: coroutine cleanup and coroutine suspension. +/// +/// async.await op lowering additionaly creates a resume block for each +/// operation to enable non-blocking waiting via coroutine suspension. namespace { struct CoroMachinery { - Value asyncToken; + // Async execute region returns a completion token, and an async value for + // each yielded value. + // + // %token, %result = async.execute -> !async.value { + // %0 = constant ... : T + // async.yield %0 : T + // } + Value asyncToken; // token representing completion of the async region + llvm::SmallVector returnValues; // returned async values + Value coroHandle; Block *cleanup; Block *suspend; }; } // namespace -// Builds an coroutine template compatible with LLVM coroutines lowering. -// -// - `entry` block sets up the coroutine. -// - `cleanup` block cleans up the coroutine state. -// - `suspend block after the @llvm.coro.end() defines what value will be -// returned to the initial caller of a coroutine. Everything before the -// @llvm.coro.end() will be executed at every suspension point. -// -// Coroutine structure (only the important bits): -// -// func @async_execute_fn() -> !async.token { -// ^entryBlock(): -// %token = : !async.token // create async runtime token -// %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle -// br ^cleanup -// -// ^cleanup: -// llvm.call @llvm.coro.free(...) // delete coroutine state -// br ^suspend -// -// ^suspend: -// llvm.call @llvm.coro.end(...) // marks the end of a coroutine -// return %token : !async.token -// } -// -// The actual code for the async.execute operation body region will be inserted -// before the entry block terminator. -// -// +/// Builds an coroutine template compatible with LLVM coroutines lowering. +/// +/// - `entry` block sets up the coroutine. +/// - `cleanup` block cleans up the coroutine state. +/// - `suspend block after the @llvm.coro.end() defines what value will be +/// returned to the initial caller of a coroutine. Everything before the +/// @llvm.coro.end() will be executed at every suspension point. +/// +/// Coroutine structure (only the important bits): +/// +/// func @async_execute_fn() +/// -> (!async.token, !async.value) +/// { +/// ^entryBlock(): +/// %token = : !async.token // create async runtime token +/// %value = : !async.value // create async value +/// %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle +/// br ^cleanup +/// +/// ^cleanup: +/// llvm.call @llvm.coro.free(...) // delete coroutine state +/// br ^suspend +/// +/// ^suspend: +/// llvm.call @llvm.coro.end(...) // marks the end of a coroutine +/// return %token, %value : !async.token, !async.value +/// } +/// +/// The actual code for the async.execute operation body region will be inserted +/// before the entry block terminator. +/// +/// static CoroMachinery setupCoroMachinery(FuncOp func) { assert(func.getBody().empty() && "Function must have empty body"); @@ -310,6 +371,44 @@ auto createToken = builder.create(loc, kCreateToken, TokenType::get(ctx)); + // Async value operands and results must be convertible to LLVM types. This is + // verified before the function outlining. + LLVMTypeConverter converter(ctx); + + // Returns the size requirements for the async value storage. + // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt + auto sizeOf = [&](ValueType valueType) -> Value { + auto storedType = converter.convertType(valueType.getValueType()); + auto storagePtrType = storedType.cast().getPointerTo(); + + // %Size = getelementptr %T* null, int 1 + // %SizeI = ptrtoint %T* %Size to i32 + auto nullPtr = builder.create(loc, storagePtrType); + auto i32Type = LLVM::LLVMType::getInt32Ty(ctx); + auto one = builder.create(loc, i32Type, + builder.getI32IntegerAttr(1)); + auto gep = builder.create(loc, storagePtrType, nullPtr, + one.getResult()); + auto size = builder.create(loc, i32Type, gep); + + // Cast to std type because runtime API defined using std types. + return builder.create(loc, builder.getI32Type(), + size.getResult()); + }; + + // We use the `async.value` type as a return type although it does not match + // the `kCreateValue` function signature, because it will be later lowered to + // the runtime type (opaque i8* pointer). + llvm::SmallVector createValues; + for (auto resultType : func.getCallableResults().drop_front(1)) + createValues.emplace_back(builder.create( + loc, kCreateValue, resultType, sizeOf(resultType.cast()))); + + auto createdValues = llvm::map_range( + createValues, [](CallOp call) { return call.getResult(0); }); + llvm::SmallVector returnValues(createdValues.begin(), + createdValues.end()); + // ------------------------------------------------------------------------ // // Initialize coroutine: allocate frame, get coroutine handle. // ------------------------------------------------------------------------ // @@ -370,9 +469,11 @@ builder.create(loc, i1, builder.getSymbolRefAttr(kCoroEnd), ValueRange({coroHdl.getResult(0), constFalse})); - // Return created `async.token` from the suspend block. This will be the - // return value of a coroutine ramp function. - builder.create(loc, createToken.getResult(0)); + // Return created `async.token` and `async.values` from the suspend block. + // This will be the return value of a coroutine ramp function. + SmallVector ret{createToken.getResult(0)}; + ret.insert(ret.end(), returnValues.begin(), returnValues.end()); + builder.create(loc, ret); // Branch from the entry block to the cleanup block to create a valid CFG. builder.setInsertionPointToEnd(entryBlock); @@ -382,39 +483,44 @@ // `async.await` op lowering will create resume blocks for async // continuations, and will conditionally branch to cleanup or suspend blocks. - return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock, - suspendBlock}; + CoroMachinery machinery; + machinery.asyncToken = createToken.getResult(0); + machinery.returnValues = returnValues; + machinery.coroHandle = coroHdl.getResult(0); + machinery.cleanup = cleanupBlock; + machinery.suspend = suspendBlock; + return machinery; } -// Add a LLVM coroutine suspension point to the end of suspended block, to -// resume execution in resume block. The caller is responsible for creating the -// two suspended/resume blocks with the desired ops contained in each block. -// This function merely provides the required control flow logic. -// -// `coroState` must be a value returned from the call to @llvm.coro.save(...) -// intrinsic (saved coroutine state). -// -// Before: -// -// ^bb0: -// "opBefore"(...) -// "op"(...) -// ^cleanup: ... -// ^suspend: ... -// ^resume: -// "op"(...) -// -// After: -// -// ^bb0: -// "opBefore"(...) -// %suspend = llmv.call @llvm.coro.suspend(...) -// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] -// ^resume: -// "op"(...) -// ^cleanup: ... -// ^suspend: ... -// +/// Add a LLVM coroutine suspension point to the end of suspended block, to +/// resume execution in resume block. The caller is responsible for creating the +/// two suspended/resume blocks with the desired ops contained in each block. +/// This function merely provides the required control flow logic. +/// +/// `coroState` must be a value returned from the call to @llvm.coro.save(...) +/// intrinsic (saved coroutine state). +/// +/// Before: +/// +/// ^bb0: +/// "opBefore"(...) +/// "op"(...) +/// ^cleanup: ... +/// ^suspend: ... +/// ^resume: +/// "op"(...) +/// +/// After: +/// +/// ^bb0: +/// "opBefore"(...) +/// %suspend = llmv.call @llvm.coro.suspend(...) +/// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] +/// ^resume: +/// "op"(...) +/// ^cleanup: ... +/// ^suspend: ... +/// static void addSuspensionPoint(CoroMachinery coro, Value coroState, Operation *op, Block *suspended, Block *resume, OpBuilder &builder) { @@ -460,10 +566,10 @@ /*falseDest=*/coro.cleanup); } -// Outline the body region attached to the `async.execute` op into a standalone -// function. -// -// Note that this is not reversible transformation. +/// Outline the body region attached to the `async.execute` op into a standalone +/// function. +/// +/// Note that this is not reversible transformation. static std::pair outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { ModuleOp module = execute->getParentOfType(); @@ -476,6 +582,7 @@ // Collect all outlined function inputs. llvm::SetVector functionInputs(execute.dependencies().begin(), execute.dependencies().end()); + assert(execute.operands().empty() && "operands are not supported"); getUsedValuesDefinedAbove(execute.body(), functionInputs); // Collect types for the outlined function inputs and outputs. @@ -536,15 +643,9 @@ valueMapping.map(functionInputs, func.getArguments()); // Clone all operations from the execute operation body into the outlined - // function body, and replace all `async.yield` operations with a call - // to async runtime to emplace the result token. - for (Operation &op : execute.body().getOps()) { - if (isa(op)) { - builder.create(loc, kEmplaceToken, TypeRange(), coro.asyncToken); - continue; - } + // function body. + for (Operation &op : execute.body().getOps()) builder.clone(op, valueMapping); - } // Replace the original `async.execute` with a call to outlined function. OpBuilder callBuilder(execute); @@ -562,42 +663,38 @@ //===----------------------------------------------------------------------===// namespace { + +/// AsyncRuntimeTypeConverter only converts types from the Async dialect to +/// their runtime type (opaque pointers) and does not convert any other types. class AsyncRuntimeTypeConverter : public TypeConverter { public: - AsyncRuntimeTypeConverter() { addConversion(convertType); } - - static Type convertType(Type type) { - MLIRContext *ctx = type.getContext(); - // Convert async tokens and groups to opaque pointers. - if (type.isa()) - return LLVM::LLVMType::getInt8PtrTy(ctx); - return type; + AsyncRuntimeTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion(convertAsyncTypes); + } + + static Optional convertAsyncTypes(Type type) { + if (type.isa()) + return LLVM::LLVMType::getInt8PtrTy(type.getContext()); + return llvm::None; } }; } // namespace //===----------------------------------------------------------------------===// -// Convert types for all call operations to lowered async types. +// Convert return operations that return async values from async regions. //===----------------------------------------------------------------------===// namespace { -class CallOpOpConversion : public ConversionPattern { +class ReturnOpOpConversion : public ConversionPattern { public: - explicit CallOpOpConversion(MLIRContext *ctx) - : ConversionPattern(CallOp::getOperationName(), 1, ctx) {} + explicit ReturnOpOpConversion(TypeConverter &converter, MLIRContext *ctx) + : ConversionPattern(ReturnOp::getOperationName(), 1, converter, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - AsyncRuntimeTypeConverter converter; - - SmallVector resultTypes; - converter.convertTypes(op->getResultTypes(), resultTypes); - - CallOp call = cast(op); - rewriter.replaceOpWithNewOp(op, resultTypes, call.callee(), - operands); - + rewriter.replaceOpWithNewOp(op, operands); return success(); } }; @@ -613,8 +710,9 @@ template class RefCountingOpLowering : public ConversionPattern { public: - explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName) - : ConversionPattern(RefCountingOp::getOperationName(), 1, ctx), + explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, + StringRef apiFunctionName) + : ConversionPattern(RefCountingOp::getOperationName(), 1, converter, ctx), apiFunctionName(apiFunctionName) {} LogicalResult @@ -636,18 +734,18 @@ StringRef apiFunctionName; }; -// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call. +/// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call. class AddRefOpLowering : public RefCountingOpLowering { public: - explicit AddRefOpLowering(MLIRContext *ctx) - : RefCountingOpLowering(ctx, kAddRef) {} + explicit AddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) + : RefCountingOpLowering(converter, ctx, kAddRef) {} }; -// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. +/// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. class DropRefOpLowering : public RefCountingOpLowering { public: - explicit DropRefOpLowering(MLIRContext *ctx) - : RefCountingOpLowering(ctx, kDropRef) {} + explicit DropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) + : RefCountingOpLowering(converter, ctx, kDropRef) {} }; } // namespace @@ -659,8 +757,9 @@ namespace { class CreateGroupOpLowering : public ConversionPattern { public: - explicit CreateGroupOpLowering(MLIRContext *ctx) - : ConversionPattern(CreateGroupOp::getOperationName(), 1, ctx) {} + explicit CreateGroupOpLowering(TypeConverter &converter, MLIRContext *ctx) + : ConversionPattern(CreateGroupOp::getOperationName(), 1, converter, + ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -679,8 +778,9 @@ namespace { class AddToGroupOpLowering : public ConversionPattern { public: - explicit AddToGroupOpLowering(MLIRContext *ctx) - : ConversionPattern(AddToGroupOp::getOperationName(), 1, ctx) {} + explicit AddToGroupOpLowering(TypeConverter &converter, MLIRContext *ctx) + : ConversionPattern(AddToGroupOp::getOperationName(), 1, converter, ctx) { + } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -708,10 +808,10 @@ class AwaitOpLoweringBase : public ConversionPattern { protected: explicit AwaitOpLoweringBase( - MLIRContext *ctx, + TypeConverter &converter, MLIRContext *ctx, const llvm::DenseMap &outlinedFunctions, StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName) - : ConversionPattern(AwaitType::getOperationName(), 1, ctx), + : ConversionPattern(AwaitType::getOperationName(), 1, converter, ctx), outlinedFunctions(outlinedFunctions), blockingAwaitFuncName(blockingAwaitFuncName), coroAwaitFuncName(coroAwaitFuncName) {} @@ -721,7 +821,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // We can only await on one the `AwaitableType` (for `await` it can be - // only a `token`, for `await_all` it is a `group`). + // a `token` or a `value`, for `await_all` it must be a `group`). auto await = cast(op); if (!await.operand().getType().template isa()) return failure(); @@ -770,44 +870,161 @@ Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume, builder); + + // Make sure that replacement value will be constructed in resume block. + rewriter.setInsertionPointToStart(resume); } - // Original operation was replaced by function call or suspension point. - rewriter.eraseOp(op); + // Replace or erase the await operation with the new value. + if (Value replaceWith = getReplacementValue(op, operands[0], rewriter)) + rewriter.replaceOp(op, replaceWith); + else + rewriter.eraseOp(op); return success(); } + virtual Value getReplacementValue(Operation *op, Value operand, + ConversionPatternRewriter &rewriter) const { + return Value(); + } + private: const llvm::DenseMap &outlinedFunctions; StringRef blockingAwaitFuncName; StringRef coroAwaitFuncName; }; -// Lowering for `async.await` operation (only token operands are supported). -class AwaitOpLowering : public AwaitOpLoweringBase { +/// Lowering for `async.await` with a token operand. +class AwaitTokenOpLowering : public AwaitOpLoweringBase { using Base = AwaitOpLoweringBase; public: - explicit AwaitOpLowering( - MLIRContext *ctx, + explicit AwaitTokenOpLowering( + TypeConverter &converter, MLIRContext *ctx, const llvm::DenseMap &outlinedFunctions) - : Base(ctx, outlinedFunctions, kAwaitToken, kAwaitAndExecute) {} + : Base(converter, ctx, outlinedFunctions, kAwaitToken, + kAwaitTokenAndExecute) {} +}; + +/// Lowering for `async.await` with a value operand. +class AwaitValueOpLowering : public AwaitOpLoweringBase { + using Base = AwaitOpLoweringBase; + +public: + explicit AwaitValueOpLowering( + TypeConverter &converter, MLIRContext *ctx, + const llvm::DenseMap &outlinedFunctions) + : Base(converter, ctx, outlinedFunctions, kAwaitValue, + kAwaitValueAndExecute) {} + + Value + getReplacementValue(Operation *op, Value operand, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(rewriter.getContext()); + + // Get the underlying value type from the `async.value`. + auto await = cast(op); + auto valueType = await.operand().getType().cast().getValueType(); + + // Get a pointer to an async value storage from the runtime. + auto storage = rewriter.create(loc, kGetValueStorage, + TypeRange(i8Ptr), operand); + + // Cast from i8* to the pointer pointer to LLVM type. + auto llvmValueType = getTypeConverter()->convertType(valueType); + auto castedStorage = rewriter.create( + loc, llvmValueType.cast().getPointerTo(), + storage.getResult(0)); + + // Load from the async value storage. + auto loaded = rewriter.create(loc, castedStorage.getResult()); + + // Cast from LLVM type to the expected value type. This cast will become + // no-op after lowering to LLVM. + return rewriter.create(loc, valueType, loaded); + } }; -// Lowering for `async.await_all` operation. +/// Lowering for `async.await_all` operation. class AwaitAllOpLowering : public AwaitOpLoweringBase { using Base = AwaitOpLoweringBase; public: explicit AwaitAllOpLowering( - MLIRContext *ctx, + TypeConverter &converter, MLIRContext *ctx, const llvm::DenseMap &outlinedFunctions) - : Base(ctx, outlinedFunctions, kAwaitGroup, kAwaitAllAndExecute) {} + : Base(converter, ctx, outlinedFunctions, kAwaitGroup, + kAwaitAllAndExecute) {} }; } // namespace +//===----------------------------------------------------------------------===// +// async.yield op lowerings to the corresponding async runtime function calls. +//===----------------------------------------------------------------------===// + +class YieldOpLowering : public ConversionPattern { +public: + explicit YieldOpLowering( + TypeConverter &converter, MLIRContext *ctx, + const llvm::DenseMap &outlinedFunctions) + : ConversionPattern(async::YieldOp::getOperationName(), 1, converter, + ctx), + outlinedFunctions(outlinedFunctions) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // Check if yield operation is inside the outlined coroutine function. + auto func = op->template getParentOfType(); + auto outlined = outlinedFunctions.find(func); + if (outlined == outlinedFunctions.end()) + return op->emitOpError( + "async.yield is not inside the outlined coroutine function"); + + Location loc = op->getLoc(); + const CoroMachinery &coro = outlined->getSecond(); + + // Store yielded values into the async values storage and emplace them. + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(rewriter.getContext()); + + for (auto tuple : llvm::zip(operands, coro.returnValues)) { + // Store `yieldValue` into the `asyncValue` storage. + Value yieldValue = std::get<0>(tuple); + Value asyncValue = std::get<1>(tuple); + + // Get an opaque i8* pointer to an async value storage from the runtime. + auto storage = rewriter.create(loc, kGetValueStorage, + TypeRange(i8Ptr), asyncValue); + + // Cast storage pointer to the yielded value type. + auto castedStorage = rewriter.create( + loc, yieldValue.getType().cast().getPointerTo(), + storage.getResult(0)); + + // Store the yielded value into the async value storage. + rewriter.create(loc, yieldValue, + castedStorage.getResult()); + + // Emplace the `async.value` to mark it ready. + rewriter.create(loc, kEmplaceValue, TypeRange(), asyncValue); + } + + // Emplace the completion token to mark it ready. + rewriter.create(loc, kEmplaceToken, TypeRange(), coro.asyncToken); + + // Original operation was replaced by the function call(s). + rewriter.eraseOp(op); + + return success(); + } + +private: + const llvm::DenseMap &outlinedFunctions; +}; + //===----------------------------------------------------------------------===// namespace { @@ -820,15 +1037,38 @@ ModuleOp module = getOperation(); SymbolTable symbolTable(module); + MLIRContext *ctx = &getContext(); + // Outline all `async.execute` body regions into async functions (coroutines). llvm::DenseMap outlinedFunctions; + // We use conversion to LLVM type to ensure that all `async.value` operands + // and results can be lowered to LLVM load and store operations. + LLVMTypeConverter llvmConverter(ctx); + llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); + + // Returns true if the `async.value` payload is convertible to LLVM. + auto isConvertibleToLlvm = [&](Type type) -> bool { + auto valueType = type.cast().getValueType(); + return static_cast(llvmConverter.convertType(valueType)); + }; + WalkResult outlineResult = module.walk([&](ExecuteOp execute) { + // All operands and results must be convertible to LLVM. + if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) { + execute.emitOpError("operands payload must be convertible to LLVM type"); + return WalkResult::interrupt(); + } + if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) { + execute.emitOpError("results payload must be convertible to LLVM type"); + return WalkResult::interrupt(); + } + // We currently do not support execute operations that have async value // operands or produce async results. - if (!execute.operands().empty() || !execute.results().empty()) { - execute.emitOpError("can't outline async.execute op with async value " - "operands or returned async results"); + if (!execute.operands().empty()) { + execute.emitOpError( + "can't outline async.execute op with async value operands"); return WalkResult::interrupt(); } @@ -854,26 +1094,44 @@ addCoroutineIntrinsicsDeclarations(module); addCRuntimeDeclarations(module); - MLIRContext *ctx = &getContext(); - // Convert async dialect types and operations to LLVM dialect. AsyncRuntimeTypeConverter converter; OwningRewritePatternList patterns; + // Convert async types in function signatures and function calls. populateFuncOpTypeConversionPattern(patterns, ctx, converter); - patterns.insert(ctx); - patterns.insert(ctx); - patterns.insert(ctx); - patterns.insert(ctx, outlinedFunctions); + populateCallOpTypeConversionPattern(patterns, ctx, converter); + + // Convert return operations inside async.execute regions. + patterns.insert(converter, ctx); + + // Lower async operations to async runtime API calls. + patterns.insert(converter, ctx); + patterns.insert(converter, ctx); + + // Use LLVM type converter to automatically convert between the async value + // payload type and LLVM type when loading/storing from/to the async + // value storage which is an opaque i8* pointer using LLVM load/store ops. + patterns + .insert( + llvmConverter, ctx, outlinedFunctions); + patterns.insert(llvmConverter, ctx, outlinedFunctions); ConversionTarget target(*ctx); target.addLegalOp(); target.addLegalDialect(); + + // All operations from Async dialect must be lowered to the runtime API calls. target.addIllegalDialect(); + + // Add dynamic legality constraints to apply conversions defined above. target.addDynamicallyLegalOp( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); - target.addDynamicallyLegalOp( - [&](CallOp op) { return converter.isLegal(op.getResultTypes()); }); + target.addDynamicallyLegalOp( + [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); + target.addDynamicallyLegalOp([&](CallOp op) { + return converter.isSignatureLegal(op.getCalleeType()); + }); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -130,6 +130,22 @@ std::vector> awaiters; }; +struct AsyncValue : public RefCounted { + // AsyncValue similar to an AsyncToken created with a reference count of 2. + AsyncValue(AsyncRuntime *runtime, int32_t size) + : RefCounted(runtime, /*count=*/2), storage(size) {} + + // Internal state below guarded by a mutex. + std::mutex mu; + std::condition_variable cv; + + bool ready = false; + std::vector> awaiters; + + // Use vector of bytes to store async value payload. + std::vector storage; +}; + struct AsyncGroup : public RefCounted { AsyncGroup(AsyncRuntime *runtime) : RefCounted(runtime), pendingTokens(0), rank(0) {} @@ -159,12 +175,18 @@ refCounted->dropRef(count); } -// Create a new `async.token` in not-ready state. +// Creates a new `async.token` in not-ready state. extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance()); return token; } +// Creates a new `async.value` in not-ready state. +extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) { + AsyncValue *value = new AsyncValue(getDefaultAsyncRuntimeInstance(), size); + return value; +} + // Create a new `async.group` in empty state. extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() { AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance()); @@ -228,18 +250,45 @@ token->dropRef(); } +// Switches `async.value` to ready state and runs all awaiters. +extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) { + // Make sure that `dropRef` does not destroy the mutex owned by the lock. + { + std::unique_lock lock(value->mu); + value->ready = true; + value->cv.notify_all(); + for (auto &awaiter : value->awaiters) + awaiter(); + } + + // Async values created with a ref count `2` to keep value alive until the + // async task completes. Drop this reference explicitly when value emplaced. + value->dropRef(); +} + extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { std::unique_lock lock(token->mu); if (!token->ready) token->cv.wait(lock, [token] { return token->ready; }); } +extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) { + std::unique_lock lock(value->mu); + if (!value->ready) + value->cv.wait(lock, [value] { return value->ready; }); +} + extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { std::unique_lock lock(group->mu); if (group->pendingTokens != 0) group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); } +// Returns a pointer to the storage owned by the async value. +extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) { + return value->storage.data(); +} + extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { (*resume)(handle); } @@ -255,6 +304,17 @@ token->awaiters.push_back([execute]() { execute(); }); } +extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value, + CoroHandle handle, + CoroResume resume) { + std::unique_lock lock(value->mu); + auto execute = [handle, resume]() { (*resume)(handle); }; + if (value->ready) + execute(); + else + value->awaiters.push_back([execute]() { execute(); }); +} + extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle, CoroResume resume) { diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -211,3 +211,44 @@ // Emplace result token. // CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]]) + +// ----- + +// CHECK-LABEL: execute_and_return_f32 +func @execute_and_return_f32() -> f32 { + // CHECK: %[[RET:.*]]:2 = call @async_execute_fn + %token, %result = async.execute -> !async.value { + %c0 = constant 123.0 : f32 + async.yield %c0 : f32 + } + + // CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[RET]]#1) + // CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]] + // CHECK: %[[LOADED:.*]] = llvm.load %[[ST_F32]] : !llvm.ptr + // CHECK: %[[CASTED:.*]] = llvm.mlir.cast %[[LOADED]] : !llvm.float to f32 + %0 = async.await %result : !async.value + + return %0 : f32 +} + +// Function outlined from the async.execute operation. +// CHECK-LABEL: func private @async_execute_fn() +// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken() +// CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue +// CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin + +// Suspend coroutine in the beginning. +// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL]], +// CHECK: llvm.call @llvm.coro.suspend + +// Emplace result value. +// CHECK: %[[CST:.*]] = constant 1.230000e+02 : f32 +// CHECK: %[[LLVM_CST:.*]] = llvm.mlir.cast %[[CST]] : f32 to !llvm.float +// CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]]) +// CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]] +// CHECK: llvm.store %[[LLVM_CST]], %[[ST_F32]] : !llvm.ptr +// CHECK: call @mlirAsyncRuntimeEmplaceValue(%[[VALUE]]) + +// Emplace result token. +// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[TOKEN]]) + diff --git a/mlir/test/mlir-cpu-runner/async-value.mlir b/mlir/test/mlir-cpu-runner/async-value.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/async-value.mlir @@ -0,0 +1,64 @@ +// RUN: mlir-opt %s -async-ref-counting \ +// RUN: -convert-async-to-llvm \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main -entry-point-result=void -O0 \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \ +// RUN: | FileCheck %s --dump-input=always + +func @main() { + + // ------------------------------------------------------------------------ // + // Blocking async.await outside of the async.execute. + // ------------------------------------------------------------------------ // + %token, %result = async.execute -> !async.value { + %0 = constant 123.456 : f32 + async.yield %0 : f32 + } + %1 = async.await %result : !async.value + + // CHECK: 123.456 + vector.print %1 : f32 + + // ------------------------------------------------------------------------ // + // Non-blocking async.await inside the async.execute + // ------------------------------------------------------------------------ // + %token0, %result0 = async.execute -> !async.value { + %token1, %result2 = async.execute -> !async.value { + %2 = constant 456.789 : f32 + async.yield %2 : f32 + } + %3 = async.await %result2 : !async.value + async.yield %3 : f32 + } + %4 = async.await %result0 : !async.value + + // CHECK: 456.789 + vector.print %4 : f32 + + // ------------------------------------------------------------------------ // + // Memref allocated inside async.execute region. + // ------------------------------------------------------------------------ // + %token2, %result2 = async.execute[%token0] -> !async.value> { + %5 = alloc() : memref + %c0 = constant 987.654 : f32 + store %c0, %5[]: memref + async.yield %5 : memref + } + %6 = async.await %result2 : !async.value> + %7 = memref_cast %6 : memref to memref<*xf32> + + // CHECK: Unranked Memref + // CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = [] + // CHECK-NEXT: [987.654] + call @print_memref_f32(%7): (memref<*xf32>) -> () + dealloc %6 : memref + + return +} + +func private @print_memref_f32(memref<*xf32>) + attributes { llvm.emit_c_interface }