diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp @@ -525,10 +525,6 @@ bool isGroup = type.isa(); bool isValue = type.isa(); - // Drop reference after async token or group await (sync await) - if (auto await = dyn_cast(op)) - return (isToken || isGroup) ? -1 : 0; - // Drop reference after async token or group error check (coro await). if (auto await = dyn_cast(op)) return (isToken || isGroup) ? -1 : 0; diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -397,10 +397,23 @@ Location loc = op->getLoc(); Value operand = AwaitAdaptor(operands).operand(); + Type i1 = rewriter.getI1Type(); + // Inside regular functions we use the blocking wait operation to wait for // the async object (token, value or group) to become available. - if (!isInCoroutine) - rewriter.create(loc, operand); + if (!isInCoroutine) { + ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); + builder.create(loc, operand); + + // Assert that the awaited operands is not in the error state. + Value isError = builder.create(i1, operand); + Value notError = builder.create( + isError, + builder.create(loc, i1, builder.getIntegerAttr(i1, 1))); + + builder.create(notError, + "Awaited async operand is in error state"); + } // Inside the coroutine we convert await operation into coroutine suspension // point, and resume execution asynchronously. @@ -430,8 +443,7 @@ // Check if the awaited value is in the error state. builder.setInsertionPointToStart(resume); - auto isError = - builder.create(loc, rewriter.getI1Type(), operand); + auto isError = builder.create(loc, i1, operand); builder.create(isError, /*trueDest=*/setupSetErrorBlock(coro), /*trueArgs=*/ArrayRef(), @@ -772,7 +784,8 @@ }); return !walkResult.wasInterrupted(); }); - runtimeTarget.addLegalOp(); + runtimeTarget + .addLegalOp(); // Assertions must be converted to runtime errors inside async functions. runtimeTarget.addDynamicallyLegalOp([&](AssertOp op) -> bool { diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -24,6 +24,10 @@ async.yield } // CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]]) + // CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]]) + // CHECK: %[[TRUE:.*]] = constant true + // CHECK: %[[NOT_ERROR:.*]] = xor %[[IS_ERROR]], %[[TRUE]] : i1 + // CHECK: assert %[[NOT_ERROR]] // CHECK-NEXT: return async.await %token : !async.token return @@ -83,7 +87,10 @@ async.yield } // CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]]) - // CHECK-NEXT: return + // CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]]) + // CHECK: %[[TRUE:.*]] = constant true + // CHECK: %[[NOT_ERROR:.*]] = xor %[[IS_ERROR]], %[[TRUE]] : i1 + // CHECK: assert %[[NOT_ERROR]] async.await %token0 : !async.token return } diff --git a/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir b/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir --- a/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir +++ b/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir @@ -4,7 +4,7 @@ // CHECK: %[[TOKEN:.*]]: !async.token func @token_await(%arg0: !async.token) { // CHECK: async.runtime.await %[[TOKEN]] - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK-NOT: async.runtime.drop_ref async.runtime.await %arg0 : !async.token return } @@ -13,7 +13,7 @@ // CHECK: %[[GROUP:.*]]: !async.group func @group_await(%arg0: !async.group) { // CHECK: async.runtime.await %[[GROUP]] - // CHECK: async.runtime.drop_ref %[[GROUP]] {count = 1 : i32} + // CHECK-NOT: async.runtime.drop_ref async.runtime.await %arg0 : !async.group return } diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir @@ -60,6 +60,10 @@ async.yield } // CHECK: async.runtime.await %[[TOKEN]] + // CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[TOKEN]] + // CHECK: %[[TRUE:.*]] = constant true + // CHECK: %[[NOT_ERROR:.*]] = xor %[[IS_ERROR]], %[[TRUE]] : i1 + // CHECK: assert %[[NOT_ERROR]] // CHECK-NEXT: return async.await %token0 : !async.token return