diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h --- a/mlir/include/mlir/Dialect/Async/IR/Async.h +++ b/mlir/include/mlir/Dialect/Async/IR/Async.h @@ -29,6 +29,10 @@ struct ValueTypeStorage; } // namespace detail +//===----------------------------------------------------------------------===// +// Async dialect types. +//===----------------------------------------------------------------------===// + /// The token type to represent asynchronous operation completion. class TokenType : public Type::TypeBase { public: @@ -53,9 +57,32 @@ using Base::Base; }; -// -------------------------------------------------------------------------- // +//===----------------------------------------------------------------------===// +// LLVM coroutines types. +//===----------------------------------------------------------------------===// + +/// The type identifying a switched-resume coroutine. +class CoroIdType : public Type::TypeBase { +public: + using Base::Base; +}; + +/// The coroutine handle type which is a pointer to the coroutine frame. +class CoroHandleType + : public Type::TypeBase { +public: + using Base::Base; +}; + +/// The coroutine saved state type. +class CoroStateType : public Type::TypeBase { +public: + using Base::Base; +}; + +//===----------------------------------------------------------------------===// // Helper functions of Async dialect transformations. -// -------------------------------------------------------------------------- // +//===----------------------------------------------------------------------===// /// Returns true if the type is reference counted. All async dialect types are /// reference counted at runtime. diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td @@ -77,4 +77,39 @@ Async_TokenType, Async_GroupType]>; +//===----------------------------------------------------------------------===// +// Types for lowering to LLVM + Async Runtime via the LLVM Coroutines. +//===----------------------------------------------------------------------===// + +// LLVM coroutines intrinsics use `token` and `i8*` types to represent coroutine +// identifiers and handles. To define type-safe Async Runtime operations and +// build a properly typed intermediate IR during the Async to LLVM lowering we +// define a separate types for values that can be produced by LLVM intrinsics. + +def Async_CoroIdType : DialectType()">, "coro.id type">, + BuildableType<"$_builder.getType<::mlir::async::CoroIdType>()"> { + let description = [{ + `async.coro.id` is a type identifying a switched-resume coroutine. + }]; +} + +def Async_CoroHandleType : DialectType()">, "coro.handle type">, + BuildableType<"$_builder.getType<::mlir::async::CoroHandleType>()"> { + let description = [{ + `async.coro.handle` is a handle to the coroutine (pointer to the coroutine + frame) that can be passed around to resume or destroy the coroutine. + }]; +} + +def Async_CoroStateType : DialectType()">, "coro.state type">, + BuildableType<"$_builder.getType<::mlir::async::CoroStateType>()"> { + let description = [{ + `async.coro.state` is a saved coroutine state that should be passed to the + coroutine suspension operation. + }]; +} + #endif // ASYNC_BASE_TD diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -170,7 +170,6 @@ ``` }]; - let arguments = (ins ); let results = (outs Async_GroupType:$result); let assemblyFormat = "attr-dict"; @@ -228,9 +227,210 @@ } //===----------------------------------------------------------------------===// -// Async Dialect Automatic Reference Counting Operations. +// Async Dialect LLVM Coroutines Operations. //===----------------------------------------------------------------------===// +// Async to LLVM dialect lowering converts async tasks (regions inside async +// execute operations) to LLVM coroutines [1], and relies on switched-resume +// lowering [2] to produce an asynchronous executable. +// +// We define LLVM coro intrinsics in the async dialect to facilitate progressive +// lowering with verifiable and type-safe IR during the multi-step lowering +// pipeline. First we convert from high level async operations (e.g. execute) to +// the explicit calls to coro intrinsics and runtime API, and then finalize +// lowering to LLVM with a simple dialect conversion pass. +// +// [1] https://llvm.org/docs/Coroutines.html +// [2] https://llvm.org/docs/Coroutines.html#switched-resume-lowering + +def Async_CoroIdOp : Async_Op<"coro.id"> { + let summary = "returns a switched-resume coroutine identifier"; + let description = [{ + The `async.coro.id` returns a switched-resume coroutine identifier. + }]; + + let results = (outs Async_CoroIdType:$id); + let assemblyFormat = "attr-dict"; +} + +def Async_CoroBeginOp : Async_Op<"coro.begin"> { + let summary = "returns a handle to the coroutine"; + let description = [{ + The `async.coro.begin` allocates a coroutine frame and returns a handle to + the coroutine. + }]; + + let arguments = (ins Async_CoroIdType:$id); + let results = (outs Async_CoroHandleType:$handle); + let assemblyFormat = "$id attr-dict"; +} + +def Async_CoroFreeOp : Async_Op<"coro.free"> { + let summary = "deallocates the coroutine frame"; + let description = [{ + The `async.coro.free` deallocates the coroutine frame created by the + async.coro.begin operation. + }]; + + let arguments = (ins Async_CoroIdType:$id, + Async_CoroHandleType:$handle); + let assemblyFormat = "$id `,` $handle attr-dict"; +} + +def Async_CoroEndOp : Async_Op<"coro.end"> { + let summary = "marks the end of the coroutine in the suspend block"; + let description = [{ + The `async.coro.end` marks the point where a coroutine needs to return + control back to the caller if it is not an initial invocation of the + coroutine. It the start part of the coroutine is is no-op. + }]; + + let arguments = (ins Async_CoroHandleType:$handle); + let assemblyFormat = "$handle attr-dict"; +} + +def Async_CoroSaveOp : Async_Op<"coro.save"> { + let summary = "saves the coroutine state"; + let description = [{ + The `async.coro.saves` saves the coroutine state. + }]; + + let arguments = (ins Async_CoroHandleType:$handle); + let results = (outs Async_CoroStateType:$state); + let assemblyFormat = "$handle attr-dict"; +} + +def Async_CoroSuspendOp : Async_Op<"coro.suspend", [Terminator]> { + let summary = "suspends the coroutine"; + let description = [{ + The `async.coro.suspend` suspends the coroutine and transfers control to the + `suspend` successor. If suspended coroutine later resumed it will transfer + control to the `resume` successor. If it destroyed it will transfer control + to the the `cleanup` successor. + + In switched-resume lowering coroutine can be already in resumed state when + suspend operation is called, in this case control will be transferred to the + `resume` successor skipping the `suspend` successor. + }]; + + let arguments = (ins Async_CoroStateType:$state); + let successors = (successor AnySuccessor:$suspendDest, + AnySuccessor:$resumeDest, + AnySuccessor:$cleanupDest); + let assemblyFormat = + "$state `,` $suspendDest `,` $resumeDest `,` $cleanupDest attr-dict"; +} + +//===----------------------------------------------------------------------===// +// Async Dialect Runtime Operations. +//===----------------------------------------------------------------------===// + +// The following operations are intermediate async dialect operations to help +// lowering from high level async operation like `async.execute` to the Async +// Runtime API defined in the `ExecutionEngine/AsyncRuntime.h`. + +def Async_RuntimeCreateOp : Async_Op<"runtime.create"> { + let summary = "creates and async runtime value (token, value or group)"; + let description = [{ + The `async.runtime.create` operation creates an async dialect value + (token, value or group). Tokens and values are created in non-ready state. + Groups are created in empty state. + }]; + + let results = (outs Async_AnyAsyncType:$result); + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def Async_RuntimeSetAvailableOp : Async_Op<"runtime.set_available"> { + let summary = "switches token or value available state"; + let description = [{ + The `async.runtime.set_available` operation switches async token or value + state to available. + }]; + + let arguments = (ins Async_AnyValueOrTokenType:$operand); + let assemblyFormat = "$operand attr-dict `:` type($operand)"; +} + +def Async_RuntimeAwaitOp : Async_Op<"runtime.await"> { + let summary = "blocks the caller thread until the operand becomes available"; + let description = [{ + The `async.runtime.await` operation blocks the caller thread until the + operand becomes available. + }]; + + let arguments = (ins Async_AnyAsyncType:$operand); + let assemblyFormat = "$operand attr-dict `:` type($operand)"; +} + +def Async_RuntimeResumeOp : Async_Op<"runtime.resume"> { + let summary = "resumes the coroutine on a thread managed by the runtime"; + let description = [{ + The `async.runtime.resume` operation resumes the coroutine on a thread + managed by the runtime. + }]; + + let arguments = (ins Async_CoroHandleType:$handle); + let assemblyFormat = "$handle attr-dict"; +} + +def Async_RuntimeAwaitAndResumeOp : Async_Op<"runtime.await_and_resume"> { + let summary = "awaits the async operand and resumes the coroutine"; + let description = [{ + The `async.runtime.await_and_resume` operation awaits for the operand to + become available and resumes the coroutine on a thread managed by the + runtime. + }]; + + let arguments = (ins Async_AnyAsyncType:$operand, + Async_CoroHandleType:$handle); + let assemblyFormat = "$operand `,` $handle attr-dict `:` type($operand)"; +} + +def Async_RuntimeStoreOp : Async_Op<"runtime.store", + [TypesMatchWith<"type of 'value' matches element type of 'storage'", + "storage", "value", + "$_self.cast().getValueType()">]> { + let summary = "stores the value into the runtime async.value"; + let description = [{ + The `async.runtime.store` operation stores the value into the runtime + async.value storage. + }]; + + let arguments = (ins AnyType:$value, + Async_AnyValueType:$storage); + let assemblyFormat = "$value `,` $storage attr-dict `:` type($storage)"; +} + +def Async_RuntimeLoadOp : Async_Op<"runtime.load", + [TypesMatchWith<"type of 'value' matches element type of 'storage'", + "storage", "result", + "$_self.cast().getValueType()">]> { + let summary = "loads the value from the runtime async.value"; + let description = [{ + The `async.runtime.load` operation loads the value from the runtime + async.value storage. + }]; + + let arguments = (ins Async_AnyValueType:$storage); + let results = (outs AnyType:$result); + let assemblyFormat = "$storage attr-dict `:` type($storage)"; +} + +def Async_RuntimeAddToGroupOp : Async_Op<"runtime.add_to_group", []> { + let summary = "adds and async token or value to the group"; + let description = [{ + The `async.runtime.add_to_group` adds an async token or value to the async + group. Returns the rank of the added element in the group. + }]; + + let arguments = (ins Async_AnyValueOrTokenType:$operand, + Async_GroupType:$group); + let results = (outs Index:$rank); + + let assemblyFormat = "$operand `,` $group attr-dict `:` type($operand)"; +} + // All async values (values, tokens, groups) are reference counted at runtime // and automatically destructed when reference count drops to 0. // @@ -253,32 +453,30 @@ // // See `AsyncRefCountingPass` documentation for the implementation details. -def Async_AddRefOp : Async_Op<"add_ref"> { +def Async_RuntimeAddRefOp : Async_Op<"runtime.add_ref"> { let summary = "adds a reference to async value"; let description = [{ - The `async.add_ref` operation adds a reference(s) to async value (token, - value or group). + The `async.runtime.add_ref` operation adds a reference(s) to async value + (token, value or group). }]; let arguments = (ins Async_AnyAsyncType:$operand, Confined:$count); - let results = (outs ); let assemblyFormat = [{ $operand attr-dict `:` type($operand) }]; } -def Async_DropRefOp : Async_Op<"drop_ref"> { +def Async_RuntimeDropRefOp : Async_Op<"runtime.drop_ref"> { let summary = "drops a reference to async value"; let description = [{ - The `async.drop_ref` operation drops a reference(s) to async value (token, - value or group). + The `async.runtime.drop_ref` operation drops a reference(s) to async value + (token, value or group). }]; let arguments = (ins Async_AnyAsyncType:$operand, Confined:$count); - let results = (outs ); let assemblyFormat = [{ $operand attr-dict `:` type($operand) diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -69,6 +69,10 @@ return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); } + static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { + return LLVM::LLVMTokenType::get(ctx); + } + static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { auto ref = opaquePointerType(ctx); auto count = IntegerType::get(ctx, 32); @@ -317,13 +321,16 @@ Value asyncToken; // token representing completion of the async region llvm::SmallVector returnValues; // returned async values - Value coroHandle; - Block *cleanup; - Block *suspend; + Value coroHandle; // coroutine handle (!async.coro.handle value) + Block *cleanup; // coroutine cleanup block + Block *suspend; // coroutine suspension block }; } // namespace -/// Builds an coroutine template compatible with LLVM coroutines lowering. +/// Builds an coroutine template compatible with LLVM coroutines switched-resume +/// lowering using `async.runtime.*` and `async.coro.*` operations. +/// +/// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html /// /// - `entry` block sets up the coroutine. /// - `cleanup` block cleans up the coroutine state. @@ -336,18 +343,19 @@ /// func @async_execute_fn() /// -> (!async.token, !async.value) /// { -/// ^entryBlock(): +/// ^entry(): /// %token = : !async.token // create async runtime token /// %value = : !async.value // create async value -/// %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle +/// %id = async.coro.id // create a coroutine id +/// %hdl = async.coro.begin %id // create a coroutine handle /// br ^cleanup /// /// ^cleanup: -/// llvm.call @llvm.coro.free(...) // delete coroutine state +/// async.coro.free %hdl // delete the coroutine state /// br ^suspend /// /// ^suspend: -/// llvm.call @llvm.coro.end(...) // marks the end of a coroutine +/// async.coro.end %hdl // marks the end of a coroutine /// return %token, %value : !async.token, !async.value /// } /// @@ -359,85 +367,25 @@ assert(func.getBody().empty() && "Function must have empty body"); MLIRContext *ctx = func.getContext(); - - auto token = LLVM::LLVMTokenType::get(ctx); - auto i1 = IntegerType::get(ctx, 1); - auto i32 = IntegerType::get(ctx, 32); - auto i64 = IntegerType::get(ctx, 64); - auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); - Block *entryBlock = func.addEntryBlock(); - Location loc = func.getBody().getLoc(); - auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, entryBlock); + auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); // ------------------------------------------------------------------------ // - // Allocate async tokens/values that we will return from a ramp function. + // Allocate async token/values that we will return from a ramp function. // ------------------------------------------------------------------------ // - auto createToken = builder.create(kCreateToken, TokenType::get(ctx)); - - // Async value operands and results must be convertible to LLVM types. This is - // verified before the function outlining. - LLVMTypeConverter converter(ctx); - - // Returns the size requirements for the async value storage. - // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt - auto sizeOf = [&](ValueType valueType) -> Value { - auto storedType = converter.convertType(valueType.getValueType()); - auto storagePtrType = LLVM::LLVMPointerType::get(storedType); - - // %Size = getelementptr %T* null, int 1 - // %SizeI = ptrtoint %T* %Size to i32 - auto nullPtr = builder.create(loc, storagePtrType); - auto one = builder.create(loc, i32, - builder.getI32IntegerAttr(1)); - auto gep = builder.create(loc, storagePtrType, nullPtr, - one.getResult()); - return builder.create(loc, i32, gep); - }; - - // We use the `async.value` type as a return type although it does not match - // the `kCreateValue` function signature, because it will be later lowered to - // the runtime type (opaque i8* pointer). - llvm::SmallVector createValues; - for (auto resultType : func.getCallableResults().drop_front(1)) - createValues.emplace_back(builder.create( - loc, kCreateValue, resultType, sizeOf(resultType.cast()))); + auto retToken = builder.create(TokenType::get(ctx)).result(); - auto createdValues = llvm::map_range( - createValues, [](CallOp call) { return call.getResult(0); }); - llvm::SmallVector returnValues(createdValues.begin(), - createdValues.end()); + llvm::SmallVector retValues; + for (auto resType : func.getCallableResults().drop_front()) + retValues.emplace_back(builder.create(resType).result()); // ------------------------------------------------------------------------ // - // Initialize coroutine: allocate frame, get coroutine handle. + // Initialize coroutine: get coroutine id and coroutine handle. // ------------------------------------------------------------------------ // - - // Constants for initializing coroutine frame. - auto constZero = - builder.create(i32, builder.getI32IntegerAttr(0)); - auto constFalse = - builder.create(i1, builder.getBoolAttr(false)); - auto nullPtr = builder.create(i8Ptr); - - // Get coroutine id: @llvm.coro.id - auto coroId = builder.create( - token, builder.getSymbolRefAttr(kCoroId), - ValueRange({constZero, nullPtr, nullPtr, nullPtr})); - - // Get coroutine frame size: @llvm.coro.size.i64 - auto coroSize = builder.create( - i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange()); - - // Allocate memory for coroutine frame. - auto coroAlloc = - builder.create(i8Ptr, builder.getSymbolRefAttr(kMalloc), - ValueRange(coroSize.getResult(0))); - - // Begin a coroutine: @llvm.coro.begin - auto coroHdl = builder.create( - i8Ptr, builder.getSymbolRefAttr(kCoroBegin), - ValueRange({coroId.getResult(0), coroAlloc.getResult(0)})); + auto coroIdOp = builder.create(CoroIdType::get(ctx)); + auto coroHdlOp = + builder.create(CoroHandleType::get(ctx), coroIdOp.id()); Block *cleanupBlock = func.addBlock(); Block *suspendBlock = func.addBlock(); @@ -446,15 +394,8 @@ // Coroutine cleanup block: deallocate coroutine frame, free the memory. // ------------------------------------------------------------------------ // builder.setInsertionPointToStart(cleanupBlock); + builder.create(coroIdOp.id(), coroHdlOp.handle()); - // Get a pointer to the coroutine frame memory: @llvm.coro.free. - auto coroMem = builder.create( - i8Ptr, builder.getSymbolRefAttr(kCoroFree), - ValueRange({coroId.getResult(0), coroHdl.getResult(0)})); - - // Free the memory. - builder.create(TypeRange(), builder.getSymbolRefAttr(kFree), - ValueRange(coroMem.getResult(0))); // Branch into the suspend block. builder.create(suspendBlock); @@ -464,107 +405,31 @@ // ------------------------------------------------------------------------ // builder.setInsertionPointToStart(suspendBlock); - // Mark the end of a coroutine: @llvm.coro.end. - builder.create(i1, builder.getSymbolRefAttr(kCoroEnd), - ValueRange({coroHdl.getResult(0), constFalse})); + // Mark the end of a coroutine: async.coro.end + builder.create(coroHdlOp.handle()); // Return created `async.token` and `async.values` from the suspend block. // This will be the return value of a coroutine ramp function. - SmallVector ret{createToken.getResult(0)}; - ret.insert(ret.end(), returnValues.begin(), returnValues.end()); - builder.create(loc, ret); + SmallVector ret{retToken}; + ret.insert(ret.end(), retValues.begin(), retValues.end()); + builder.create(ret); // Branch from the entry block to the cleanup block to create a valid CFG. builder.setInsertionPointToEnd(entryBlock); - builder.create(cleanupBlock); // `async.await` op lowering will create resume blocks for async // continuations, and will conditionally branch to cleanup or suspend blocks. CoroMachinery machinery; - machinery.asyncToken = createToken.getResult(0); - machinery.returnValues = returnValues; - machinery.coroHandle = coroHdl.getResult(0); + machinery.asyncToken = retToken; + machinery.returnValues = retValues; + machinery.coroHandle = coroHdlOp.handle(); machinery.cleanup = cleanupBlock; machinery.suspend = suspendBlock; return machinery; } -/// Add a LLVM coroutine suspension point to the end of suspended block, to -/// resume execution in resume block. The caller is responsible for creating the -/// two suspended/resume blocks with the desired ops contained in each block. -/// This function merely provides the required control flow logic. -/// -/// `coroState` must be a value returned from the call to @llvm.coro.save(...) -/// intrinsic (saved coroutine state). -/// -/// Before: -/// -/// ^bb0: -/// "opBefore"(...) -/// "op"(...) -/// ^cleanup: ... -/// ^suspend: ... -/// ^resume: -/// "op"(...) -/// -/// After: -/// -/// ^bb0: -/// "opBefore"(...) -/// %suspend = llmv.call @llvm.coro.suspend(...) -/// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] -/// ^resume: -/// "op"(...) -/// ^cleanup: ... -/// ^suspend: ... -/// -static void addSuspensionPoint(CoroMachinery coro, Value coroState, - Operation *op, Block *suspended, Block *resume, - OpBuilder &builder) { - Location loc = op->getLoc(); - MLIRContext *ctx = op->getContext(); - auto i1 = IntegerType::get(ctx, 1); - auto i8 = IntegerType::get(ctx, 8); - - // Add a coroutine suspension in place of original `op` in the split block. - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToEnd(suspended); - - auto constFalse = - builder.create(loc, i1, builder.getBoolAttr(false)); - - // Suspend a coroutine: @llvm.coro.suspend - auto coroSuspend = builder.create( - loc, i8, builder.getSymbolRefAttr(kCoroSuspend), - ValueRange({coroState, constFalse})); - - // After a suspension point decide if we should branch into resume, cleanup - // or suspend block of the coroutine (see @llvm.coro.suspend return code - // documentation). - auto constZero = - builder.create(loc, i8, builder.getI8IntegerAttr(0)); - auto constNegOne = - builder.create(loc, i8, builder.getI8IntegerAttr(-1)); - - Block *resumeOrCleanup = builder.createBlock(resume); - - // Suspend the coroutine ...? - builder.setInsertionPointToEnd(suspended); - auto isNegOne = builder.create( - loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne); - builder.create(loc, isNegOne, /*trueDest=*/coro.suspend, - /*falseDest=*/resumeOrCleanup); - - // ... or resume or cleanup the coroutine? - builder.setInsertionPointToStart(resumeOrCleanup); - auto isZero = builder.create( - loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero); - builder.create(loc, isZero, /*trueDest=*/resume, - /*falseDest=*/coro.cleanup); -} - /// Outline the body region attached to the `async.execute` op into a standalone /// function. /// @@ -599,35 +464,31 @@ SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); // Prepare a function for coroutine lowering by adding entry/cleanup/suspend - // blocks, adding llvm.coro instrinsics and setting up control flow. + // blocks, adding async.coro operations and setting up control flow. CoroMachinery coro = setupCoroMachinery(func); // Suspend async function at the end of an entry block, and resume it using - // Async execute API (execution will be resumed in a thread managed by the - // async runtime). + // Async resume operation (execution will be resumed in a thread managed by + // the async runtime). Block *entryBlock = &func.getBlocks().front(); auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock); - // A pointer to coroutine resume intrinsic wrapper. - auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); - auto resumePtr = builder.create( - LLVM::LLVMPointerType::get(resumeFnTy), kResume); + // Save the coroutine state: async.coro.save + auto coroSaveOp = + builder.create(CoroStateType::get(ctx), coro.coroHandle); - // Save the coroutine state: @llvm.coro.save - auto coroSave = builder.create( - LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), - ValueRange({coro.coroHandle})); + // Pass coroutine to the runtime to be resumed on a runtime managed thread. + builder.create(coro.coroHandle); - // Call async runtime API to execute a coroutine in the managed thread. - SmallVector executeArgs = {coro.coroHandle, resumePtr.res()}; - builder.create(TypeRange(), kExecute, executeArgs); - - // Split the entry block before the terminator. + // Split the entry block before the terminator (branch to suspend block). auto *terminatorOp = entryBlock->getTerminator(); Block *suspended = terminatorOp->getBlock(); Block *resume = suspended->splitBlock(terminatorOp); - addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended, - resume, builder); + + // Add async.coro.suspend as a suspended block terminator. + builder.setInsertionPointToEnd(suspended); + builder.create(coroSaveOp.state(), coro.suspend, resume, + coro.cleanup); size_t numDependencies = execute.dependencies().size(); size_t numOperands = execute.operands().size(); @@ -670,7 +531,6 @@ //===----------------------------------------------------------------------===// namespace { - /// AsyncRuntimeTypeConverter only converts types from the Async dialect to /// their runtime type (opaque pointers) and does not convert any other types. class AsyncRuntimeTypeConverter : public TypeConverter { @@ -683,56 +543,547 @@ static Optional convertAsyncTypes(Type type) { if (type.isa()) return AsyncAPI::opaquePointerType(type.getContext()); + + if (type.isa()) + return AsyncAPI::tokenType(type.getContext()); + if (type.isa()) + return AsyncAPI::opaquePointerType(type.getContext()); + return llvm::None; } }; } // namespace //===----------------------------------------------------------------------===// -// Convert return operations that return async values from async regions. +// Convert async.coro.id to @llvm.coro.id intrinsic. //===----------------------------------------------------------------------===// namespace { -class ReturnOpOpConversion : public ConversionPattern { +class CoroIdOpConversion : public OpConversionPattern { public: - explicit ReturnOpOpConversion(TypeConverter &converter, MLIRContext *ctx) - : ConversionPattern(ReturnOp::getOperationName(), 1, converter, ctx) {} + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(CoroIdOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands); + auto token = AsyncAPI::tokenType(op->getContext()); + auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); + auto loc = op->getLoc(); + + // Constants for initializing coroutine frame. + auto constZero = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); + auto nullPtr = rewriter.create(loc, i8Ptr); + + // Get coroutine id: @llvm.coro.id. + rewriter.replaceOpWithNewOp( + op, token, rewriter.getSymbolRefAttr(kCoroId), + ValueRange({constZero, nullPtr, nullPtr, nullPtr})); + + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.coro.begin to @llvm.coro.begin intrinsic. +//===----------------------------------------------------------------------===// + +namespace { +class CoroBeginOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CoroBeginOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); + auto loc = op->getLoc(); + + // Get coroutine frame size: @llvm.coro.size.i64. + auto coroSize = rewriter.create( + loc, rewriter.getI64Type(), rewriter.getSymbolRefAttr(kCoroSizeI64), + ValueRange()); + + // Allocate memory for the coroutine frame. + auto coroAlloc = rewriter.create( + loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc), + ValueRange(coroSize.getResult(0))); + + // Begin a coroutine: @llvm.coro.begin. + auto coroId = CoroBeginOpAdaptor(operands).id(); + rewriter.replaceOpWithNewOp( + op, i8Ptr, rewriter.getSymbolRefAttr(kCoroBegin), + ValueRange({coroId, coroAlloc.getResult(0)})); + + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.coro.free to @llvm.coro.free intrinsic. +//===----------------------------------------------------------------------===// + +namespace { +class CoroFreeOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CoroFreeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); + auto loc = op->getLoc(); + + // Get a pointer to the coroutine frame memory: @llvm.coro.free. + auto coroMem = rewriter.create( + loc, i8Ptr, rewriter.getSymbolRefAttr(kCoroFree), operands); + + // Free the memory. + rewriter.replaceOpWithNewOp(op, TypeRange(), + rewriter.getSymbolRefAttr(kFree), + ValueRange(coroMem.getResult(0))); + + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.coro.end to @llvm.coro.end intrinsic. +//===----------------------------------------------------------------------===// + +namespace { +class CoroEndOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CoroEndOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // We are not in the block that is part of the unwind sequence. + auto constFalse = rewriter.create( + op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); + + // Mark the end of a coroutine: @llvm.coro.end. + auto coroHdl = CoroEndOpAdaptor(operands).handle(); + rewriter.create(op->getLoc(), rewriter.getI1Type(), + rewriter.getSymbolRefAttr(kCoroEnd), + ValueRange({coroHdl, constFalse})); + rewriter.eraseOp(op); + + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.coro.save to @llvm.coro.save intrinsic. +//===----------------------------------------------------------------------===// + +namespace { +class CoroSaveOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CoroSaveOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // Save the coroutine state: @llvm.coro.save + rewriter.replaceOpWithNewOp( + op, AsyncAPI::tokenType(op->getContext()), + rewriter.getSymbolRefAttr(kCoroSave), operands); + + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.coro.suspend to @llvm.coro.suspend intrinsic. +//===----------------------------------------------------------------------===// + +namespace { + +/// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and +/// branch to the appropriate block based on the return code. +/// +/// Before: +/// +/// ^suspended: +/// "opBefore"(...) +/// async.coro.suspend %state, ^suspend, ^resume, ^cleanup +/// ^resume: +/// "op"(...) +/// ^cleanup: ... +/// ^suspend: ... +/// +/// After: +/// +/// ^suspended: +/// "opBefore"(...) +/// %suspend = llmv.call @llvm.coro.suspend(...) +/// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] +/// ^resume: +/// "op"(...) +/// ^cleanup: ... +/// ^suspend: ... +/// +class CoroSuspendOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CoroSuspendOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto i8 = rewriter.getIntegerType(8); + auto i32 = rewriter.getI32Type(); + auto loc = op->getLoc(); + + // This is not a final suspension point. + auto constFalse = rewriter.create( + loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); + + // Suspend a coroutine: @llvm.coro.suspend + auto coroState = CoroSuspendOpAdaptor(operands).state(); + auto coroSuspend = rewriter.create( + loc, i8, rewriter.getSymbolRefAttr(kCoroSuspend), + ValueRange({coroState, constFalse})); + + // Cast return code to i32. + + // After a suspension point decide if we should branch into resume, cleanup + // or suspend block of the coroutine (see @llvm.coro.suspend return code + // documentation). + llvm::SmallVector caseValues = {0, 1}; + llvm::SmallVector caseDest = {op.resumeDest(), + op.cleanupDest()}; + rewriter.replaceOpWithNewOp( + op, rewriter.create(loc, i32, coroSuspend.getResult(0)), + /*defaultDestination=*/op.suspendDest(), + /*defaultOperands=*/ValueRange(), + /*caseValues=*/caseValues, + /*caseDestinations=*/caseDest, + /*caseOperands=*/ArrayRef(), + /*branchWeights=*/ArrayRef()); + return success(); } }; } // namespace //===----------------------------------------------------------------------===// -// Async reference counting ops lowering (`async.add_ref` and `async.drop_ref` -// to the corresponding API calls). +// Convert async.runtime.create to the corresponding runtime API call. +// +// To allocate storage for the async values we use getelementptr trick: +// http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt +//===----------------------------------------------------------------------===// + +namespace { +class RuntimeCreateOpLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(RuntimeCreateOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + TypeConverter *converter = getTypeConverter(); + Type resultType = op->getResultTypes()[0]; + + // Tokens and Groups lowered to function calls without arguments. + if (resultType.isa() || resultType.isa()) { + rewriter.replaceOpWithNewOp( + op, resultType.isa() ? kCreateToken : kCreateGroup, + converter->convertType(resultType)); + return success(); + } + + // To create a value we need to compute the storage requirement. + if (auto value = resultType.dyn_cast()) { + // Returns the size requirements for the async value storage. + auto sizeOf = [&](ValueType valueType) -> Value { + auto loc = op->getLoc(); + auto i32 = rewriter.getI32Type(); + + auto storedType = converter->convertType(valueType.getValueType()); + auto storagePtrType = LLVM::LLVMPointerType::get(storedType); + + // %Size = getelementptr %T* null, int 1 + // %SizeI = ptrtoint %T* %Size to i32 + auto nullPtr = rewriter.create(loc, storagePtrType); + auto one = rewriter.create( + loc, i32, rewriter.getI32IntegerAttr(1)); + auto gep = rewriter.create(loc, storagePtrType, nullPtr, + one.getResult()); + return rewriter.create(loc, i32, gep); + }; + + rewriter.replaceOpWithNewOp(op, kCreateValue, resultType, + sizeOf(value)); + + return success(); + } + + return rewriter.notifyMatchFailure(op, "unsupported async type"); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.runtime.set_available to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { +class RuntimeSetAvailableOpLowering + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Type operandType = op.operand().getType(); + + if (operandType.isa() || operandType.isa()) { + rewriter.create(op->getLoc(), + operandType.isa() ? kEmplaceToken + : kEmplaceValue, + TypeRange(), operands); + rewriter.eraseOp(op); + return success(); + } + return rewriter.notifyMatchFailure(op, "unsupported async type"); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.runtime.await to the corresponding runtime API call. +//===----------------------------------------------------------------------===// + +namespace { +class RuntimeAwaitOpLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(RuntimeAwaitOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Type operandType = op.operand().getType(); + + StringRef apiFuncName; + if (operandType.isa()) + apiFuncName = kAwaitToken; + else if (operandType.isa()) + apiFuncName = kAwaitValue; + else if (operandType.isa()) + apiFuncName = kAwaitGroup; + else + return rewriter.notifyMatchFailure(op, "unsupported async type"); + + rewriter.create(op->getLoc(), apiFuncName, TypeRange(), operands); + rewriter.eraseOp(op); + + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.runtime.await_and_resume to the corresponding runtime API call. +//===----------------------------------------------------------------------===// + +namespace { +class RuntimeAwaitAndResumeOpLowering + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Type operandType = op.operand().getType(); + + StringRef apiFuncName; + if (operandType.isa()) + apiFuncName = kAwaitTokenAndExecute; + else if (operandType.isa()) + apiFuncName = kAwaitValueAndExecute; + else if (operandType.isa()) + apiFuncName = kAwaitAllAndExecute; + else + return rewriter.notifyMatchFailure(op, "unsupported async type"); + + Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand(); + Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle(); + + // A pointer to coroutine resume intrinsic wrapper. + auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); + auto resumePtr = rewriter.create( + op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); + + rewriter.create(op->getLoc(), apiFuncName, TypeRange(), + ValueRange({operand, handle, resumePtr.res()})); + rewriter.eraseOp(op); + + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.runtime.resume to the corresponding runtime API call. +//===----------------------------------------------------------------------===// + +namespace { +class RuntimeResumeOpLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(RuntimeResumeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // A pointer to coroutine resume intrinsic wrapper. + auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); + auto resumePtr = rewriter.create( + op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); + + // Call async runtime API to execute a coroutine in the managed thread. + auto coroHdl = RuntimeResumeOpAdaptor(operands).handle(); + rewriter.replaceOpWithNewOp(op, TypeRange(), kExecute, + ValueRange({coroHdl, resumePtr.res()})); + + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.runtime.store to the corresponding runtime API call. +//===----------------------------------------------------------------------===// + +namespace { +class RuntimeStoreOpLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(RuntimeStoreOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + // Get a pointer to the async value storage from the runtime. + auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); + auto storage = RuntimeStoreOpAdaptor(operands).storage(); + auto storagePtr = rewriter.create(loc, kGetValueStorage, + TypeRange(i8Ptr), storage); + + // Cast from i8* to the LLVM pointer type. + auto valueType = op.value().getType(); + auto llvmValueType = getTypeConverter()->convertType(valueType); + auto castedStoragePtr = rewriter.create( + loc, LLVM::LLVMPointerType::get(llvmValueType), + storagePtr.getResult(0)); + + // Store the yielded value into the async value storage. + auto value = RuntimeStoreOpAdaptor(operands).value(); + rewriter.create(loc, value, castedStoragePtr.getResult()); + + // Erase the original runtime store operation. + rewriter.eraseOp(op); + + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.runtime.load to the corresponding runtime API call. +//===----------------------------------------------------------------------===// + +namespace { +class RuntimeLoadOpLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(RuntimeLoadOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + // Get a pointer to the async value storage from the runtime. + auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); + auto storage = RuntimeLoadOpAdaptor(operands).storage(); + auto storagePtr = rewriter.create(loc, kGetValueStorage, + TypeRange(i8Ptr), storage); + + // Cast from i8* to the LLVM pointer type. + auto valueType = op.result().getType(); + auto llvmValueType = getTypeConverter()->convertType(valueType); + auto castedStoragePtr = rewriter.create( + loc, LLVM::LLVMPointerType::get(llvmValueType), + storagePtr.getResult(0)); + + // Load from the casted pointer. + rewriter.replaceOpWithNewOp(op, castedStoragePtr.getResult()); + + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.runtime.add_to_group to the corresponding runtime API call. +//===----------------------------------------------------------------------===// + +namespace { +class RuntimeAddToGroupOpLowering + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // Currently we can only add tokens to the group. + if (!op.operand().getType().isa()) + return rewriter.notifyMatchFailure(op, "only token type is supported"); + + // Replace with a runtime API function call. + rewriter.replaceOpWithNewOp(op, kAddTokenToGroup, + rewriter.getI64Type(), operands); + + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Async reference counting ops lowering (`async.runtime.add_ref` and +// `async.runtime.drop_ref` to the corresponding API calls). +//===----------------------------------------------------------------------===// + +namespace { template -class RefCountingOpLowering : public ConversionPattern { +class RefCountingOpLowering : public OpConversionPattern { public: explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, StringRef apiFunctionName) - : ConversionPattern(RefCountingOp::getOperationName(), 1, converter, ctx), + : OpConversionPattern(converter, ctx), apiFunctionName(apiFunctionName) {} LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(RefCountingOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - RefCountingOp refCountingOp = cast(op); - - auto count = rewriter.create( - op->getLoc(), rewriter.getI32Type(), - rewriter.getI32IntegerAttr(refCountingOp.count())); + auto count = + rewriter.create(op->getLoc(), rewriter.getI32Type(), + rewriter.getI32IntegerAttr(op.count())); + auto operand = typename RefCountingOp::Adaptor(operands).operand(); rewriter.replaceOpWithNewOp(op, TypeRange(), apiFunctionName, - ValueRange({operands[0], count})); + ValueRange({operand, count})); return success(); } @@ -741,149 +1092,143 @@ StringRef apiFunctionName; }; -/// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call. -class AddRefOpLowering : public RefCountingOpLowering { +class RuntimeAddRefOpLowering : public RefCountingOpLowering { public: - explicit AddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) + explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) : RefCountingOpLowering(converter, ctx, kAddRef) {} }; -/// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. -class DropRefOpLowering : public RefCountingOpLowering { +class RuntimeDropRefOpLowering + : public RefCountingOpLowering { public: - explicit DropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) + explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) : RefCountingOpLowering(converter, ctx, kDropRef) {} }; - } // namespace //===----------------------------------------------------------------------===// -// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. +// Convert return operations that return async values from async regions. //===----------------------------------------------------------------------===// namespace { -class CreateGroupOpLowering : public ConversionPattern { +class ReturnOpOpConversion : public OpConversionPattern { public: - explicit CreateGroupOpLowering(TypeConverter &converter, MLIRContext *ctx) - : ConversionPattern(CreateGroupOp::getOperationName(), 1, converter, - ctx) {} + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(ReturnOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto retTy = GroupType::get(op->getContext()); - rewriter.replaceOpWithNewOp(op, kCreateGroup, retTy); + rewriter.replaceOpWithNewOp(op, operands); return success(); } }; } // namespace //===----------------------------------------------------------------------===// -// async.add_to_group op lowering to runtime function call. +// Convert async.create_group operation to async.runtime.create //===----------------------------------------------------------------------===// namespace { -class AddToGroupOpLowering : public ConversionPattern { +class CreateGroupOpLowering : public OpConversionPattern { public: - explicit AddToGroupOpLowering(TypeConverter &converter, MLIRContext *ctx) - : ConversionPattern(AddToGroupOp::getOperationName(), 1, converter, ctx) { - } + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(CreateGroupOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - // Currently we can only add tokens to the group. - auto addToGroup = cast(op); - if (!addToGroup.operand().getType().isa()) - return failure(); - - auto i64 = IntegerType::get(op->getContext(), 64); - rewriter.replaceOpWithNewOp(op, kAddTokenToGroup, i64, operands); + rewriter.replaceOpWithNewOp( + op, GroupType::get(op->getContext())); return success(); } }; } // namespace //===----------------------------------------------------------------------===// -// async.await and async.await_all op lowerings to the corresponding async -// runtime function calls. +// Convert async.add_to_group operation to async.runtime.add_to_group. //===----------------------------------------------------------------------===// namespace { +class AddToGroupOpLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AddToGroupOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, rewriter.getIndexType(), operands); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.await and async.await_all operations to the async.runtime.await +// or async.runtime.await_and_resume operations. +//===----------------------------------------------------------------------===// +namespace { template -class AwaitOpLoweringBase : public ConversionPattern { -protected: - explicit AwaitOpLoweringBase( - TypeConverter &converter, MLIRContext *ctx, - const llvm::DenseMap &outlinedFunctions, - StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName) - : ConversionPattern(AwaitType::getOperationName(), 1, converter, ctx), - outlinedFunctions(outlinedFunctions), - blockingAwaitFuncName(blockingAwaitFuncName), - coroAwaitFuncName(coroAwaitFuncName) {} +class AwaitOpLoweringBase : public OpConversionPattern { + using AwaitAdaptor = typename AwaitType::Adaptor; public: + AwaitOpLoweringBase( + MLIRContext *ctx, + const llvm::DenseMap &outlinedFunctions) + : OpConversionPattern(ctx), + outlinedFunctions(outlinedFunctions) {} + LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(AwaitType op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // We can only await on one the `AwaitableType` (for `await` it can be // a `token` or a `value`, for `await_all` it must be a `group`). - auto await = cast(op); - if (!await.operand().getType().template isa()) - return failure(); + if (!op.operand().getType().template isa()) + return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); // Check if await operation is inside the outlined coroutine function. - auto func = await->template getParentOfType(); + auto func = op->template getParentOfType(); auto outlined = outlinedFunctions.find(func); const bool isInCoroutine = outlined != outlinedFunctions.end(); Location loc = op->getLoc(); + Value operand = AwaitAdaptor(operands).operand(); - // Inside regular function we convert await operation to the blocking - // async API await function call. + // Inside regular functions we use the blocking wait operation to wait for + // the async object (token, value or group) to become available. if (!isInCoroutine) - rewriter.create(loc, TypeRange(), blockingAwaitFuncName, - ValueRange(operands[0])); + rewriter.create(loc, operand); // Inside the coroutine we convert await operation into coroutine suspension // point, and resume execution asynchronously. if (isInCoroutine) { const CoroMachinery &coro = outlined->getSecond(); + Block *suspended = op->getBlock(); ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); MLIRContext *ctx = op->getContext(); - // A pointer to coroutine resume intrinsic wrapper. - auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); - auto resumePtr = builder.create( - LLVM::LLVMPointerType::get(resumeFnTy), kResume); - - // Save the coroutine state: @llvm.coro.save - auto coroSave = builder.create( - LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), - ValueRange(coro.coroHandle)); - - // Call async runtime API to resume a coroutine in the managed thread when - // the async await argument becomes ready. - SmallVector awaitAndExecuteArgs = {operands[0], coro.coroHandle, - resumePtr.res()}; - builder.create(TypeRange(), coroAwaitFuncName, - awaitAndExecuteArgs); - - Block *suspended = op->getBlock(); + // Save the coroutine state and resume on a runtime managed thread when + // the operand becomes available. + auto coroSaveOp = + builder.create(CoroStateType::get(ctx), coro.coroHandle); + builder.create(operand, coro.coroHandle); // Split the entry block before the await operation. Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); - addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume, - builder); + + // Add async.coro.suspend as a suspended block terminator. + builder.setInsertionPointToEnd(suspended); + builder.create(coroSaveOp.state(), coro.suspend, resume, + coro.cleanup); // Make sure that replacement value will be constructed in resume block. rewriter.setInsertionPointToStart(resume); } - // Replace or erase the await operation with the new value. - if (Value replaceWith = getReplacementValue(op, operands[0], rewriter)) + // Erase or replace the await operation with the new value. + if (Value replaceWith = getReplacementValue(op, operand, rewriter)) rewriter.replaceOp(op, replaceWith); else rewriter.eraseOp(op); @@ -891,15 +1236,13 @@ return success(); } - virtual Value getReplacementValue(Operation *op, Value operand, + virtual Value getReplacementValue(AwaitType op, Value operand, ConversionPatternRewriter &rewriter) const { return Value(); } private: const llvm::DenseMap &outlinedFunctions; - StringRef blockingAwaitFuncName; - StringRef coroAwaitFuncName; }; /// Lowering for `async.await` with a token operand. @@ -907,11 +1250,7 @@ using Base = AwaitOpLoweringBase; public: - explicit AwaitTokenOpLowering( - TypeConverter &converter, MLIRContext *ctx, - const llvm::DenseMap &outlinedFunctions) - : Base(converter, ctx, outlinedFunctions, kAwaitToken, - kAwaitTokenAndExecute) {} + using Base::Base; }; /// Lowering for `async.await` with a value operand. @@ -919,33 +1258,14 @@ using Base = AwaitOpLoweringBase; public: - explicit AwaitValueOpLowering( - TypeConverter &converter, MLIRContext *ctx, - const llvm::DenseMap &outlinedFunctions) - : Base(converter, ctx, outlinedFunctions, kAwaitValue, - kAwaitValueAndExecute) {} + using Base::Base; Value - getReplacementValue(Operation *op, Value operand, + getReplacementValue(AwaitOp op, Value operand, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); - - // Get the underlying value type from the `async.value`. - auto await = cast(op); - auto valueType = await.operand().getType().cast().getValueType(); - - // Get a pointer to an async value storage from the runtime. - auto storage = rewriter.create(loc, kGetValueStorage, - TypeRange(i8Ptr), operand); - - // Cast from i8* to the pointer pointer to LLVM type. - auto llvmValueType = getTypeConverter()->convertType(valueType); - auto castedStorage = rewriter.create( - loc, LLVM::LLVMPointerType::get(llvmValueType), storage.getResult(0)); - // Load from the async value storage. - return rewriter.create(loc, castedStorage.getResult()); + auto valueType = operand.getType().cast().getValueType(); + return rewriter.create(op->getLoc(), valueType, operand); } }; @@ -954,71 +1274,47 @@ using Base = AwaitOpLoweringBase; public: - explicit AwaitAllOpLowering( - TypeConverter &converter, MLIRContext *ctx, - const llvm::DenseMap &outlinedFunctions) - : Base(converter, ctx, outlinedFunctions, kAwaitGroup, - kAwaitAllAndExecute) {} + using Base::Base; }; } // namespace //===----------------------------------------------------------------------===// -// async.yield op lowerings to the corresponding async runtime function calls. +// Convert async.yield operation to async.runtime operations. //===----------------------------------------------------------------------===// -class YieldOpLowering : public ConversionPattern { +class YieldOpLowering : public OpConversionPattern { public: - explicit YieldOpLowering( - TypeConverter &converter, MLIRContext *ctx, + YieldOpLowering( + MLIRContext *ctx, const llvm::DenseMap &outlinedFunctions) - : ConversionPattern(async::YieldOp::getOperationName(), 1, converter, - ctx), + : OpConversionPattern(ctx), outlinedFunctions(outlinedFunctions) {} LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(async::YieldOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Check if yield operation is inside the outlined coroutine function. auto func = op->template getParentOfType(); auto outlined = outlinedFunctions.find(func); if (outlined == outlinedFunctions.end()) - return op->emitOpError( - "async.yield is not inside the outlined coroutine function"); + return rewriter.notifyMatchFailure( + op, "operation is not inside the outlined async.execute function"); Location loc = op->getLoc(); const CoroMachinery &coro = outlined->getSecond(); - // Store yielded values into the async values storage and emplace them. - auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); - + // Store yielded values into the async values storage and switch async + // values state to available. for (auto tuple : llvm::zip(operands, coro.returnValues)) { - // Store `yieldValue` into the `asyncValue` storage. Value yieldValue = std::get<0>(tuple); Value asyncValue = std::get<1>(tuple); - - // Get an opaque i8* pointer to an async value storage from the runtime. - auto storage = rewriter.create(loc, kGetValueStorage, - TypeRange(i8Ptr), asyncValue); - - // Cast storage pointer to the yielded value type. - auto castedStorage = rewriter.create( - loc, LLVM::LLVMPointerType::get(yieldValue.getType()), - storage.getResult(0)); - - // Store the yielded value into the async value storage. - rewriter.create(loc, yieldValue, - castedStorage.getResult()); - - // Emplace the `async.value` to mark it ready. - rewriter.create(loc, kEmplaceValue, TypeRange(), asyncValue); + rewriter.create(loc, yieldValue, asyncValue); + rewriter.create(loc, asyncValue); } - // Emplace the completion token to mark it ready. - rewriter.create(loc, kEmplaceToken, TypeRange(), coro.asyncToken); - - // Original operation was replaced by the function call(s). - rewriter.eraseOp(op); + // Switch the coroutine completion token to available state. + rewriter.replaceOpWithNewOp(op, coro.asyncToken); return success(); } @@ -1034,6 +1330,7 @@ : public ConvertAsyncToLLVMBase { void runOnOperation() override; }; +} // namespace void ConvertAsyncToLLVMPass::runOnOperation() { ModuleOp module = getOperation(); @@ -1088,6 +1385,35 @@ addCoroutineIntrinsicsDeclarations(module); addCRuntimeDeclarations(module); + // ------------------------------------------------------------------------ // + // Lower async operations to async.runtime operations. + // ------------------------------------------------------------------------ // + OwningRewritePatternList asyncPatterns; + + // Async lowering does not use type converter because it must preserve all + // types for async.runtime operations. + asyncPatterns.insert(ctx); + asyncPatterns.insert(ctx, + outlinedFunctions); + + // All high level async operations must be lowered to the runtime operations. + ConversionTarget runtimeTarget(*ctx); + runtimeTarget.addLegalDialect(); + runtimeTarget.addIllegalOp(); + runtimeTarget.addIllegalOp(); + + if (failed(applyPartialConversion(module, runtimeTarget, + std::move(asyncPatterns)))) { + signalPassFailure(); + return; + } + + // ------------------------------------------------------------------------ // + // Lower async.runtime and async.coro operations to Async Runtime API and + // LLVM coroutine intrinsics. + // ------------------------------------------------------------------------ // + // Convert async dialect types and operations to LLVM dialect. AsyncRuntimeTypeConverter converter; OwningRewritePatternList patterns; @@ -1099,23 +1425,29 @@ // Convert return operations inside async.execute regions. patterns.insert(converter, ctx); - // Lower async operations to async runtime API calls. - patterns.insert(converter, ctx); - patterns.insert(converter, ctx); + // Lower async.runtime operations to the async runtime API calls. + patterns.insert(converter, ctx); - // Use LLVM type converter to automatically convert between the async value - // payload type and LLVM type when loading/storing from/to the async - // value storage which is an opaque i8* pointer using LLVM load/store ops. - patterns - .insert( - llvmConverter, ctx, outlinedFunctions); - patterns.insert(llvmConverter, ctx, outlinedFunctions); + // Lower async.runtime operations that rely on LLVM type converter to convert + // from async value payload type to the LLVM type. + patterns.insert(llvmConverter, ctx); + + // Lower async coroutine operations to LLVM coroutine intrinsics. + patterns.insert(converter, + ctx); ConversionTarget target(*ctx); target.addLegalOp(); target.addLegalDialect(); - // All operations from Async dialect must be lowered to the runtime API calls. + // All operations from Async dialect must be lowered to the runtime API and + // LLVM intrinsics calls. target.addIllegalDialect(); // Add dynamic legality constraints to apply conversions defined above. @@ -1130,7 +1462,10 @@ if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } -} // namespace + +//===----------------------------------------------------------------------===// +// Patterns for structural type conversions for the Async dialect operations. +//===----------------------------------------------------------------------===// namespace { class ConvertExecuteOpTypes : public OpConversionPattern { diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -19,9 +19,8 @@ #define GET_OP_LIST #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" >(); - addTypes(); - addTypes(); - addTypes(); + addTypes(); // async types + addTypes(); // coro types } /// Parse a type registered to this dialect. @@ -33,6 +32,9 @@ if (keyword == "token") return TokenType::get(getContext()); + if (keyword == "group") + return GroupType::get(getContext()); + if (keyword == "value") { Type ty; if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { @@ -42,6 +44,15 @@ return ValueType::get(ty); } + if (keyword == "coro.id") + return CoroIdType::get(getContext()); + + if (keyword == "coro.handle") + return CoroHandleType::get(getContext()); + + if (keyword == "coro.state") + return CoroStateType::get(getContext()); + parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword; return Type(); } @@ -56,6 +67,9 @@ os << '>'; }) .Case([&](GroupType) { os << "group"; }) + .Case([&](CoroIdType) { os << "coro.id"; }) + .Case([&](CoroHandleType) { os << "coro.handle"; }) + .Case([&](CoroStateType) { os << "coro.state"; }) .Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); }); } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp @@ -122,7 +122,7 @@ // Drop the reference count immediately if the value has no uses. if (value.getUses().empty()) { - builder.create(loc, value, IntegerAttr::get(i32, 1)); + builder.create(loc, value, IntegerAttr::get(i32, 1)); return success(); } @@ -200,7 +200,7 @@ // Add a drop_ref immediately after the last user. builder.setInsertionPointAfter(lastUser); - builder.create(loc, value, IntegerAttr::get(i32, 1)); + builder.create(loc, value, IntegerAttr::get(i32, 1)); } // ------------------------------------------------------------------------ // @@ -232,7 +232,7 @@ // their `liveIn` set. for (Block *dropRefSuccessor : dropRefSuccessors) { builder.setInsertionPointToStart(dropRefSuccessor); - builder.create(loc, value, IntegerAttr::get(i32, 1)); + builder.create(loc, value, IntegerAttr::get(i32, 1)); } // ------------------------------------------------------------------------ // @@ -267,11 +267,12 @@ // Add a reference before the execute operation to keep the reference // counted alive before the async region completes execution. builder.setInsertionPoint(execute.getOperation()); - builder.create(loc, value, IntegerAttr::get(i32, 1)); + builder.create(loc, value, IntegerAttr::get(i32, 1)); // Drop the reference inside the async region before completion. OpBuilder executeBuilder = OpBuilder::atBlockTerminator(execute.getBody()); - executeBuilder.create(loc, value, IntegerAttr::get(i32, 1)); + executeBuilder.create(loc, value, + IntegerAttr::get(i32, 1)); } return success(); @@ -284,7 +285,7 @@ // because otherwise automatic reference counting will produce incorrect // results. WalkResult refCountingWalk = func.walk([&](Operation *op) -> WalkResult { - if (isa(op)) + if (isa(op)) return op->emitError() << "explicit reference counting is not supported"; return WalkResult::advance(); }); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp @@ -63,8 +63,8 @@ }; struct BlockUsersInfo { - llvm::SmallVector addRefs; - llvm::SmallVector dropRefs; + llvm::SmallVector addRefs; + llvm::SmallVector dropRefs; llvm::SmallVector users; }; @@ -74,9 +74,9 @@ BlockUsersInfo &info = blockUsers[user.operation->getBlock()]; info.users.push_back(user); - if (auto addRef = dyn_cast(user.operation)) + if (auto addRef = dyn_cast(user.operation)) info.addRefs.push_back(addRef); - if (auto dropRef = dyn_cast(user.operation)) + if (auto dropRef = dyn_cast(user.operation)) info.dropRefs.push_back(dropRef); }; @@ -118,8 +118,8 @@ // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`. llvm::SmallDenseMap cancellable; - for (AddRefOp addRef : info.addRefs) { - for (DropRefOp dropRef : info.dropRefs) { + for (RuntimeAddRefOp addRef : info.addRefs) { + for (RuntimeDropRefOp dropRef : info.dropRefs) { // `drop_ref` operation after the `add_ref` with matching count. if (dropRef.count() != addRef.count() || dropRef->isBeforeInBlock(addRef.getOperation())) diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir @@ -0,0 +1,83 @@ +// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s + +// CHECK-LABEL: @coro_id +func @coro_id() { + // CHECK: %0 = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %1 = llvm.mlir.null : !llvm.ptr + // CHECK: %2 = llvm.call @llvm.coro.id(%0, %1, %1, %1) + %0 = async.coro.id + return +} + +// CHECK-LABEL: @coro_begin +func @coro_begin() { + // CHECK: %[[ID:.*]] = llvm.call @llvm.coro.id + %0 = async.coro.id + // CHECK: %[[SIZE:.*]] = llvm.call @llvm.coro.size.i64() + // CHECK: %[[ALLOC:.*]] = llvm.call @malloc(%[[SIZE]]) + // CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin(%[[ID]], %[[ALLOC]]) + %1 = async.coro.begin %0 + return +} + +// CHECK-LABEL: @coro_free +func @coro_free() { + // CHECK: %[[ID:.*]] = llvm.call @llvm.coro.id + %0 = async.coro.id + // CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin + %1 = async.coro.begin %0 + // CHECK: %[[MEM:.*]] = llvm.call @llvm.coro.free(%[[ID]], %[[HDL]]) + // CHECK: llvm.call @free(%[[MEM]]) + async.coro.free %0, %1 + return +} + +// CHECK-LABEL: @coro_end +func @coro_end() { + %0 = async.coro.id + // CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin + %1 = async.coro.begin %0 + // CHECK: %[[FALSE:.*]] = llvm.mlir.constant(false) : i1 + // CHECK: llvm.call @llvm.coro.end(%[[HDL]], %[[FALSE]]) + async.coro.end %1 + return +} + +// CHECK-LABEL: @coro_save +func @coro_save() { + %0 = async.coro.id + // CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin + %1 = async.coro.begin %0 + // CHECK: llvm.call @llvm.coro.save(%[[HDL]]) + %2 = async.coro.save %1 + return +} + +// CHECK-LABEL: @coro_suspend +func @coro_suspend() { + %0 = async.coro.id + // CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin + %1 = async.coro.begin %0 + // CHECK: %[[STATE:.*]] = llvm.call @llvm.coro.save(%[[HDL]]) + %2 = async.coro.save %1 + + // CHECK: %[[FINAL:.*]] = llvm.mlir.constant(false) : i1 + // CHECK: %[[RET:.*]] = llvm.call @llvm.coro.suspend(%[[STATE]], %[[FINAL]]) + // CHECK: %[[SEXT:.*]] = llvm.sext %[[RET]] : i8 to i32 + // CHECK: llvm.switch %[[SEXT]], ^[[SUSPEND:[b0-9]+]] + // CHECK-NEXT: 0: ^[[RESUME:[b0-9]+]] + // CHECK-NEXT: 1: ^[[CLEANUP:[b0-9]+]] + async.coro.suspend %2, ^suspend, ^resume, ^cleanup +^resume: + // CHECK: ^[[RESUME]] + // CHECK: return {coro.resume} + return { coro.resume } +^cleanup: + // CHECK: ^[[CLEANUP]] + // CHECK: return {coro.cleanup} + return { coro.cleanup } +^suspend: + // CHECK: ^[[SUSPEND]] + // CHECK: return {coro.suspend} + return { coro.suspend } +} diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir @@ -0,0 +1,160 @@ +// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s + +// CHECK-LABEL: @create_token +func @create_token() { + // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken + %0 = async.runtime.create : !async.token + return +} + +// CHECK-LABEL: @create_value +func @create_value() { + // CHECK: %[[NULL:.*]] = llvm.mlir.null : !llvm.ptr + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[OFFSET:.*]] = llvm.getelementptr %[[NULL]][%[[ONE]]] + // CHECK: %[[SIZE:.*]] = llvm.ptrtoint %[[OFFSET]] + // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue(%[[SIZE]]) + %0 = async.runtime.create : !async.value + return +} + +// CHECK-LABEL: @create_group +func @create_group() { + // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup + %0 = async.runtime.create : !async.group + return +} + +// CHECK-LABEL: @set_token_available +func @set_token_available() { + // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken + %0 = async.runtime.create : !async.token + // CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[TOKEN]]) + async.runtime.set_available %0 : !async.token + return +} + +// CHECK-LABEL: @set_value_available +func @set_value_available() { + // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue + %0 = async.runtime.create : !async.value + // CHECK: call @mlirAsyncRuntimeEmplaceValue(%[[VALUE]]) + async.runtime.set_available %0 : !async.value + return +} + +// CHECK-LABEL: @await_token +func @await_token() { + // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken + %0 = async.runtime.create : !async.token + // CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]]) + async.runtime.await %0 : !async.token + return +} + +// CHECK-LABEL: @await_value +func @await_value() { + // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue + %0 = async.runtime.create : !async.value + // CHECK: call @mlirAsyncRuntimeAwaitValue(%[[VALUE]]) + async.runtime.await %0 : !async.value + return +} + +// CHECK-LABEL: @await_group +func @await_group() { + // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup + %0 = async.runtime.create : !async.group + // CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%[[GROUP]]) + async.runtime.await %0 : !async.group + return +} + +// CHECK-LABEL: @await_and_resume_token +func @await_and_resume_token() { + %0 = async.coro.id + // CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin + %1 = async.coro.begin %0 + // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken + %2 = async.runtime.create : !async.token + // CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume + // CHECK: call @mlirAsyncRuntimeAwaitTokenAndExecute + // CHECK-SAME: (%[[TOKEN]], %[[HDL]], %[[RESUME]]) + async.runtime.await_and_resume %2, %1 : !async.token + return +} + +// CHECK-LABEL: @await_and_resume_value +func @await_and_resume_value() { + %0 = async.coro.id + // CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin + %1 = async.coro.begin %0 + // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue + %2 = async.runtime.create : !async.value + // CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume + // CHECK: call @mlirAsyncRuntimeAwaitValueAndExecute + // CHECK-SAME: (%[[VALUE]], %[[HDL]], %[[RESUME]]) + async.runtime.await_and_resume %2, %1 : !async.value + return +} + +// CHECK-LABEL: @await_and_resume_group +func @await_and_resume_group() { + %0 = async.coro.id + // CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin + %1 = async.coro.begin %0 + // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateGroup + %2 = async.runtime.create : !async.group + // CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume + // CHECK: call @mlirAsyncRuntimeAwaitAllInGroupAndExecute + // CHECK-SAME: (%[[TOKEN]], %[[HDL]], %[[RESUME]]) + async.runtime.await_and_resume %2, %1 : !async.group + return +} + +// CHECK-LABEL: @resume +func @resume() { + %0 = async.coro.id + // CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin + %1 = async.coro.begin %0 + // CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume + // CHECK: call @mlirAsyncRuntimeExecute(%[[HDL]], %[[RESUME]]) + async.runtime.resume %1 + return +} + +// CHECK-LABEL: @store +func @store() { + // CHECK: %[[CST:.*]] = constant 1.0 + %0 = constant 1.0 : f32 + // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue + %1 = async.runtime.create : !async.value + // CHECK: %[[P0:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]]) + // CHECK: %[[P1:.*]] = llvm.bitcast %[[P0]] : !llvm.ptr to !llvm.ptr + // CHECK: llvm.store %[[CST]], %[[P1]] + async.runtime.store %0, %1 : !async.value + return +} + +// CHECK-LABEL: @load +func @load() -> f32 { + // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue + %0 = async.runtime.create : !async.value + // CHECK: %[[P0:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]]) + // CHECK: %[[P1:.*]] = llvm.bitcast %[[P0]] : !llvm.ptr to !llvm.ptr + // CHECK: %[[VALUE:.*]] = llvm.load %[[P1]] + %1 = async.runtime.load %0 : !async.value + // CHECK: return %[[VALUE]] : f32 + return %1 : f32 +} + +// CHECK-LABEL: @add_token_to_group +func @add_token_to_group() { + // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken + %0 = async.runtime.create : !async.token + // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup + %1 = async.runtime.create : !async.group + // CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %[[GROUP]]) + async.runtime.add_to_group %0, %1 : !async.token + return +} diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -4,11 +4,11 @@ func @reference_counting(%arg0: !async.token) { // CHECK: %[[C2:.*]] = constant 2 : i32 // CHECK: call @mlirAsyncRuntimeAddRef(%arg0, %[[C2]]) - async.add_ref %arg0 {count = 2 : i32} : !async.token + async.runtime.add_ref %arg0 {count = 2 : i32} : !async.token // CHECK: %[[C1:.*]] = constant 1 : i32 // CHECK: call @mlirAsyncRuntimeDropRef(%arg0, %[[C1]]) - async.drop_ref %arg0 {count = 1 : i32} : !async.token + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token return } @@ -38,21 +38,16 @@ // CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin // Pass a suspended coroutine to the async runtime. -// CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume // CHECK: %[[STATE:.*]] = llvm.call @llvm.coro.save +// CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume // CHECK: call @mlirAsyncRuntimeExecute(%[[HDL]], %[[RESUME]]) // CHECK: %[[SUSPENDED:.*]] = llvm.call @llvm.coro.suspend(%[[STATE]] // Decide the next block based on the code returned from suspend. -// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i8) -// CHECK: %[[NONE:.*]] = llvm.mlir.constant(-1 : i8) -// CHECK: %[[IS_NONE:.*]] = llvm.icmp "eq" %[[SUSPENDED]], %[[NONE]] -// CHECK: llvm.cond_br %[[IS_NONE]], ^[[SUSPEND:.*]], ^[[RESUME_OR_CLEANUP:.*]] - -// Decide if branch to resume or cleanup block. -// CHECK: ^[[RESUME_OR_CLEANUP]]: -// CHECK: %[[IS_ZERO:.*]] = llvm.icmp "eq" %[[SUSPENDED]], %[[ZERO]] -// CHECK: llvm.cond_br %[[IS_ZERO]], ^[[RESUME:.*]], ^[[CLEANUP:.*]] +// CHECK: %[[SEXT:.*]] = llvm.sext %[[SUSPENDED]] : i8 to i32 +// CHECK: llvm.switch %[[SEXT]], ^[[SUSPEND:[b0-9]+]] +// CHECK-NEXT: 0: ^[[RESUME:[b0-9]+]] +// CHECK-NEXT: 1: ^[[CLEANUP:[b0-9]+]] // Resume coroutine after suspension. // CHECK: ^[[RESUME]]: diff --git a/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir b/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir --- a/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir +++ b/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir @@ -2,27 +2,27 @@ // CHECK-LABEL: @cancellable_operations_0 func @cancellable_operations_0(%arg0: !async.token) { - // CHECK-NOT: async.add_ref - // CHECK-NOT: async.drop_ref - async.add_ref %arg0 {count = 1 : i32} : !async.token - async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK-NOT: async.runtime.add_ref + // CHECK-NOT: async.runtime.drop_ref + async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token // CHECK: return return } // CHECK-LABEL: @cancellable_operations_1 func @cancellable_operations_1(%arg0: !async.token) { - // CHECK-NOT: async.add_ref + // CHECK-NOT: async.runtime.add_ref // CHECK: async.execute - async.add_ref %arg0 {count = 1 : i32} : !async.token + async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token async.execute [%arg0] { - // CHECK: async.drop_ref - async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: async.runtime.drop_ref + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token // CHECK-NEXT: async.yield async.yield } - // CHECK-NOT: async.drop_ref - async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK-NOT: async.runtime.drop_ref + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token // CHECK: return return } @@ -33,28 +33,28 @@ // CHECK-NEXT: async.await // CHECK-NEXT: async.await // CHECK-NEXT: return - async.add_ref %arg0 {count = 1 : i32} : !async.token + async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token async.await %arg0 : !async.token - async.drop_ref %arg0 {count = 1 : i32} : !async.token + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token async.await %arg0 : !async.token - async.add_ref %arg0 {count = 1 : i32} : !async.token + async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token async.await %arg0 : !async.token - async.drop_ref %arg0 {count = 1 : i32} : !async.token + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token return } // CHECK-LABEL: @cancellable_operations_3 func @cancellable_operations_3(%arg0: !async.token) { // CHECK-NOT: add_ref - async.add_ref %arg0 {count = 1 : i32} : !async.token + async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token %token = async.execute { async.await %arg0 : !async.token - // CHECK: async.drop_ref - async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: async.runtime.drop_ref + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token async.yield } - // CHECK-NOT: async.drop_ref - async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK-NOT: async.runtime.drop_ref + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token // CHECK: async.await async.await %arg0 : !async.token // CHECK: return @@ -67,19 +67,19 @@ // that the body of the `async.execute` operation will run before the await // operation in the function body, and will destroy the `%arg0` token. // CHECK: add_ref - async.add_ref %arg0 {count = 1 : i32} : !async.token + async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token %token = async.execute { // CHECK: async.await async.await %arg0 : !async.token - // CHECK: async.drop_ref - async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: async.runtime.drop_ref + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token // CHECK: async.yield async.yield } // CHECK: async.await async.await %arg0 : !async.token // CHECK: drop_ref - async.drop_ref %arg0 {count = 1 : i32} : !async.token + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token // CHECK: return return } @@ -92,22 +92,23 @@ // NOTE: This test is not correct w.r.t. reference counting, and at runtime // would leak %arg0 value if %arg1 is false. IR like this will not be // constructed by automatic reference counting pass, because it would - // place `async.add_ref` right before the `async.execute` inside `scf.if`. + // place `async.runtime.add_ref` right before the `async.execute` + // inside `scf.if`. - // CHECK: async.add_ref - async.add_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: async.runtime.add_ref + async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token scf.if %arg1 { %token = async.execute { async.await %arg0 : !async.token - // CHECK: async.drop_ref - async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: async.runtime.drop_ref + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token async.yield } } // CHECK: async.await async.await %arg0 : !async.token - // CHECK: async.drop_ref - async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: async.runtime.drop_ref + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token // CHECK: return return } diff --git a/mlir/test/Dialect/Async/async-ref-counting.mlir b/mlir/test/Dialect/Async/async-ref-counting.mlir --- a/mlir/test/Dialect/Async/async-ref-counting.mlir +++ b/mlir/test/Dialect/Async/async-ref-counting.mlir @@ -5,7 +5,7 @@ // CHECK-LABEL: @token_arg_no_uses func @token_arg_no_uses(%arg0: !async.token) { - // CHECK: async.drop_ref %arg0 {count = 1 : i32} + // CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32} return } @@ -13,11 +13,11 @@ func @token_arg_conditional_await(%arg0: !async.token, %arg1: i1) { cond_br %arg1, ^bb1, ^bb2 ^bb1: - // CHECK: async.drop_ref %arg0 {count = 1 : i32} + // CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32} return ^bb2: // CHECK: async.await %arg0 - // CHECK: async.drop_ref %arg0 {count = 1 : i32} + // CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32} async.await %arg0 : !async.token return } @@ -25,7 +25,7 @@ // CHECK-LABEL: @token_no_uses func @token_no_uses() { // CHECK: %[[TOKEN:.*]] = async.execute - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} %token = async.execute { async.yield } @@ -50,7 +50,7 @@ } // CHECK: async.await %[[TOKEN]] async.await %token : !async.token - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} // CHECK: return return } @@ -62,7 +62,7 @@ async.yield } // CHECK: async.await %[[TOKEN]] - // CHECK-NOT: async.drop_ref + // CHECK-NOT: async.runtime.drop_ref async.await %token : !async.token // CHECK: return %[[TOKEN]] return %token : !async.token @@ -80,7 +80,7 @@ async.await %token : !async.token } // CHECK: } - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} // CHECK: return return } @@ -93,11 +93,11 @@ } cond_br %arg0, ^bb1, ^bb2 ^bb1: - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} return ^bb2: // CHECK: async.await %[[TOKEN]] - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} async.await %token : !async.token return } @@ -115,7 +115,7 @@ %0 = call @cond(): () -> (i1) cond_br %0, ^bb1, ^bb2 ^bb2: - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} return } @@ -128,7 +128,7 @@ async.yield } // CHECK: async.await %[[TOKEN]] - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} async.await %token : !async.token %0 = call @cond(): () -> (i1) cond_br %0, ^bb1, ^bb2 @@ -143,16 +143,16 @@ async.yield } - // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} // CHECK: %[[TOKEN_0:.*]] = async.execute %token_0 = async.execute { - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} // CHECK-NEXT: async.yield async.await %token : !async.token async.yield } - // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32} - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} // CHECK: return return } @@ -164,30 +164,30 @@ async.yield } - // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} // CHECK: %[[TOKEN_0:.*]] = async.execute %token_0 = async.execute { - // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} // CHECK: %[[TOKEN_1:.*]] = async.execute %token_1 = async.execute { - // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} // CHECK: %[[TOKEN_2:.*]] = async.execute %token_2 = async.execute { // CHECK: async.await %[[TOKEN]] - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} async.await %token : !async.token async.yield } - // CHECK: async.drop_ref %[[TOKEN_2]] {count = 1 : i32} - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN_2]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} async.yield } - // CHECK: async.drop_ref %[[TOKEN_1]] {count = 1 : i32} - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN_1]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} async.yield } - // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32} - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} // CHECK: return return } @@ -199,19 +199,19 @@ async.yield } - // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} // CHECK: %[[TOKEN_0:.*]] = async.execute %token_0 = async.execute[%token] { - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} // CHECK-NEXT: async.yield async.yield } // CHECK: async.await %[[TOKEN]] - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} async.await %token : !async.token // CHECK: async.await %[[TOKEN_0]] - // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32} async.await %token_0 : !async.token // CHECK: return @@ -226,26 +226,26 @@ async.yield %0 : f32 } - // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} - // CHECK: async.add_ref %[[RESULTS]] {count = 1 : i32} + // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.add_ref %[[RESULTS]] {count = 1 : i32} // CHECK: %[[TOKEN_0:.*]] = async.execute %token_0 = async.execute[%token](%results as %arg0 : !async.value) { - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} - // CHECK: async.drop_ref %[[RESULTS]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[RESULTS]] {count = 1 : i32} // CHECK: async.yield async.yield } // CHECK: async.await %[[TOKEN]] - // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} async.await %token : !async.token // CHECK: async.await %[[TOKEN_0]] - // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32} async.await %token_0 : !async.token // CHECK: async.await %[[RESULTS]] - // CHECK: async.drop_ref %[[RESULTS]] {count = 1 : i32} + // CHECK: async.runtime.drop_ref %[[RESULTS]] {count = 1 : i32} %0 = async.await %results : !async.value // CHECK: return diff --git a/mlir/test/Dialect/Async/coro.mlir b/mlir/test/Dialect/Async/coro.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/coro.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt %s | FileCheck %s --dump-input=always + +// CHECK-LABEL: @coro_id +func @coro_id() -> !async.coro.id { + // CHECK: %0 = async.coro.id + // CHECK: return %0 : !async.coro.id + %0 = async.coro.id + return %0 : !async.coro.id +} + +// CHECK-LABEL: @coro_handle +func @coro_handle(%arg0: !async.coro.id) -> !async.coro.handle { + // CHECK: %0 = async.coro.begin %arg0 + // CHECK: return %0 : !async.coro.handle + %0 = async.coro.begin %arg0 + return %0 : !async.coro.handle +} + +// CHECK-LABEL: @coro_free +func @coro_free(%arg0: !async.coro.id, %arg1: !async.coro.handle) { + // CHECK: async.coro.free %arg0, %arg1 + async.coro.free %arg0, %arg1 + return +} + +// CHECK-LABEL: @coro_end +func @coro_end(%arg0: !async.coro.handle) { + // CHECK: async.coro.end %arg0 + async.coro.end %arg0 + return +} + +// CHECK-LABEL: @coro_save +func @coro_save(%arg0: !async.coro.handle) -> !async.coro.state { + // CHECK: %0 = async.coro.save %arg0 + %0 = async.coro.save %arg0 + // CHECK: return %0 : !async.coro.state + return %0 : !async.coro.state +} + +// CHECK-LABEL: @coro_suspend +func @coro_suspend(%arg0: !async.coro.state) { + // CHECK: async.coro.suspend %arg0 + // CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]] + async.coro.suspend %arg0, ^suspend, ^resume, ^cleanup +^resume: + // CHECK: ^[[RESUME]] + // CHECK: return {coro.resume} + return { coro.resume } +^cleanup: + // CHECK: ^[[CLEANUP]] + // CHECK: return {coro.cleanup} + return { coro.cleanup } +^suspend: + // CHECK: ^[[SUSPEND]] + // CHECK: return {coro.suspend} + return { coro.suspend } +} diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir --- a/mlir/test/Dialect/Async/ops.mlir +++ b/mlir/test/Dialect/Async/ops.mlir @@ -134,17 +134,3 @@ %3 = addi %1, %2 : index return %3 : index } - -// CHECK-LABEL: @add_ref -func @add_ref(%arg0: !async.token) { - // CHECK: async.add_ref %arg0 {count = 1 : i32} - async.add_ref %arg0 {count = 1 : i32} : !async.token - return -} - -// CHECK-LABEL: @drop_ref -func @drop_ref(%arg0: !async.token) { - // CHECK: async.drop_ref %arg0 {count = 1 : i32} - async.drop_ref %arg0 {count = 1 : i32} : !async.token - return -} diff --git a/mlir/test/Dialect/Async/runtime.mlir b/mlir/test/Dialect/Async/runtime.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/runtime.mlir @@ -0,0 +1,130 @@ +// RUN: mlir-opt %s | FileCheck %s --dump-input=always + +// CHECK-LABEL: @create_token +func @create_token() -> !async.token { + // CHECK: %0 = async.runtime.create : !async.token + %0 = async.runtime.create : !async.token + // CHECK: return %0 : !async.token + return %0 : !async.token +} + +// CHECK-LABEL: @create_value +func @create_value() -> !async.value { + // CHECK: %0 = async.runtime.create : !async.value + %0 = async.runtime.create : !async.value + // CHECK: return %0 : !async.value + return %0 : !async.value +} + +// CHECK-LABEL: @create_group +func @create_group() -> !async.group { + // CHECK: %0 = async.runtime.create : !async.group + %0 = async.runtime.create : !async.group + // CHECK: return %0 : !async.group + return %0 : !async.group +} + +// CHECK-LABEL: @set_token_available +func @set_token_available(%arg0: !async.token) { + // CHECK: async.runtime.set_available %arg0 : !async.token + async.runtime.set_available %arg0 : !async.token + return +} + +// CHECK-LABEL: @set_value_available +func @set_value_available(%arg0: !async.value) { + // CHECK: async.runtime.set_available %arg0 : !async.value + async.runtime.set_available %arg0 : !async.value + return +} + +// CHECK-LABEL: @await_token +func @await_token(%arg0: !async.token) { + // CHECK: async.runtime.await %arg0 : !async.token + async.runtime.await %arg0 : !async.token + return +} + +// CHECK-LABEL: @await_value +func @await_value(%arg0: !async.value) { + // CHECK: async.runtime.await %arg0 : !async.value + async.runtime.await %arg0 : !async.value + return +} + +// CHECK-LABEL: @await_group +func @await_group(%arg0: !async.group) { + // CHECK: async.runtime.await %arg0 : !async.group + async.runtime.await %arg0 : !async.group + return +} + +// CHECK-LABEL: @await_and_resume_token +func @await_and_resume_token(%arg0: !async.token, + %arg1: !async.coro.handle) { + // CHECK: async.runtime.await_and_resume %arg0, %arg1 : !async.token + async.runtime.await_and_resume %arg0, %arg1 : !async.token + return +} + +// CHECK-LABEL: @await_and_resume_value +func @await_and_resume_value(%arg0: !async.value, + %arg1: !async.coro.handle) { + // CHECK: async.runtime.await_and_resume %arg0, %arg1 : !async.value + async.runtime.await_and_resume %arg0, %arg1 : !async.value + return +} + +// CHECK-LABEL: @await_and_resume_group +func @await_and_resume_group(%arg0: !async.group, + %arg1: !async.coro.handle) { + // CHECK: async.runtime.await_and_resume %arg0, %arg1 : !async.group + async.runtime.await_and_resume %arg0, %arg1 : !async.group + return +} + +// CHECK-LABEL: @resume +func @resume(%arg0: !async.coro.handle) { + // CHECK: async.runtime.resume %arg0 + async.runtime.resume %arg0 + return +} + +// CHECK-LABEL: @store +func @store(%arg0: f32, %arg1: !async.value) { + // CHECK: async.runtime.store %arg0, %arg1 : !async.value + async.runtime.store %arg0, %arg1 : !async.value + return +} + +// CHECK-LABEL: @load +func @load(%arg0: !async.value) -> f32 { + // CHECK: %0 = async.runtime.load %arg0 : !async.value + // CHECK: return %0 : f32 + %0 = async.runtime.load %arg0 : !async.value + return %0 : f32 +} + +// CHECK-LABEL: @add_to_group +func @add_to_group(%arg0: !async.token, %arg1: !async.value, + %arg2: !async.group) { + // CHECK: async.runtime.add_to_group %arg0, %arg2 : !async.token + async.runtime.add_to_group %arg0, %arg2 : !async.token + // CHECK: async.runtime.add_to_group %arg1, %arg2 : !async.value + async.runtime.add_to_group %arg1, %arg2 : !async.value + return +} + +// CHECK-LABEL: @add_ref +func @add_ref(%arg0: !async.token) { + // CHECK: async.runtime.add_ref %arg0 {count = 1 : i32} + async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token + return +} + +// CHECK-LABEL: @drop_ref +func @drop_ref(%arg0: !async.token) { + // CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32} + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token + return +}