diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -373,7 +373,7 @@ resuming after `await_and_resume`. }]; - let arguments = (ins Async_AnyValueOrTokenType:$operand); + let arguments = (ins Async_AnyAsyncType:$operand); let results = (outs I1:$is_error); let assemblyFormat = "$operand attr-dict `:` type($operand)"; diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h --- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h @@ -88,6 +88,10 @@ // Returns true if value is in the error state. extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *); +// Returns true if group is in the error state (any of the tokens or values +// added to the group are in the error state). +extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *); + // Blocks the caller thread until the token becomes ready. extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *); diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/TypeSwitch.h" #define DEBUG_TYPE "convert-async-to-llvm" @@ -39,6 +40,7 @@ static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError"; static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError"; static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError"; +static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError"; static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; @@ -125,6 +127,11 @@ return FunctionType::get(ctx, {value}, {i1}); } + static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) { + auto i1 = IntegerType::get(ctx, 1); + return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1}); + } + static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); } @@ -201,6 +208,7 @@ addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); + addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); @@ -587,10 +595,13 @@ LogicalResult matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - Type operandType = op.operand().getType(); - rewriter.replaceOpWithNewOp( - op, operandType.isa() ? kEmplaceToken : kEmplaceValue, - TypeRange(), operands); + StringRef apiFuncName = + TypeSwitch(op.operand().getType()) + .Case([](Type) { return kEmplaceToken; }) + .Case([](Type) { return kEmplaceValue; }); + + rewriter.replaceOpWithNewOp(op, apiFuncName, TypeRange(), operands); + return success(); } }; @@ -609,10 +620,13 @@ LogicalResult matchAndRewrite(RuntimeSetErrorOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - Type operandType = op.operand().getType(); - rewriter.replaceOpWithNewOp( - op, operandType.isa() ? kSetTokenError : kSetValueError, - TypeRange(), operands); + StringRef apiFuncName = + TypeSwitch(op.operand().getType()) + .Case([](Type) { return kSetTokenError; }) + .Case([](Type) { return kSetValueError; }); + + rewriter.replaceOpWithNewOp(op, apiFuncName, TypeRange(), operands); + return success(); } }; @@ -630,10 +644,14 @@ LogicalResult matchAndRewrite(RuntimeIsErrorOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - Type operandType = op.operand().getType(); - rewriter.replaceOpWithNewOp( - op, operandType.isa() ? kIsTokenError : kIsValueError, - rewriter.getI1Type(), operands); + StringRef apiFuncName = + TypeSwitch(op.operand().getType()) + .Case([](Type) { return kIsTokenError; }) + .Case([](Type) { return kIsGroupError; }) + .Case([](Type) { return kIsValueError; }); + + rewriter.replaceOpWithNewOp(op, apiFuncName, rewriter.getI1Type(), + operands); return success(); } }; @@ -651,17 +669,11 @@ LogicalResult matchAndRewrite(RuntimeAwaitOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - Type operandType = op.operand().getType(); - - StringRef apiFuncName; - if (operandType.isa()) - apiFuncName = kAwaitToken; - else if (operandType.isa()) - apiFuncName = kAwaitValue; - else if (operandType.isa()) - apiFuncName = kAwaitGroup; - else - return rewriter.notifyMatchFailure(op, "unsupported async type"); + StringRef apiFuncName = + TypeSwitch(op.operand().getType()) + .Case([](Type) { return kAwaitToken; }) + .Case([](Type) { return kAwaitValue; }) + .Case([](Type) { return kAwaitGroup; }); rewriter.create(op->getLoc(), apiFuncName, TypeRange(), operands); rewriter.eraseOp(op); @@ -684,17 +696,11 @@ LogicalResult matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - Type operandType = op.operand().getType(); - - StringRef apiFuncName; - if (operandType.isa()) - apiFuncName = kAwaitTokenAndExecute; - else if (operandType.isa()) - apiFuncName = kAwaitValueAndExecute; - else if (operandType.isa()) - apiFuncName = kAwaitAllAndExecute; - else - return rewriter.notifyMatchFailure(op, "unsupported async type"); + StringRef apiFuncName = + TypeSwitch(op.operand().getType()) + .Case([](Type) { return kAwaitTokenAndExecute; }) + .Case([](Type) { return kAwaitValueAndExecute; }) + .Case([](Type) { return kAwaitAllAndExecute; }); Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand(); Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle(); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -399,25 +399,22 @@ builder.create(coroSaveOp.state(), coro.suspend, resume, coro.cleanup); - // TODO: Async groups do not yet support runtime errors. - if (!std::is_same::value) { - // Split the resume block into error checking and continuation. - Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); - - // Check if the awaited value is in the error state. - builder.setInsertionPointToStart(resume); - auto isError = builder.create( - loc, rewriter.getI1Type(), operand); - builder.create(isError, - /*trueDest=*/setupSetErrorBlock(coro), - /*trueArgs=*/ArrayRef(), - /*falseDest=*/continuation, - /*falseArgs=*/ArrayRef()); - - // Make sure that replacement value will be constructed in the - // continuation block. - rewriter.setInsertionPointToStart(continuation); - } + // Split the resume block into error checking and continuation. + Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); + + // Check if the awaited value is in the error state. + builder.setInsertionPointToStart(resume); + auto isError = + builder.create(loc, rewriter.getI1Type(), operand); + builder.create(isError, + /*trueDest=*/setupSetErrorBlock(coro), + /*trueArgs=*/ArrayRef(), + /*falseDest=*/continuation, + /*falseArgs=*/ArrayRef()); + + // Make sure that replacement value will be constructed in the + // continuation block. + rewriter.setInsertionPointToStart(continuation); } // Erase or replace the await operation with the new value. diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -212,9 +212,10 @@ // tokens or values added to the group). struct AsyncGroup : public RefCounted { AsyncGroup(AsyncRuntime *runtime) - : RefCounted(runtime), pendingTokens(0), rank(0) {} + : RefCounted(runtime), pendingTokens(0), numErrors(0), rank(0) {} std::atomic pendingTokens; + std::atomic numErrors; std::atomic rank; // Pending awaiters are guarded by a mutex. @@ -262,7 +263,11 @@ int rank = group->rank.fetch_add(1); group->pendingTokens.fetch_add(1); - auto onTokenReady = [group]() { + auto onTokenReady = [group, token]() { + // Increment the number of errors in the group. + if (State(token->state).isError()) + group->numErrors.fetch_add(1); + // Run all group awaiters if it was the last token in the group. if (group->pendingTokens.fetch_sub(1) == 1) { group->cv.notify_all(); @@ -356,6 +361,10 @@ return State(value->state).isError(); } +extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) { + return group->numErrors.load() > 0; +} + extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { std::unique_lock lock(token->mu); if (!State(token->state).isAvailableOrError()) @@ -483,6 +492,8 @@ &mlir::runtime::mlirAsyncRuntimeIsTokenError); exportSymbol("mlirAsyncRuntimeIsValueError", &mlir::runtime::mlirAsyncRuntimeIsValueError); + exportSymbol("mlirAsyncRuntimeIsGroupError", + &mlir::runtime::mlirAsyncRuntimeIsGroupError); exportSymbol("mlirAsyncRuntimeAwaitToken", &mlir::runtime::mlirAsyncRuntimeAwaitToken); exportSymbol("mlirAsyncRuntimeAwaitValue", diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir @@ -216,8 +216,13 @@ // CHECK: async.coro.suspend // CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]] -// Emplace result token. +// Check the error of the awaited token after resumption. // CHECK: ^[[RESUME_1]]: +// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG]] +// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]] + +// Emplace result token after error checking. +// CHECK: ^[[CONTINUATION:.*]]: // CHECK: async.runtime.set_available %[[TOKEN]] // CHECK: ^[[CLEANUP]]: diff --git a/mlir/test/Dialect/Async/runtime.mlir b/mlir/test/Dialect/Async/runtime.mlir --- a/mlir/test/Dialect/Async/runtime.mlir +++ b/mlir/test/Dialect/Async/runtime.mlir @@ -66,6 +66,13 @@ return %0 : i1 } +// CHECK-LABEL: @is_group_error +func @is_group_error(%arg0: !async.group) -> i1 { + // CHECK: %[[ERR:.*]] = async.runtime.is_error %arg0 : !async.group + %0 = async.runtime.is_error %arg0 : !async.group + return %0 : i1 +} + // CHECK-LABEL: @await_token func @await_token(%arg0: !async.token) { // CHECK: async.runtime.await %arg0 : !async.token diff --git a/mlir/test/mlir-cpu-runner/async-error.mlir b/mlir/test/mlir-cpu-runner/async-error.mlir --- a/mlir/test/mlir-cpu-runner/async-error.mlir +++ b/mlir/test/mlir-cpu-runner/async-error.mlir @@ -81,5 +81,29 @@ vector.print %err3_0 : i1 vector.print %err3_1 : i1 + // ------------------------------------------------------------------------ // + // Check error propagation from a token to the group. + // ------------------------------------------------------------------------ // + + %group0 = async.create_group + + %token4 = async.execute { + async.yield + } + + %token5 = async.execute { + assert %false, "error" + async.yield + } + + %idx0 = async.add_to_group %token4, %group0 : !async.token + %idx1 = async.add_to_group %token5, %group0 : !async.token + + async.runtime.await %group0 : !async.group + + // CHECK: 1 + %err4 = async.runtime.is_error %group0 : !async.group + vector.print %err4 : i1 + return }