diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -343,7 +343,7 @@ } def Async_RuntimeSetAvailableOp : Async_Op<"runtime.set_available"> { - let summary = "switches token or value available state"; + let summary = "switches token or value to available state"; let description = [{ The `async.runtime.set_available` operation switches async token or value state to available. @@ -353,11 +353,37 @@ let assemblyFormat = "$operand attr-dict `:` type($operand)"; } +def Async_RuntimeSetErrorOp : Async_Op<"runtime.set_error"> { + let summary = "switches token or value to error state"; + let description = [{ + The `async.runtime.set_error` operation switches async token or value + state to error. + }]; + + let arguments = (ins Async_AnyValueOrTokenType:$operand); + let assemblyFormat = "$operand attr-dict `:` type($operand)"; +} + +def Async_RuntimeIsErrorOp : Async_Op<"runtime.is_error"> { + let summary = "returns true if token, value or group is in error state"; + let description = [{ + The `async.runtime.is_error` operation returns true if the token, value or + group (any of the async runtime values) is in the error state. It is the + caller responsibility to check error state after the call to `await` or + resuming after `await_and_resume`. + }]; + + let arguments = (ins Async_AnyValueOrTokenType:$operand); + let results = (outs I1:$is_error); + + let assemblyFormat = "$operand attr-dict `:` type($operand)"; +} + def Async_RuntimeAwaitOp : Async_Op<"runtime.await"> { let summary = "blocks the caller thread until the operand becomes available"; let description = [{ The `async.runtime.await` operation blocks the caller thread until the - operand becomes available. + operand becomes available or error. }]; let arguments = (ins Async_AnyAsyncType:$operand); @@ -379,8 +405,8 @@ let summary = "awaits the async operand and resumes the coroutine"; let description = [{ The `async.runtime.await_and_resume` operation awaits for the operand to - become available and resumes the coroutine on a thread managed by the - runtime. + become available or error and resumes the coroutine on a thread managed by + the runtime. }]; let arguments = (ins Async_AnyAsyncType:$operand, diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h --- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h @@ -76,6 +76,18 @@ // Switches `async.value` to ready state and runs all awaiters. extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *); +// Switches `async.token` to error state and runs all awaiters. +extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *); + +// Switches `async.value` to error state and runs all awaiters. +extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *); + +// Returns true if token is in the error state. +extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *); + +// Returns true if value is in the error state. +extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *); + // Blocks the caller thread until the token becomes ready. extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *); diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -35,6 +35,10 @@ static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; +static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError"; +static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError"; +static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError"; +static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError"; static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; @@ -101,6 +105,26 @@ return FunctionType::get(ctx, {value}, {}); } + static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) { + return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); + } + + static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { + auto value = opaquePointerType(ctx); + return FunctionType::get(ctx, {value}, {}); + } + + static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) { + auto i1 = IntegerType::get(ctx, 1); + return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); + } + + static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { + auto value = opaquePointerType(ctx); + auto i1 = IntegerType::get(ctx, 1); + return FunctionType::get(ctx, {value}, {i1}); + } + static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); } @@ -173,6 +197,10 @@ addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); + addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); + addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); + addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); + addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); @@ -560,17 +588,53 @@ matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Type operandType = op.operand().getType(); + rewriter.replaceOpWithNewOp( + op, operandType.isa() ? kEmplaceToken : kEmplaceValue, + TypeRange(), operands); + return success(); + } +}; +} // namespace - if (operandType.isa() || operandType.isa()) { - rewriter.create(op->getLoc(), - operandType.isa() ? kEmplaceToken - : kEmplaceValue, - TypeRange(), operands); - rewriter.eraseOp(op); - return success(); - } +//===----------------------------------------------------------------------===// +// Convert async.runtime.set_error to the corresponding runtime API call. +//===----------------------------------------------------------------------===// - return rewriter.notifyMatchFailure(op, "unsupported async type"); +namespace { +class RuntimeSetErrorOpLowering + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(RuntimeSetErrorOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Type operandType = op.operand().getType(); + rewriter.replaceOpWithNewOp( + op, operandType.isa() ? kSetTokenError : kSetValueError, + TypeRange(), operands); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert async.runtime.is_error to the corresponding runtime API call. +//===----------------------------------------------------------------------===// + +namespace { +class RuntimeIsErrorOpLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(RuntimeIsErrorOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Type operandType = op.operand().getType(); + rewriter.replaceOpWithNewOp( + op, operandType.isa() ? kIsTokenError : kIsValueError, + rewriter.getI1Type(), operands); + return success(); } }; } // namespace @@ -889,7 +953,8 @@ patterns.add(converter, ctx); // Lower async.runtime operations to the async runtime API calls. - patterns.add(converter, ctx); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -52,6 +52,8 @@ /// operation to enable non-blocking waiting via coroutine suspension. namespace { struct CoroMachinery { + FuncOp func; + // Async execute region returns a completion token, and an async value for // each yielded value. // @@ -63,6 +65,7 @@ llvm::SmallVector returnValues; // returned async values Value coroHandle; // coroutine handle (!async.coro.handle value) + Block *setError; // switch completion token and all values to error state Block *cleanup; // coroutine cleanup block Block *suspend; // coroutine suspension block }; @@ -74,6 +77,7 @@ /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html /// /// - `entry` block sets up the coroutine. +/// - `set_error` block sets completion token and async values state to error. /// - `cleanup` block cleans up the coroutine state. /// - `suspend block after the @llvm.coro.end() defines what value will be /// returned to the initial caller of a coroutine. Everything before the @@ -91,6 +95,11 @@ /// %hdl = async.coro.begin %id // create a coroutine handle /// br ^cleanup /// +/// ^set_error: // this block created lazily only if needed (see code below) +/// async.runtime.set_error %token : !async.token +/// async.runtime.set_error %value : !async.value +/// br ^cleanup +/// /// ^cleanup: /// async.coro.free %hdl // delete the coroutine state /// br ^suspend @@ -163,14 +172,39 @@ // continuations, and will conditionally branch to cleanup or suspend blocks. CoroMachinery machinery; + machinery.func = func; machinery.asyncToken = retToken; machinery.returnValues = retValues; machinery.coroHandle = coroHdlOp.handle(); + machinery.setError = nullptr; // created lazily only if needed machinery.cleanup = cleanupBlock; machinery.suspend = suspendBlock; return machinery; } +// Lazily creates `set_error` block only if it is required for lowering to the +// runtime operations (see for example lowering of assert operation). +static Block *setupSetErrorBlock(CoroMachinery &coro) { + if (coro.setError) + return coro.setError; + + coro.setError = coro.func.addBlock(); + coro.setError->moveBefore(coro.cleanup); + + auto builder = + ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError); + + // Coroutine set_error block: set error on token and all returned values. + builder.create(coro.asyncToken); + for (Value retValue : coro.returnValues) + builder.create(retValue); + + // Branch into the cleanup block. + builder.create(coro.cleanup); + + return coro.setError; +} + /// Outline the body region attached to the `async.execute` op into a standalone /// function. /// @@ -316,9 +350,8 @@ using AwaitAdaptor = typename AwaitType::Adaptor; public: - AwaitOpLoweringBase( - MLIRContext *ctx, - const llvm::DenseMap &outlinedFunctions) + AwaitOpLoweringBase(MLIRContext *ctx, + llvm::DenseMap &outlinedFunctions) : OpConversionPattern(ctx), outlinedFunctions(outlinedFunctions) {} @@ -346,7 +379,7 @@ // Inside the coroutine we convert await operation into coroutine suspension // point, and resume execution asynchronously. if (isInCoroutine) { - const CoroMachinery &coro = outlined->getSecond(); + CoroMachinery &coro = outlined->getSecond(); Block *suspended = op->getBlock(); ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); @@ -366,8 +399,25 @@ builder.create(coroSaveOp.state(), coro.suspend, resume, coro.cleanup); - // Make sure that replacement value will be constructed in resume block. - rewriter.setInsertionPointToStart(resume); + // TODO: Async groups do not yet support runtime errors. + if (!std::is_same::value) { + // Split the resume block into error checking and continuation. + Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); + + // Check if the awaited value is in the error state. + builder.setInsertionPointToStart(resume); + auto isError = builder.create( + loc, rewriter.getI1Type(), operand); + builder.create(isError, + /*trueDest=*/setupSetErrorBlock(coro), + /*trueArgs=*/ArrayRef(), + /*falseDest=*/continuation, + /*falseArgs=*/ArrayRef()); + + // Make sure that replacement value will be constructed in the + // continuation block. + rewriter.setInsertionPointToStart(continuation); + } } // Erase or replace the await operation with the new value. @@ -385,7 +435,7 @@ } private: - const llvm::DenseMap &outlinedFunctions; + llvm::DenseMap &outlinedFunctions; }; /// Lowering for `async.await` with a token operand. @@ -437,12 +487,12 @@ LogicalResult matchAndRewrite(async::YieldOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - // Check if yield operation is inside the outlined coroutine function. + // Check if yield operation is inside the async coroutine function. auto func = op->template getParentOfType(); auto outlined = outlinedFunctions.find(func); if (outlined == outlinedFunctions.end()) return rewriter.notifyMatchFailure( - op, "operation is not inside the outlined async.execute function"); + op, "operation is not inside the async coroutine function"); Location loc = op->getLoc(); const CoroMachinery &coro = outlined->getSecond(); @@ -466,6 +516,46 @@ const llvm::DenseMap &outlinedFunctions; }; +//===----------------------------------------------------------------------===// +// Convert std.assert operation to cond_br into `set_error` block. +//===----------------------------------------------------------------------===// + +class AssertOpLowering : public OpConversionPattern { +public: + AssertOpLowering(MLIRContext *ctx, + llvm::DenseMap &outlinedFunctions) + : OpConversionPattern(ctx), + outlinedFunctions(outlinedFunctions) {} + + LogicalResult + matchAndRewrite(AssertOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // Check if assert operation is inside the async coroutine function. + auto func = op->template getParentOfType(); + auto outlined = outlinedFunctions.find(func); + if (outlined == outlinedFunctions.end()) + return rewriter.notifyMatchFailure( + op, "operation is not inside the async coroutine function"); + + Location loc = op->getLoc(); + CoroMachinery &coro = outlined->getSecond(); + + Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); + rewriter.setInsertionPointToEnd(cont->getPrevNode()); + rewriter.create(loc, AssertOpAdaptor(operands).arg(), + /*trueDest=*/cont, + /*trueArgs=*/ArrayRef(), + /*falseDest=*/setupSetErrorBlock(coro), + /*falseArgs=*/ArrayRef()); + rewriter.eraseOp(op); + + return success(); + } + +private: + llvm::DenseMap &outlinedFunctions; +}; + //===----------------------------------------------------------------------===// void AsyncToAsyncRuntimePass::runOnOperation() { @@ -495,12 +585,19 @@ AwaitAllOpLowering, YieldOpLowering>(ctx, outlinedFunctions); + // Lower assertions to conditional branches into error blocks. + asyncPatterns.add(ctx, outlinedFunctions); + // All high level async operations must be lowered to the runtime operations. ConversionTarget runtimeTarget(*ctx); runtimeTarget.addLegalDialect(); runtimeTarget.addIllegalOp(); runtimeTarget.addIllegalOp(); + // Assertions must be converted to runtime errors. + runtimeTarget.addIllegalOp(); + runtimeTarget.addLegalOp(); + if (failed(applyPartialConversion(module, runtimeTarget, std::move(asyncPatterns)))) { signalPassFailure(); diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -77,6 +77,46 @@ llvm::ThreadPool threadPool; }; +// -------------------------------------------------------------------------- // +// A state of the async runtime value (token, value or group). +// -------------------------------------------------------------------------- // + +class State { +public: + enum StateEnum : int8_t { + // The underlying value is not yet available for consumption. + kUnavailable = 0, + // The underlying value is available for consumption. This state can not + // transition to any other state. + kAvailable = 1, + // This underlying value is available and contains an error. This state can + // not transition to any other state. + kError = 2, + }; + + /* implicit */ State(StateEnum s) : state(s) {} + /* implicit */ operator StateEnum() { return state; } + + bool isUnavailable() const { return state == kUnavailable; } + bool isAvailable() const { return state == kAvailable; } + bool isError() const { return state == kError; } + bool isAvailableOrError() const { return isAvailable() || isError(); } + + const char *debug() const { + switch (state) { + case kUnavailable: + return "unavailable"; + case kAvailable: + return "available"; + case kError: + return "error"; + } + } + +private: + StateEnum state; +}; + // -------------------------------------------------------------------------- // // A base class for all reference counted objects created by the async runtime. // -------------------------------------------------------------------------- // @@ -137,9 +177,9 @@ // reference we must ensure that the token will be alive until the // asynchronous operation is completed. AsyncToken(AsyncRuntime *runtime) - : RefCounted(runtime, /*count=*/2), ready(false) {} + : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {} - std::atomic ready; + std::atomic state; // Pending awaiters are guarded by a mutex. std::mutex mu; @@ -153,9 +193,10 @@ struct AsyncValue : public RefCounted { // AsyncValue similar to an AsyncToken created with a reference count of 2. AsyncValue(AsyncRuntime *runtime, int32_t size) - : RefCounted(runtime, /*count=*/2), ready(false), storage(size) {} + : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable), + storage(size) {} - std::atomic ready; + std::atomic state; // Use vector of bytes to store async value payload. std::vector storage; @@ -182,7 +223,6 @@ std::vector> awaiters; }; - // Adds references to reference counted runtime object. extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { RefCounted *refCounted = static_cast(ptr); @@ -231,7 +271,7 @@ } }; - if (token->ready) { + if (State(token->state).isAvailableOrError()) { // Update group pending tokens immediately and maybe run awaiters. onTokenReady(); @@ -254,12 +294,16 @@ return rank; } -// Switches `async.token` to ready state and runs all awaiters. -extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { +// Switches `async.token` to available or error state (terminatl state) and runs +// all awaiters. +static void setTokenState(AsyncToken *token, State state) { + assert(state.isAvailableOrError() && "must be terminal state"); + assert(State(token->state).isUnavailable() && "token must be unavailable"); + // Make sure that `dropRef` does not destroy the mutex owned by the lock. { std::unique_lock lock(token->mu); - token->ready = true; + token->state = state; token->cv.notify_all(); for (auto &awaiter : token->awaiters) awaiter(); @@ -270,12 +314,14 @@ token->dropRef(); } -// Switches `async.value` to ready state and runs all awaiters. -extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) { +static void setValueState(AsyncValue *value, State state) { + assert(state.isAvailableOrError() && "must be terminal state"); + assert(State(value->state).isUnavailable() && "value must be unavailable"); + // Make sure that `dropRef` does not destroy the mutex owned by the lock. { std::unique_lock lock(value->mu); - value->ready = true; + value->state = state; value->cv.notify_all(); for (auto &awaiter : value->awaiters) awaiter(); @@ -286,16 +332,42 @@ value->dropRef(); } +extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { + setTokenState(token, State::kAvailable); +} + +extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) { + setValueState(value, State::kAvailable); +} + +extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) { + setTokenState(token, State::kError); +} + +extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) { + setValueState(value, State::kError); +} + +extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) { + return State(token->state).isError(); +} + +extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) { + return State(value->state).isError(); +} + extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { std::unique_lock lock(token->mu); - if (!token->ready) - token->cv.wait(lock, [token] { return token->ready.load(); }); + if (!State(token->state).isAvailableOrError()) + token->cv.wait( + lock, [token] { return State(token->state).isAvailableOrError(); }); } extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) { std::unique_lock lock(value->mu); - if (!value->ready) - value->cv.wait(lock, [value] { return value->ready.load(); }); + if (!State(value->state).isAvailableOrError()) + value->cv.wait( + lock, [value] { return State(value->state).isAvailableOrError(); }); } extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { @@ -306,6 +378,7 @@ // Returns a pointer to the storage owned by the async value. extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) { + assert(!State(value->state).isError() && "unexpected error state"); return value->storage.data(); } @@ -319,7 +392,7 @@ CoroResume resume) { auto execute = [handle, resume]() { (*resume)(handle); }; std::unique_lock lock(token->mu); - if (token->ready) { + if (State(token->state).isAvailableOrError()) { lock.unlock(); execute(); } else { @@ -332,7 +405,7 @@ CoroResume resume) { auto execute = [handle, resume]() { (*resume)(handle); }; std::unique_lock lock(value->mu); - if (value->ready) { + if (State(value->state).isAvailableOrError()) { lock.unlock(); execute(); } else { @@ -402,6 +475,14 @@ &mlir::runtime::mlirAsyncRuntimeEmplaceToken); exportSymbol("mlirAsyncRuntimeEmplaceValue", &mlir::runtime::mlirAsyncRuntimeEmplaceValue); + exportSymbol("mlirAsyncRuntimeSetTokenError", + &mlir::runtime::mlirAsyncRuntimeSetTokenError); + exportSymbol("mlirAsyncRuntimeSetValueError", + &mlir::runtime::mlirAsyncRuntimeSetValueError); + exportSymbol("mlirAsyncRuntimeIsTokenError", + &mlir::runtime::mlirAsyncRuntimeIsTokenError); + exportSymbol("mlirAsyncRuntimeIsValueError", + &mlir::runtime::mlirAsyncRuntimeIsValueError); exportSymbol("mlirAsyncRuntimeAwaitToken", &mlir::runtime::mlirAsyncRuntimeAwaitToken); exportSymbol("mlirAsyncRuntimeAwaitValue", diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir @@ -43,6 +43,24 @@ return } +// CHECK-LABEL: @is_token_error +func @is_token_error() -> i1 { + // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken + %0 = async.runtime.create : !async.token + // CHECK: %[[ERR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]]) + %1 = async.runtime.is_error %0 : !async.token + return %1 : i1 +} + +// CHECK-LABEL: @is_value_error +func @is_value_error() -> i1 { + // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue + %0 = async.runtime.create : !async.value + // CHECK: %[[ERR:.*]] = call @mlirAsyncRuntimeIsValueError(%[[VALUE]]) + %1 = async.runtime.is_error %0 : !async.value + return %1 : i1 +} + // CHECK-LABEL: @await_token func @await_token() { // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -split-input-file -async-to-async-runtime -print-ir-after-all | FileCheck %s --dump-input=always +// RUN: mlir-opt %s -split-input-file -async-to-async-runtime \ +// RUN: | FileCheck %s --dump-input=always // CHECK-LABEL: @execute_no_async_args func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) { @@ -101,11 +102,17 @@ // CHECK: async.coro.suspend %[[SAVED]] // CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]] -// Set token available after second resumption. +// Check the error of the awaited token after resumption. // CHECK: ^[[RESUME_1]]: +// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[INNER_TOKEN]] +// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]] + +// Set token available if the token is not in the error state. +// CHECK: ^[[CONTINUATION:.*]]: // CHECK: memref.store // CHECK: async.runtime.set_available %[[TOKEN]] +// CHECK: ^[[SET_ERROR]]: // CHECK: ^[[CLEANUP]]: // CHECK: ^[[SUSPEND]]: @@ -155,8 +162,13 @@ // CHECK: async.coro.suspend %[[SAVED]] // CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]] -// Emplace result token after second resumption. +// Check the error of the awaited token after resumption. // CHECK: ^[[RESUME_1]]: +// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG0]] +// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]] + +// Emplace result token after second resumption and error checking. +// CHECK: ^[[CONTINUATION:.*]]: // CHECK: memref.store // CHECK: async.runtime.set_available %[[TOKEN]] @@ -293,11 +305,65 @@ // CHECK: async.coro.suspend // CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]] -// Load from the async.value argument. +// Check the error of the awaited token after resumption. // CHECK: ^[[RESUME_1]]: +// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG]] +// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]] + +// // Load from the async.value argument after error checking. +// CHECK: ^[[CONTINUATION:.*]]: // CHECK: %[[LOADED:.*]] = async.runtime.load %[[ARG]] : !async.value !async.token + +// Create token for return op, and mark a function as a coroutine. +// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token +// CHECK: %[[ID:.*]] = async.coro.id +// CHECK: %[[HDL:.*]] = async.coro.begin + +// Initial coroutine suspension. +// CHECK: async.coro.suspend +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]] + +// Resume coroutine after suspension. +// CHECK: ^[[RESUME]]: +// CHECK: cond_br %[[ARG0]], ^[[SET_AVAILABLE:.*]], ^[[SET_ERROR:.*]] + +// Set coroutine completion token to available state. +// CHECK: ^[[SET_AVAILABLE]]: +// CHECK: async.runtime.set_available %[[TOKEN]] +// CHECK: br ^[[CLEANUP]] + +// Set coroutine completion token to error state. +// CHECK: ^[[SET_ERROR]]: +// CHECK: async.runtime.set_error %[[TOKEN]] +// CHECK: br ^[[CLEANUP]] + +// Delete coroutine. +// CHECK: ^[[CLEANUP]]: +// CHECK: async.coro.free %[[ID]], %[[HDL]] + +// Suspend coroutine, and also a return statement for ramp function. +// CHECK: ^[[SUSPEND]]: +// CHECK: async.coro.end %[[HDL]] +// CHECK: return %[[TOKEN]] diff --git a/mlir/test/Dialect/Async/runtime.mlir b/mlir/test/Dialect/Async/runtime.mlir --- a/mlir/test/Dialect/Async/runtime.mlir +++ b/mlir/test/Dialect/Async/runtime.mlir @@ -38,6 +38,34 @@ return } +// CHECK-LABEL: @set_token_error +func @set_token_error(%arg0: !async.token) { + // CHECK: async.runtime.set_error %arg0 : !async.token + async.runtime.set_error %arg0 : !async.token + return +} + +// CHECK-LABEL: @set_value_error +func @set_value_error(%arg0: !async.value) { + // CHECK: async.runtime.set_error %arg0 : !async.value + async.runtime.set_error %arg0 : !async.value + return +} + +// CHECK-LABEL: @is_token_error +func @is_token_error(%arg0: !async.token) -> i1 { + // CHECK: %[[ERR:.*]] = async.runtime.is_error %arg0 : !async.token + %0 = async.runtime.is_error %arg0 : !async.token + return %0 : i1 +} + +// CHECK-LABEL: @is_value_error +func @is_value_error(%arg0: !async.value) -> i1 { + // CHECK: %[[ERR:.*]] = async.runtime.is_error %arg0 : !async.value + %0 = async.runtime.is_error %arg0 : !async.value + return %0 : i1 +} + // CHECK-LABEL: @await_token func @await_token(%arg0: !async.token) { // CHECK: async.runtime.await %arg0 : !async.token diff --git a/mlir/test/mlir-cpu-runner/async-error.mlir b/mlir/test/mlir-cpu-runner/async-error.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/async-error.mlir @@ -0,0 +1,85 @@ +// RUN: mlir-opt %s -async-to-async-runtime \ +// RUN: -async-runtime-ref-counting \ +// RUN: -async-runtime-ref-counting-opt \ +// RUN: -convert-async-to-llvm \ +// RUN: -convert-linalg-to-loops \ +// RUN: -convert-scf-to-std \ +// RUN: -convert-linalg-to-llvm \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main -entry-point-result=void -O0 \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \ +// RUN: | FileCheck %s --dump-input=always + +func @main() { + %false = constant 0 : i1 + + // ------------------------------------------------------------------------ // + // Check that simple async region completes without errors. + // ------------------------------------------------------------------------ // + %token0 = async.execute { + async.yield + } + async.runtime.await %token0 : !async.token + + // CHECK: 0 + %err0 = async.runtime.is_error %token0 : !async.token + vector.print %err0 : i1 + + // ------------------------------------------------------------------------ // + // Check that assertion in the async region converted to async error. + // ------------------------------------------------------------------------ // + %token1 = async.execute { + assert %false, "error" + async.yield + } + async.runtime.await %token1 : !async.token + + // CHECK: 1 + %err1 = async.runtime.is_error %token1 : !async.token + vector.print %err1 : i1 + + // ------------------------------------------------------------------------ // + // Check error propagation from the nested region. + // ------------------------------------------------------------------------ // + %token2 = async.execute { + %token = async.execute { + assert %false, "error" + async.yield + } + async.await %token : !async.token + async.yield + } + async.runtime.await %token2 : !async.token + + // CHECK: 1 + %err2 = async.runtime.is_error %token2 : !async.token + vector.print %err2 : i1 + + // ------------------------------------------------------------------------ // + // Check error propagation from the nested region with async values. + // ------------------------------------------------------------------------ // + %token3, %value3 = async.execute -> !async.value { + %token, %value = async.execute -> !async.value { + assert %false, "error" + %0 = constant 123.45 : f32 + async.yield %0 : f32 + } + %ret = async.await %value : !async.value + async.yield %ret : f32 + } + async.runtime.await %token3 : !async.token + async.runtime.await %value3 : !async.value + + // CHECK: 1 + // CHECK: 1 + %err3_0 = async.runtime.is_error %token3 : !async.token + %err3_1 = async.runtime.is_error %value3 : !async.value + vector.print %err3_0 : i1 + vector.print %err3_1 : i1 + + return +}