diff --git a/mlir/include/mlir/Conversion/AsyncToLLVM/ConvertAsyncToLLVM.h b/mlir/include/mlir/Conversion/AsyncToLLVM/ConvertAsyncToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/AsyncToLLVM/ConvertAsyncToLLVM.h @@ -0,0 +1,28 @@ +//===- ConvertAsyncToLLVM.h - Conversion Patterns from Async to LLVM ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ASYNCTOLLVM_CONVERTASYNCTOLLVM_H_ +#define MLIR_CONVERSION_ASYNCTOLLVM_CONVERTASYNCTOLLVM_H_ + +#include + +namespace mlir { +class ModuleOp; +template +class OperationPass; +class OwningRewritePatternList; + +/// Collect a set of patterns to convert from the Async dialect to LLVM. +void populateAsyncToLLVMConversionPatterns(OwningRewritePatternList &patterns); + +/// Create a pass to convert Async operations to the LLVMIR dialect. +std::unique_ptr> createConvertAsyncToLLVMPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_ASYNCTOLLVM_CONVERTASYNCTOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -78,6 +78,16 @@ let constructor = "mlir::createConvertAVX512ToLLVMPass()"; } +//===----------------------------------------------------------------------===// +// AsyncToLLVM +//===----------------------------------------------------------------------===// + +def ConvertAsyncToLLVM : Pass<"convert-async-to-llvm", "ModuleOp"> { + let summary = "Convert the operations from the async dialect into the LLVM " + "dialect"; + let constructor = "mlir::createConvertAsyncToLLVMPass()"; +} + //===----------------------------------------------------------------------===// // GPUCommon //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Async/Async.h b/mlir/include/mlir/Dialect/Async/Async.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Async/Async.h @@ -0,0 +1,61 @@ +//===- Async.h - Asynchronous Operations ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines asynchronous operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ASYNC_H_ +#define MLIR_DIALECT_ASYNC_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace async { + +namespace AsyncTypes { +enum Kind { + Runtime = Type::FIRST_SHAPE_TYPE, + Handle, + LAST_SHAPE_TYPE = Handle +}; +} // namespace AsyncTypes + +class RuntimeType : public Type::TypeBase { +public: + using Base::Base; + + static RuntimeType get(MLIRContext *context) { + return Base::get(context, AsyncTypes::Runtime); + } + /// Used to implement llvm-style cast. + static bool kindof(unsigned kind) { return kind == AsyncTypes::Runtime; } +}; + +class HandleType : public Type::TypeBase { +public: + using Base::Base; + + static HandleType get(MLIRContext *context) { + return Base::get(context, AsyncTypes::Handle); + } + /// Used to implement llvm-style cast. + static bool kindof(unsigned kind) { return kind == AsyncTypes::Handle; } +}; + +#define GET_OP_CLASSES +#include "mlir/Dialect/Async/AsyncOps.h.inc" + +#include "mlir/Dialect/Async/AsyncOpsDialect.h.inc" + +} // end namespace async +} // end namespace mlir +#endif // MLIR_DIALECT_ASYNC_H_ diff --git a/mlir/include/mlir/Dialect/Async/AsyncOps.td b/mlir/include/mlir/Dialect/Async/AsyncOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Async/AsyncOps.td @@ -0,0 +1,131 @@ +//===- AsyncOps.td - Asynchronous operations ---------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines MLIR asynchronous operations +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ASYNC_ASYNCOPS +#define MLIR_DIALECT_ASYNC_ASYNCOPS + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Async Dialect definition +//===----------------------------------------------------------------------===// + +def Async_Dialect : Dialect { + let name = "async"; + let cppNamespace = ""; +} + +//===----------------------------------------------------------------------===// +// Async type definitions +//===----------------------------------------------------------------------===// + +def AsyncIsRuntimeTypePred : CPred<"$_self.isa()">; +def AsyncRuntimeType : DialectType; + +def AsyncIsHandleTypePred : CPred<"$_self.isa()">; +def AsyncHandleType : DialectType; + +//===----------------------------------------------------------------------===// +// Async op definitions +//===----------------------------------------------------------------------===// + +class Async_Op traits = []> : + Op { + + // Each registered op in the Async dialect needs to provide all + // of a printer, parser and verifier. + let parser = [{ return mlir::tfrt::parse$cppClass(parser, result); }]; +} + +def AsyncDefaultRuntimeOp : Async_Op<"default_runtime"> { + let summary = "Async default_runtime operation"; + let description = [{ + }]; + + let arguments = (ins ); + let results = (outs AsyncRuntimeType:$runtime); + + let assemblyFormat = [{ + attr-dict `:` type($runtime) + }]; +} + +def AsyncCreateHandleOp : Async_Op<"create_handle"> { + let summary = "Async create_handle operation"; + let description = [{ + }]; + + let arguments = (ins AsyncRuntimeType:$runtime); + let results = (outs AsyncHandleType:$handle); + + let assemblyFormat = [{ + $runtime attr-dict `:` functional-type($runtime, $handle) + }]; +} + +def AsyncEmplaceHandleOp : Async_Op<"emplace_handle"> { + let summary = "Async emplace_handle operation"; + let description = [{ + }]; + + let arguments = (ins AsyncHandleType:$handle); + let results = (outs); + + let assemblyFormat = [{ + $handle attr-dict `:` type($handle) + }]; +} + +def AsyncCallOp : Async_Op<"call"> { + let summary = "Async call operation"; + let description = [{ + }]; + + let arguments = (ins AsyncRuntimeType:$runtime, + FlatSymbolRefAttr:$callee, + Variadic:$operands); + let results = (outs AsyncHandleType:$handle); + + let assemblyFormat = [{ + $runtime `:` type($runtime) $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + +def AsyncAwaitOp : Async_Op<"await"> { + let summary = "Async await operation"; + let description = [{ + }]; + + let arguments = (ins AsyncHandleType:$handle); + let results = (outs); + + let assemblyFormat = [{ + $handle attr-dict `:` type($handle) + }]; +} + +def AsyncSyncAwaitOp : Async_Op<"sync_await"> { + let summary = "Async sync_await operation"; + let description = [{ + }]; + + let arguments = (ins AsyncHandleType:$handle); + let results = (outs); + + let assemblyFormat = [{ + $handle attr-dict `:` type($handle) + }]; +} + +#endif // MLIR_DIALECT_ASYNC_ASYNCOPS diff --git a/mlir/include/mlir/Dialect/Async/CMakeLists.txt b/mlir/include/mlir/Dialect/Async/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Async/CMakeLists.txt @@ -0,0 +1,2 @@ +add_mlir_dialect(AsyncOps async Ops) +add_mlir_doc(AsyncOps -gen-dialect-doc AsyncDialect Dialects/) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -151,6 +151,10 @@ static LLVMType getVoidTy(LLVMDialect *dialect); bool isVoidTy(); + /// Token type utilities. + static LLVMType getTokenTy(LLVMDialect *dialect); + bool isTokenTy(); + // Creation and setting of LLVM's identified struct types static LLVMType createStructTy(LLVMDialect *dialect, ArrayRef elements, diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h @@ -0,0 +1,63 @@ +//===- AsyncRuntime.h - Async runtime reference implementation ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares basic Async runtime API for supporting Async dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef EXECUTIONENGINE_ASYNCRUNTIME_H_ +#define EXECUTIONENGINE_ASYNCRUNTIME_H_ + +#ifdef _WIN32 +#ifndef MLIR_ASYNCRUNTIME_EXPORT +#ifdef mlir_c_runner_utils_EXPORTS +/* We are building this library */ +#define MLIR_ASYNCRUNTIME_EXPORT __declspec(dllexport) +#define MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS +#else +/* We are using this library */ +#define MLIR_ASYNCRUNTIME_EXPORT __declspec(dllimport) +#endif // mlir_c_runner_utils_EXPORTS +#endif // MLIR_ASYNCRUNTIME_EXPORT +#else +#define MLIR_ASYNCRUNTIME_EXPORT +#define MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS +#endif // _WIN32 + +//===----------------------------------------------------------------------===// +// MLIRAsync runtime API. +//===----------------------------------------------------------------------===// +// clang-format off + +typedef struct MLIR_AsyncRuntime MLIR_AsyncRuntime; +typedef struct MLIR_AsyncHandle MLIR_AsyncHandle; + +using TaskFunction = void (*)(); // asynchronous task task function +using CoroHandle = void *; // coroutine handle +using CoroResume = void (*)(void *); // coroutine resume function + +extern "C" MLIR_ASYNCRUNTIME_EXPORT +MLIR_AsyncRuntime *MLIR_AsyncRT_DefaultRuntime(); + +extern "C" MLIR_ASYNCRUNTIME_EXPORT +MLIR_AsyncHandle *MLIR_AsyncRT_Call(MLIR_AsyncRuntime *, TaskFunction); + +extern "C" MLIR_ASYNCRUNTIME_EXPORT +void MLIR_AsyncRT_Await(MLIR_AsyncHandle *, CoroHandle, CoroResume); + +extern "C" MLIR_ASYNCRUNTIME_EXPORT +void MLIR_AsyncRT_SyncAwait(MLIR_AsyncHandle *); + +extern "C" MLIR_ASYNCRUNTIME_EXPORT +MLIR_AsyncHandle * MLIR_AsyncRT_CreateHandle(MLIR_AsyncRuntime *); + +extern "C" MLIR_ASYNCRUNTIME_EXPORT +void MLIR_AsyncRT_EmplaceHandle(MLIR_AsyncHandle *); + +// clang-format on +#endif // EXECUTIONENGINE_ASYNCRUNTIME_H_ diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/AVX512/AVX512Dialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Async/Async.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -40,6 +41,7 @@ static bool init_once = []() { registerDialect(); registerDialect(); + registerDialect(); registerDialect(); registerDialect(); registerDialect(); diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -15,6 +15,7 @@ #define MLIR_INITALLPASSES_H_ #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" +#include "mlir/Conversion/AsyncToLLVM/ConvertAsyncToLLVM.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" diff --git a/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt b/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_conversion_library(MLIRASYNCToLLVM + ConvertAsyncToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/AsyncToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRASYNC + MLIRLLVMIR + MLIRStandardToLLVM + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/AsyncToLLVM/ConvertAsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/ConvertAsyncToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/AsyncToLLVM/ConvertAsyncToLLVM.cpp @@ -0,0 +1,702 @@ +//===- ConvertAsyncToLLVM.cpp - Convert Async to the LLVM dialect ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Each function that has `async.await` operation in the function body converted +// to a coroutine. See https://llvm.org/docs/Coroutines.html for documentation. +// +//===----------------------------------------------------------------------===// +#include "mlir/Conversion/AsyncToLLVM/ConvertAsyncToLLVM.h" + +#include "../PassDetail.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/Async/Async.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::async; + +namespace { +static constexpr const char *kResume = "__resume"; +static constexpr const char *kMalloc = "malloc"; +static constexpr const char *kFree = "free"; + +// clang-format off +static constexpr const char *kAsyncCall = "MLIR_AsyncRT_Call"; +static constexpr const char *kAwait = "MLIR_AsyncRT_Await"; +static constexpr const char *kSyncAwait = "MLIR_AsyncRT_SyncAwait"; +static constexpr const char *kCreateHandle = "MLIR_AsyncRT_CreateHandle"; +static constexpr const char *kEmplaceHandle = "MLIR_AsyncRT_EmplaceHandle"; +static constexpr const char *kDefaultRuntime = "MLIR_AsyncRT_DefaultRuntime"; + +static constexpr const char* kCoroId = "llvm.coro.id"; +static constexpr const char* kCoroSizeI64 = "llvm.coro.size.i64"; +static constexpr const char* kCoroBegin = "llvm.coro.begin"; +static constexpr const char* kCoroSave = "llvm.coro.save"; +static constexpr const char* kCoroSuspend = "llvm.coro.suspend"; +static constexpr const char* kCoroEnd = "llvm.coro.end"; +static constexpr const char* kCoroFree = "llvm.coro.free"; +static constexpr const char* kCoroResume = "llvm.coro.resume"; +// clang-format on + +struct ConvertAsyncToLLVMPass + : public ConvertAsyncToLLVMBase { + void runOnOperation() override; +}; + +// Function targeted for coroutine transformation has two additional blocks at +// the end: coroutine cleanup and coroutine suspension. +// +// async.await op lowering additionaly creates a resume block for each +// operation. +struct CoroMachinery { + Value coroHandle; + Block *cleanup; + Block *suspend; +}; + +/// LLVM function types for Async Runtime API. +struct API { + static LLVM::LLVMType resumeFunctionType(LLVM::LLVMDialect *llvm) { + auto voidTy = LLVM::LLVMType::getVoidTy(llvm); + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(llvm); + return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false); + } + + static LLVM::LLVMType asyncFunctionType(LLVM::LLVMDialect *llvm) { + auto voidTy = LLVM::LLVMType::getVoidTy(llvm); + return LLVM::LLVMType::getFunctionTy(voidTy, {}, false); + } + + static LLVM::LLVMType asyncCallFunctionType(LLVM::LLVMDialect *llvm) { + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(llvm); + auto calleePtr = asyncFunctionType(llvm).getPointerTo(); + return LLVM::LLVMType::getFunctionTy(i8Ptr, {i8Ptr, calleePtr}, false); + } + + static LLVM::LLVMType awaitFunctionType(LLVM::LLVMDialect *llvm) { + auto voidTy = LLVM::LLVMType::getVoidTy(llvm); + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(llvm); + auto resumePtr = resumeFunctionType(llvm).getPointerTo(); + return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr, i8Ptr, resumePtr}, + false); + } + + static LLVM::LLVMType syncAwaitFunctionType(LLVM::LLVMDialect *llvm) { + auto voidTy = LLVM::LLVMType::getVoidTy(llvm); + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(llvm); + return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false); + } + + static LLVM::LLVMType createHandleFunctionType(LLVM::LLVMDialect *llvm) { + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(llvm); + return LLVM::LLVMType::getFunctionTy(i8Ptr, {i8Ptr}, false); + } + + static LLVM::LLVMType emplaceHandleFunctionType(LLVM::LLVMDialect *llvm) { + auto voidTy = LLVM::LLVMType::getVoidTy(llvm); + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(llvm); + return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false); + } + + static LLVM::LLVMType defaultRuntimeFunctionType(LLVM::LLVMDialect *llvm) { + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(llvm); + return LLVM::LLVMType::getFunctionTy(i8Ptr, {}, false); + } +}; + +LLVM::LLVMDialect *LlvmDialect(MLIRContext *ctx) { + return ctx->getRegisteredDialect(); +} + +//===----------------------------------------------------------------------===// +// Convert Async dialect types to opaque LLVM pointers. +//===----------------------------------------------------------------------===// + +class AsyncRuntimeTypeConverted : public TypeConverter { +public: + AsyncRuntimeTypeConverted() { + addConversion([](Type type) { + auto *llvm = LlvmDialect(type.getContext()); + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(llvm); + + if (type.isa() || type.isa()) + return Type(i8Ptr); + + return type; + }); + } +}; + +//===----------------------------------------------------------------------===// +// async.call op lowering to MLIR_AsyncRT_Call function call. +//===----------------------------------------------------------------------===// + +class AsyncCallOpLowering : public ConversionPattern { +public: + explicit AsyncCallOpLowering(MLIRContext *ctx) + : ConversionPattern(AsyncCallOp::getOperationName(), 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + AsyncCallOp async_call = cast(op); + + MLIRContext *ctx = op->getContext(); + Location loc = op->getLoc(); + + ModuleOp module = op->getParentOfType(); + FuncOp callee = module.lookupSymbol(async_call.callee()); + + auto *llvm = LlvmDialect(ctx); + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(llvm); + + auto asyncFnTy = API::asyncFunctionType(llvm); + auto callePtr = rewriter.create( + loc, asyncFnTy.getPointerTo(), + rewriter.getSymbolRefAttr(callee.getOperation())); + + rewriter.replaceOpWithNewOp( + op, i8Ptr, rewriter.getSymbolRefAttr(kAsyncCall), + ArrayRef({op->getOperand(0), callePtr})); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// async.await op lowering to MLIR_AsyncRT_Await function call. +//===----------------------------------------------------------------------===// + +class AsyncAwaitOpLowering : public ConversionPattern { +public: + explicit AsyncAwaitOpLowering(MLIRContext *ctx, CoroMachinery coroMachinery) + : ConversionPattern(AsyncAwaitOp::getOperationName(), 1, ctx), + coroMachinery(coroMachinery) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op->getContext(); + Location loc = op->getLoc(); + + auto *llvm = LlvmDialect(ctx); + + auto token_ty = LLVM::LLVMType::getTokenTy(llvm); + auto i1 = LLVM::LLVMType::getInt1Ty(llvm); + auto i8 = LLVM::LLVMType::getInt8Ty(llvm); + + auto resumeFnTy = API::resumeFunctionType(llvm); + auto resumePtr = rewriter.create( + loc, resumeFnTy.getPointerTo(), rewriter.getSymbolRefAttr(kResume)); + + // Save coroutine state: @llvm.coro.save + auto coroSave = rewriter.create( + loc, token_ty, rewriter.getSymbolRefAttr(kCoroSave), + ArrayRef({coroMachinery.coroHandle})); + + // Call MLIR async runtime await API with an async handle, coroutine handle, + // and a pointer to resume function. + rewriter.create( + loc, Type(), rewriter.getSymbolRefAttr(kAwait), + ArrayRef( + {op->getOperand(0), coroMachinery.coroHandle, resumePtr})); + rewriter.eraseOp(op); + + // Everything after await call goes into the coroutine resume block. + Block *resume = op->getBlock()->splitBlock(op->getNextNode()); + + // Branch into the resume block unconditionally. This branch will be + // replaced with a conditional branch after we'll add suspension point. + OpBuilder builder = OpBuilder::atBlockEnd(resume->getPrevNode()); + builder.create(loc, resume); + + // Add a coroutine suspension point right after the await call. + builder = OpBuilder::atBlockTerminator(resume->getPrevNode()); + + auto constFalse = + builder.create(loc, i1, builder.getBoolAttr(false)); + + // Suspend a coroutine: @llvm.coro.suspend + auto coroSuspend = builder.create( + loc, i8, builder.getSymbolRefAttr(kCoroSuspend), + ArrayRef({coroSave.getResult(0), constFalse})); + + // After a suspension point decide if we should branch into resume, cleanup + // or suspend block of the coroutine (see @llvm.coro.suspend return code + // documentation). + auto constZero = + builder.create(loc, i8, builder.getI8IntegerAttr(0)); + auto constNegOne = + builder.create(loc, i8, builder.getI8IntegerAttr(-1)); + + auto isZero = builder.create( + loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero); + auto isNegOne = builder.create( + loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne); + + Block *resumeOrCleanup = builder.createBlock(resume); + // Suspend the coroutine ...? + builder = OpBuilder::atBlockTerminator(resumeOrCleanup->getPrevNode()); + builder.create(loc, isNegOne, coroMachinery.suspend, + ArrayRef(), resumeOrCleanup, + ArrayRef()); + resumeOrCleanup->getPrevNode()->getTerminator()->erase(); + + // ... or resume or cleanup the coroutine? + builder = OpBuilder::atBlockBegin(resumeOrCleanup); + builder.create(loc, isZero, resume, ArrayRef(), + coroMachinery.cleanup, ArrayRef()); + + return success(); + } + +private: + CoroMachinery coroMachinery; +}; + +//===----------------------------------------------------------------------===// +// async.sync_await op lowering to MLIR_AsyncRT_SyncAwait function call. +//===----------------------------------------------------------------------===// + +class AsyncSyncAwaitOpLowering : public ConversionPattern { +public: + explicit AsyncSyncAwaitOpLowering(MLIRContext *ctx) + : ConversionPattern(AsyncSyncAwaitOp::getOperationName(), 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.create(op->getLoc(), Type(), + rewriter.getSymbolRefAttr(kSyncAwait), + ArrayRef({op->getOperand(0)})); + rewriter.eraseOp(op); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// async.create_handle op lowering to MLIR_AsyncRT_CreateHandle function call. +//===----------------------------------------------------------------------===// + +class AsyncCreateHandleOpLowering : public ConversionPattern { +public: + explicit AsyncCreateHandleOpLowering(MLIRContext *ctx) + : ConversionPattern(AsyncCreateHandleOp::getOperationName(), 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op->getContext(); + + auto *llvm = LlvmDialect(ctx); + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(llvm); + rewriter.replaceOpWithNewOp( + op, i8Ptr, rewriter.getSymbolRefAttr(kCreateHandle), + ArrayRef({op->getOperand(0)})); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// async.emplace_handle op lowering to MLIR_AsyncRT_EmplaceHandle function call. +//===----------------------------------------------------------------------===// + +class AsyncEmplaceHandleOpLowering : public ConversionPattern { +public: + explicit AsyncEmplaceHandleOpLowering(MLIRContext *ctx) + : ConversionPattern(AsyncEmplaceHandleOp::getOperationName(), 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.create(op->getLoc(), Type(), + rewriter.getSymbolRefAttr(kEmplaceHandle), + ArrayRef({op->getOperand(0)})); + rewriter.eraseOp(op); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// async.default_runtime op lowering to MLIR_AsyncRT_DefaultRuntime function +// call. +//===----------------------------------------------------------------------===// + +class AsyncDefaultRuntimeOpLowering : public ConversionPattern { +public: + explicit AsyncDefaultRuntimeOpLowering(MLIRContext *ctx) + : ConversionPattern(AsyncDefaultRuntimeOp::getOperationName(), 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op->getContext(); + + auto *llvm = LlvmDialect(ctx); + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(llvm); + rewriter.replaceOpWithNewOp( + op, i8Ptr, rewriter.getSymbolRefAttr(kDefaultRuntime), + ArrayRef()); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Convert types for all call operations to lower async types. +//===----------------------------------------------------------------------===// + +class CallOpOpConversion : public ConversionPattern { +public: + explicit CallOpOpConversion(MLIRContext *ctx) + : ConversionPattern(CallOp::getOperationName(), 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + AsyncRuntimeTypeConverted converter; + + SmallVector resultTypes; + converter.convertTypes(op->getResultTypes(), resultTypes); + + CallOp call = cast(op); + rewriter.replaceOpWithNewOp(op, resultTypes, call.callee(), + call.getOperands()); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// + +/// Add Async Runtime C API declarations to the module. +void AddAsyncRuntimeApiDeclarations(ModuleOp module) { + MLIRContext *ctx = module.getContext(); + Location loc = module.getLoc(); + + OpBuilder builder(module.getBody()->getTerminator()); + auto *llvm = LlvmDialect(ctx); + + if (!module.lookupSymbol(kAsyncCall)) + builder.create(loc, kAsyncCall, + API::asyncCallFunctionType(llvm)); + + if (!module.lookupSymbol(kAwait)) + builder.create(loc, kAwait, API::awaitFunctionType(llvm)); + + if (!module.lookupSymbol(kSyncAwait)) + builder.create(loc, kSyncAwait, + API::syncAwaitFunctionType(llvm)); + + if (!module.lookupSymbol(kCreateHandle)) + builder.create(loc, kCreateHandle, + API::createHandleFunctionType(llvm)); + + if (!module.lookupSymbol(kEmplaceHandle)) + builder.create(loc, kEmplaceHandle, + API::emplaceHandleFunctionType(llvm)); + + if (!module.lookupSymbol(kDefaultRuntime)) + builder.create(loc, kDefaultRuntime, + API::defaultRuntimeFunctionType(llvm)); +} + +/// Adds malloc/free declarations to the module. +void AddCRuntimeDeclarations(ModuleOp module) { + using namespace mlir::LLVM; + + MLIRContext *ctx = module.getContext(); + Location loc = module.getLoc(); + + OpBuilder builder(module.getBody()->getTerminator()); + auto *llvm = LlvmDialect(ctx); + + auto voidTy = LLVMType::getVoidTy(llvm); + auto i64 = LLVMType::getInt64Ty(llvm); + auto i8Ptr = LLVMType::getInt8PtrTy(llvm); + + if (!module.lookupSymbol(kMalloc)) + builder.create( + loc, kMalloc, LLVMType::getFunctionTy(i8Ptr, {i64}, false)); + + if (!module.lookupSymbol(kFree)) + builder.create( + loc, kFree, LLVMType::getFunctionTy(voidTy, i8Ptr, false)); +} + +/// Adds coroutine intrinsics declarations to the module. +void AddCoroutineIntrinsicsDeclarations(ModuleOp module) { + using namespace mlir::LLVM; + + MLIRContext *ctx = module.getContext(); + Location loc = module.getLoc(); + + OpBuilder builder(module.getBody()->getTerminator()); + auto *llvm = LlvmDialect(ctx); + + auto token = LLVMType::getTokenTy(llvm); + auto voidTy = LLVMType::getVoidTy(llvm); + + auto i8 = LLVMType::getInt8Ty(llvm); + auto i1 = LLVMType::getInt1Ty(llvm); + auto i32 = LLVMType::getInt32Ty(llvm); + auto i64 = LLVMType::getInt64Ty(llvm); + auto i8Ptr = LLVMType::getInt8PtrTy(llvm); + + if (!module.lookupSymbol(kCoroId)) { + builder.create( + loc, kCoroId, + LLVMType::getFunctionTy(token, {i32, i8Ptr, i8Ptr, i8Ptr}, false)); + } + + if (!module.lookupSymbol(kCoroSizeI64)) { + builder.create(loc, kCoroSizeI64, + LLVMType::getFunctionTy(i64, false)); + } + + if (!module.lookupSymbol(kCoroBegin)) { + builder.create( + loc, kCoroBegin, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false)); + } + + if (!module.lookupSymbol(kCoroSave)) { + builder.create(loc, kCoroSave, + LLVMType::getFunctionTy(token, i8Ptr, false)); + } + + if (!module.lookupSymbol(kCoroSuspend)) { + builder.create(loc, kCoroSuspend, + LLVMType::getFunctionTy(i8, {token, i1}, false)); + } + + if (!module.lookupSymbol(kCoroEnd)) { + builder.create(loc, kCoroEnd, + LLVMType::getFunctionTy(i1, {i8Ptr, i1}, false)); + } + + if (!module.lookupSymbol(kCoroFree)) { + builder.create( + loc, kCoroFree, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false)); + } + + if (!module.lookupSymbol(kCoroResume)) { + builder.create(loc, kCoroResume, + LLVMType::getFunctionTy(voidTy, i8Ptr, false)); + } +} + +/// Adds a function that passes input coroutine handle to llvm.coro.resume +/// intrinsic. We need this function to be able to take an address and pass +/// it to async.await callback. +void AddResumeFunction(ModuleOp module) { + OpBuilder moduleBuilder(module.getBody()->getTerminator()); + Location loc = module.getLoc(); + + if (module.lookupSymbol(kResume)) + return; + + auto *llvm = LlvmDialect(module.getContext()); + auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(llvm); + + auto resumeOp = moduleBuilder.create( + loc, kResume, moduleBuilder.getFunctionType({i8Ptr}, {}), + ArrayRef{}); + SymbolTable::setSymbolVisibility(resumeOp, SymbolTable::Visibility::Private); + + auto *block = resumeOp.addEntryBlock(); + OpBuilder blockBuilder = OpBuilder::atBlockEnd(block); + + SmallVector args = {resumeOp.getArgument(0)}; + blockBuilder.create( + loc, Type(), blockBuilder.getSymbolRefAttr("llvm.coro.resume"), args); + + blockBuilder.create(loc); +} + +/// Returns `true` if the function has operations from the Async dialect in the +/// body. +bool IsAsyncFunction(FuncOp func) { + auto *async = func.getContext()->getRegisteredDialect(); + for (auto &block : func.getBlocks()) { + for (auto &op : block.getOperations()) + if (op.getDialect() == async) + return true; + } + return false; +} + +/// Returns `true` if the function has `async.await` operation in the body, +/// which requires the function to be converted into the coroutine. +bool IsAsyncCoroFunction(FuncOp func) { + for (auto &block : func.getBlocks()) { + for (auto &op : block.getOperations()) + if (isa(op)) + return true; + } + return false; +} + +/// Adds a coroutine initialization, cleanup and suspention points to the +/// function. This will allow to use LLVM coroutines passes to convert a +/// function to the coroutine. +CoroMachinery AddCoroMachinery(FuncOp func) { + using namespace mlir::LLVM; + + OpBuilder builder(func.getBody()); + Location loc = func.getBody().getLoc(); + + auto *llvm = LlvmDialect(func.getContext()); + + auto token = LLVMType::getTokenTy(llvm); + auto i1 = LLVMType::getInt1Ty(llvm); + auto i32 = LLVMType::getInt32Ty(llvm); + auto i64 = LLVMType::getInt64Ty(llvm); + auto i8Ptr = LLVMType::getInt8PtrTy(llvm); + + // ------------------------------------------------------------------------ // + // Initialize coroutine: allocate frame, get coroutine handle. + // ------------------------------------------------------------------------ // + + // Constants for initializing coroutine frame. + auto constZero = + builder.create(loc, i32, builder.getI32IntegerAttr(0)); + auto constFalse = + builder.create(loc, i1, builder.getBoolAttr(false)); + auto nullPtr = builder.create(loc, i8Ptr); + + // Get coroutine id: @llvm.coro.id + auto coroId = builder.create( + loc, token, builder.getSymbolRefAttr(kCoroId), + ArrayRef({constZero, nullPtr, nullPtr, nullPtr})); + + // Get coroutine frame size: @llvm.coro.size.i64 + auto coroSize = builder.create( + loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ArrayRef()); + + // Allocate memory for coroutine frame. + auto coroAlloc = builder.create( + loc, i8Ptr, builder.getSymbolRefAttr(kMalloc), + ArrayRef({coroSize.getResult(0)})); + + // Begin a coroutine: @llvm.coro.begin + auto coroHdl = builder.create( + loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin), + ArrayRef({coroId.getResult(0), coroAlloc.getResult(0)})); + + Block *cleanupBlock = func.addBlock(); + Block *suspendBlock = func.addBlock(); + + // ------------------------------------------------------------------------ // + // Coroutine cleanup block: deallocate coroutine frame, free the memory. + // ------------------------------------------------------------------------ // + builder = OpBuilder::atBlockBegin(cleanupBlock); + + // Get a pointer to the coroutine frame memory: @llvm.coro.free. + auto coroMem = builder.create( + loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree), + ArrayRef({coroId.getResult(0), coroHdl.getResult(0)})); + + // Free the memory. + builder.create(loc, Type(), builder.getSymbolRefAttr(kFree), + ArrayRef({coroMem.getResult(0)})); + // Branch into the return block. + builder.create(loc, suspendBlock); + + // ------------------------------------------------------------------------ // + // Coroutine suspend block: return allocated async handle, and mark the + // end of a coroutine. + // ------------------------------------------------------------------------ // + builder = OpBuilder::atBlockBegin(suspendBlock); + + // Mark the end of a coroutine: @llvm.coro.end. + builder.create( + loc, i1, builder.getSymbolRefAttr(kCoroEnd), + ArrayRef({coroHdl.getResult(0), constFalse})); + + // Clone return statement into the suspend block. + auto *ret = func.getBody().front().getTerminator(); + builder.clone(*ret); + + // And replace the original return with a branch to cleanup. + ret->erase(); + builder = OpBuilder::atBlockEnd(&func.getBody().front()); + builder.create(loc, cleanupBlock); + + // `async.await` op lowering will create resume blocks for async + // continuations, and will conditionally branch to cleanup or suspend blocks. + + return {coroHdl.getResult(0), cleanupBlock, suspendBlock}; +} + +} // namespace + +/// Populate the given list with patterns that convert from Async to LLVM. +void mlir::populateAsyncToLLVMConversionPatterns( + OwningRewritePatternList &patterns) {} + +void ConvertAsyncToLLVMPass::runOnOperation() { + ModuleOp module = getOperation(); + MLIRContext *ctx = module.getContext(); + + AddAsyncRuntimeApiDeclarations(module); + AddCRuntimeDeclarations(module); + AddCoroutineIntrinsicsDeclarations(module); + AddResumeFunction(module); + + module.walk([&](FuncOp func) -> WalkResult { + // Skip functions that do not have operations from async dialect. + if (!IsAsyncFunction(func)) + return WalkResult::advance(); + + AsyncRuntimeTypeConverted converter; + OwningRewritePatternList patterns; + + // Trivial lowerings for async types and operations. + populateFuncOpTypeConversionPattern(patterns, ctx, converter); + populateAsyncToLLVMConversionPatterns(patterns); + patterns.insert(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + + // Add lowering from async.await to LLVM coroutines. + if (IsAsyncCoroFunction(func)) { + auto coro = AddCoroMachinery(func); + patterns.insert(ctx, coro); + } + + // Apply all the conversion patterns. + ConversionTarget target(*ctx); + target.addLegalDialect(); + target.addIllegalDialect(); + target.addDynamicallyLegalOp( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + target.addDynamicallyLegalOp( + [&](CallOp op) { return converter.isLegal(op.getResultTypes()); }); + + if (failed(applyPartialConversion(func, target, patterns))) { + signalPassFailure(); + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); +} + +std::unique_ptr> mlir::createConvertAsyncToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(AffineToStandard) add_subdirectory(AVX512ToLLVM) +add_subdirectory(AsyncToLLVM) add_subdirectory(GPUCommon) add_subdirectory(GPUToNVVM) add_subdirectory(GPUToROCDL) diff --git a/mlir/lib/Dialect/Async/Async.cpp b/mlir/lib/Dialect/Async/Async.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Async/Async.cpp @@ -0,0 +1,73 @@ +//===- Async.cpp - Asynchronous Operations --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Async/Async.h" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LogicalResult.h" + +using namespace mlir; +using namespace mlir::async; + +async::AsyncDialect::AsyncDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + addTypes(); + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Async/AsyncOps.cpp.inc" + >(); +} + +Type mlir::async::AsyncDialect::parseType(DialectAsmParser &parser) const { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return Type(); + MLIRContext *context = getContext(); + + if (keyword == "runtime") + return RuntimeType::get(context); + if (keyword == "handle") + return HandleType::get(context); + + parser.emitError(parser.getNameLoc(), "unknown Async type: " + keyword); + return Type(); +} + +static void print(RuntimeType rt, DialectAsmPrinter &os) { os << "runtime"; } +static void print(HandleType rt, DialectAsmPrinter &os) { os << "handle"; } + +void mlir::async::AsyncDialect::printType(Type type, + DialectAsmPrinter &os) const { + switch (type.getKind()) { + default: + llvm_unreachable("Unhandled Async type"); + case AsyncTypes::Runtime: + print(type.cast(), os); + break; + case AsyncTypes::Handle: + print(type.cast(), os); + break; + } +} + +namespace mlir { +namespace async { +#define GET_OP_CLASSES +#include "mlir/Dialect/Async/AsyncOps.cpp.inc" +} // namespace async +} // namespace mlir diff --git a/mlir/lib/Dialect/Async/CMakeLists.txt b/mlir/lib/Dialect/Async/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Async/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_dialect_library(MLIRASYNC + Async.cpp + + DEPENDS + MLIRAsyncOpsIncGen + + LINK_LIBS PUBLIC + MLIREDSC + MLIRIR + ) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1634,6 +1634,7 @@ LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; LLVMType doubleTy, floatTy, bfloatTy, halfTy, fp128Ty, x86_fp80Ty; LLVMType voidTy; + LLVMType tokenTy; /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not /// multi-threaded and requires locked access to prevent race conditions. @@ -1674,6 +1675,7 @@ LLVMType::get(context, llvm::Type::getX86_FP80Ty(llvmContext)); /// Other Types. impl->voidTy = LLVMType::get(context, llvm::Type::getVoidTy(llvmContext)); + impl->tokenTy = LLVMType::get(context, llvm::Type::getTokenTy(llvmContext)); } LLVMDialect::~LLVMDialect() {} @@ -1946,6 +1948,12 @@ bool LLVMType::isVoidTy() { return getUnderlyingType()->isVoidTy(); } +LLVMType LLVMType::getTokenTy(LLVMDialect *dialect) { + return dialect->impl->tokenTy; +} + +bool LLVMType::isTokenTy() { return getUnderlyingType()->isTokenTy(); } + //===----------------------------------------------------------------------===// // Utility functions. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -0,0 +1,96 @@ +//===- AsyncRuntime.cpp - Async runtime reference implementation ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements basic Async runtime API for supporting Async dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/ExecutionEngine/AsyncRuntime.h" + +#ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS + +#include +#include + +struct MLIR_AsyncRuntime {}; + +struct MLIR_AsyncHandle { + bool completed = false; + std::mutex mu; + std::condition_variable cv; + std::vector> awaiters; +}; + +MLIR_AsyncRuntime *DefaultInstance() { + static MLIR_AsyncRuntime *instance = new MLIR_AsyncRuntime(); + return instance; +} + +extern "C" MLIR_ASYNCRUNTIME_EXPORT MLIR_AsyncRuntime * +MLIR_AsyncRT_DefaultRuntime() { + return DefaultInstance(); +} + +extern "C" MLIR_ASYNCRUNTIME_EXPORT MLIR_AsyncHandle * +MLIR_AsyncRT_Call(MLIR_AsyncRuntime *runtime, TaskFunction task) { + MLIR_AsyncHandle *handle = new MLIR_AsyncHandle; + + // Launch task in a separate thread. + std::thread worker([handle, task]() { + (*task)(); + MLIR_AsyncRT_EmplaceHandle(handle); + }); + + worker.detach(); + + return handle; +} + +extern "C" MLIR_ASYNCRUNTIME_EXPORT MLIR_AsyncHandle * +MLIR_AsyncRT_CreateHandle(MLIR_AsyncRuntime *) { + MLIR_AsyncHandle *handle = new MLIR_AsyncHandle; + return handle; +} + +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +MLIR_AsyncRT_EmplaceHandle(MLIR_AsyncHandle *handle) { + std::unique_lock lock(handle->mu); + handle->completed = true; + handle->cv.notify_all(); + for (auto &awaiter : handle->awaiters) { + awaiter(); + } +} + +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +MLIR_AsyncRT_Await(MLIR_AsyncHandle *handle, CoroHandle coro_handle, + CoroResume coro_resume) { + std::unique_lock lock(handle->mu); + + if (handle->completed) { + (*coro_resume)(coro_handle); + delete handle; + } else { + handle->awaiters.push_back([coro_handle, coro_resume, handle]() { + (*coro_resume)(coro_handle); + delete handle; + }); + } +} + +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +MLIR_AsyncRT_SyncAwait(MLIR_AsyncHandle *handle) { + std::unique_lock lock(handle->mu); + + if (!handle->completed) + handle->cv.wait(lock, [handle] { return handle->completed; }); + + delete handle; +} + +#endif diff --git a/mlir/lib/ExecutionEngine/OptUtils.cpp b/mlir/lib/ExecutionEngine/OptUtils.cpp --- a/mlir/lib/ExecutionEngine/OptUtils.cpp +++ b/mlir/lib/ExecutionEngine/OptUtils.cpp @@ -25,6 +25,7 @@ #include "llvm/Support/Error.h" #include "llvm/Support/StringSaver.h" #include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/Coroutines.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include @@ -56,6 +57,7 @@ llvm::initializeAggressiveInstCombine(registry); llvm::initializeAnalysis(registry); llvm::initializeVectorization(registry); + llvm::initializeCoroutines(registry); } // Populate pass managers according to the optimization and size levels. @@ -73,6 +75,9 @@ builder.SLPVectorize = optLevel > 1 && sizeLevel < 2; builder.DisableUnrollLoops = (optLevel == 0); + // Add all coroutine passes to the builder. + addCoroutinePassesToExtensionPoints(builder); + if (targetMachine) { // Add pass to initialize TTI for this specific target. Otherwise, TTI will // be initialized to NoTTIImpl by default. diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -0,0 +1,53 @@ +// RUN: mlir-opt %s -convert-async-to-llvm | mlir-opt | FileCheck %s + +func @callee() { +} + +// CHECK: func @create_and_emplace_handle +func @create_and_emplace_handle() { + // CHECK: %[[R0:.*]] = llvm.call @MLIR_AsyncRT_DefaultRuntime() + // CHECK: %[[R1:.*]] = llvm.call @MLIR_AsyncRT_CreateHandle(%[[R0]]) + // CHECK: llvm.call @MLIR_AsyncRT_EmplaceHandle(%[[R1]]) + // CHECK: llvm.call @MLIR_AsyncRT_SyncAwait(%[[R1]]) + %runtime = async.default_runtime : !async.runtime + %handle = async.create_handle %runtime : (!async.runtime) -> !async.handle + async.emplace_handle %handle : !async.handle + async.sync_await %handle : !async.handle + return +} + +// CHECK: func @async_call(%arg0: !llvm<"i8*">) +func @async_call(%arg0: !async.runtime) { + // CHECK: %[[R0:.*]] = llvm.mlir.constant(@callee) + // CHECK: %[[R1:.*]] = llvm.call @MLIR_AsyncRT_Call(%arg0, %[[R0]]) + %handle = async.call %arg0 : !async.runtime @callee() : () -> !async.handle + return +} + +// CHECK: func @async_await(%arg0: !llvm<"i8*">) +func @async_await(%arg0: !async.runtime) -> !async.handle { + // CHECK: %[[ID:.*]] = llvm.call @llvm.coro.id + // CHECK: %[[SIZE:.*]] = llvm.call @llvm.coro.size.i64 + // CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin + + // CHECK: %[[AHDL:.*]] = llvm.call @MLIR_AsyncRT_CreateHandle + %ret_handle = async.create_handle %arg0 : (!async.runtime) -> !async.handle + + %handle = async.call %arg0 : !async.runtime @callee() : () -> !async.handle + + // CHECK: %[[SAVED:.*]] = llvm.call @llvm.coro.save + // CHECK: llvm.call @MLIR_AsyncRT_Await + // CHECK: %[[SUSPENDED:.*]] = llvm.call @llvm.coro.suspend + async.await %handle : !async.handle + + // CHECK: llvm.call @MLIR_AsyncRT_EmplaceHandle + async.emplace_handle %ret_handle : !async.handle + return %ret_handle : !async.handle +} + +// CHECK: llvm.func @MLIR_AsyncRT_Call +// CHECK: llvm.func @MLIR_AsyncRT_Await +// CHECK: llvm.func @MLIR_AsyncRT_SyncAwait +// CHECK: llvm.func @MLIR_AsyncRT_CreateHandle +// CHECK: llvm.func @MLIR_AsyncRT_EmplaceHandle +// CHECK: llvm.func @MLIR_AsyncRT_DefaultRuntime diff --git a/mlir/test/mlir-cpu-runner/async.mlir b/mlir/test/mlir-cpu-runner/async.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/async.mlir @@ -0,0 +1,96 @@ +// RUN: mlir-opt %s -convert-async-to-llvm -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main0 -entry-point-result=void -O0 \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \ +// RUN: | FileCheck %s -check-prefix=CHECK0 + +// RUN: mlir-opt %s -convert-async-to-llvm -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main1 -entry-point-result=void -O0 \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \ +// RUN: | FileCheck %s -check-prefix=CHECK1 + +// RUN: mlir-opt %s -convert-async-to-llvm -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main2 -entry-point-result=void -O0 \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \ +// RUN: | FileCheck %s -check-prefix=CHECK2 --dump-input=always + +func @print_i32(%n : i32) -> () + +func @print_123() { + %c123 = constant 123 : i32 + call @print_i32(%c123): (i32) -> () + return +} + +// ===--------------------------------------------------------------------=== // +// Create and emplace handles synchronously. +// ===--------------------------------------------------------------------=== // + +func @main0() { + %runtime = async.default_runtime : !async.runtime + %handle = async.create_handle %runtime : (!async.runtime) -> !async.handle + async.emplace_handle %handle : !async.handle + async.sync_await %handle : !async.handle + + // CHECK0: 123 + %c123 = constant 123 : i32 + call @print_i32(%c123): (i32) -> () + + return +} + +// ===--------------------------------------------------------------------=== // +// Wait for an async function synchronously. +// ===--------------------------------------------------------------------=== // + +func @main1() { + %runtime = async.default_runtime : !async.runtime + %handle = async.call %runtime : !async.runtime + @print_123() : () -> !async.handle + async.sync_await %handle : !async.handle + + // CHECK1: 123456 + %c456 = constant 456 : i32 + call @print_i32(%c456): (i32) -> () + + return +} + +// ===--------------------------------------------------------------------=== // +// Convert `async_function` to LLVM coroutine and call it asynchronously. +// ===--------------------------------------------------------------------=== // + +func @async_function(%arg0: !async.runtime) -> !async.handle { + %c789 = constant 789 : i32 + %ret_handle = async.create_handle %arg0 : (!async.runtime) -> !async.handle + + %handle = async.call %arg0 : !async.runtime @print_123() : () -> !async.handle + + // This will be printed before 123 because it is running in the caller thread. + call @print_i32(%c789): (i32) -> () + + async.await %handle : !async.handle + + // This line will wait for the async task completion. + call @print_i32(%c789): (i32) -> () + + async.emplace_handle %ret_handle : !async.handle + return %ret_handle : !async.handle +} + +func @main2() { + %runtime = async.default_runtime : !async.runtime + %handle = call @async_function(%runtime): (!async.runtime) -> !async.handle + async.sync_await %handle : !async.handle + + // CHECK2: 789123789456 + %c456 = constant 456 : i32 + call @print_i32(%c456): (i32) -> () + + return +} diff --git a/mlir/test/mlir-cpu-runner/coroutines.mlir b/mlir/test/mlir-cpu-runner/coroutines.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/coroutines.mlir @@ -0,0 +1,89 @@ +// RUN: mlir-opt %s -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main -entry-point-result=void -O0 \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \ +// RUN: | FileCheck %s + + +// A basic coroutines example in LLVM IR: +// (1) Call a coroutine and wait for the suspension +// (2) Resume coroutine from a suspension point. +// +// Printed numbers: +// (1) 123 - in main function +// (2) 456 - in main after coroutine suspension +// (3) 789 - in coroutine resume function + +func @print_i32(%n : !llvm.i32) -> () + +func @f(%n : !llvm.i32) -> !llvm<"i8*"> { + %0 = llvm.mlir.constant(0 : i32) : !llvm.i32 + + %false = llvm.mlir.constant(0 : i1) : !llvm.i1 + %null = llvm.mlir.null : !llvm<"i8*"> + + %id = llvm.call @llvm.coro.id(%0, %null, %null, %null): (!llvm.i32, !llvm<"i8*">, !llvm<"i8*">, !llvm<"i8*">)-> !llvm<"token"> + %size = llvm.call @llvm.coro.size.i64(): () -> !llvm.i64 + %alloc = llvm.call @malloc(%size): (!llvm.i64) -> !llvm<"i8*"> + %hdl = llvm.call @llvm.coro.begin(%id, %alloc): (!llvm<"token">, !llvm<"i8*">) -> !llvm<"i8*"> + + call @print_i32(%n): (!llvm.i32) -> () + + %save = llvm.call @llvm.coro.save(%hdl): (!llvm<"i8*">) -> !llvm<"token"> + %suspend = llvm.call @llvm.coro.suspend(%save, %false): (!llvm<"token">, !llvm.i1) -> !llvm.i8 + + %00 = llvm.mlir.constant(0 : i8) : !llvm.i8 + %n1 = llvm.mlir.constant(-1 : i8) : !llvm.i8 + + %is_zero = llvm.icmp "eq" %suspend, %00 : !llvm.i8 + %is_none = llvm.icmp "eq" %suspend, %n1 : !llvm.i8 + + llvm.cond_br %is_none, ^suspend, ^resume_or_cleanup +^resume_or_cleanup: + llvm.cond_br %is_zero, ^resume, ^cleanup + +^resume: + %c789 = llvm.mlir.constant(789 : i32) : !llvm.i32 + call @print_i32(%c789): (!llvm.i32)-> () + br ^cleanup + +^cleanup: + %mem = llvm.call @llvm.coro.free(%id, %hdl): (!llvm<"token">, !llvm<"i8*">) -> !llvm<"i8*"> + llvm.call @free(%mem): (!llvm<"i8*">) -> () + br ^suspend + +^suspend: + llvm.call @llvm.coro.end(%hdl, %false): (!llvm<"i8*">, !llvm.i1) -> !llvm.i1 + return %hdl : !llvm<"i8*"> +} + +func @main() { + %c123 = llvm.mlir.constant(123 : i32) : !llvm.i32 + %c456 = llvm.mlir.constant(456 : i32) : !llvm.i32 + + %hdl = call @f(%c123): (!llvm.i32) -> !llvm<"i8*"> + call @print_i32(%c456): (!llvm.i32)-> () + llvm.call @llvm.coro.resume(%hdl): (!llvm<"i8*">) -> () + + return +} + +// CHECK: 123456789 + + +// Declarations of C library functions. +llvm.func @malloc(!llvm.i64) -> !llvm<"i8*"> +llvm.func @free(!llvm<"i8*">) + +// Declaration of Coroutines LLVM intrinsics. +llvm.func @llvm.coro.id(!llvm.i32, !llvm<"i8*">, !llvm<"i8*">, !llvm<"i8*">) -> !llvm<"token"> +llvm.func @llvm.coro.size.i64() -> !llvm.i64 +llvm.func @llvm.coro.size.i32() -> !llvm.i32 +llvm.func @llvm.coro.begin(!llvm<"token">, !llvm<"i8*">) -> !llvm<"i8*"> +llvm.func @llvm.coro.save(!llvm<"i8*">) -> !llvm<"token"> +llvm.func @llvm.coro.suspend(!llvm<"token">, !llvm.i1) -> !llvm.i8 +llvm.func @llvm.coro.end(!llvm<"i8*">, !llvm.i1) -> !llvm.i1 +llvm.func @llvm.coro.free(!llvm<"token">, !llvm<"i8*">) -> !llvm<"i8*"> +llvm.func @llvm.coro.resume(!llvm<"i8*">) -> () +llvm.func @llvm.coro.destroy(!llvm<"i8*">) -> ()