diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -136,13 +136,14 @@ // asynchronously executed task. If the caller immediately will drop its // reference we must ensure that the token will be alive until the // asynchronous operation is completed. - AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {} + AsyncToken(AsyncRuntime *runtime) + : RefCounted(runtime, /*count=*/2), ready(false) {} - // Internal state below guarded by a mutex. + std::atomic ready; + + // Pending awaiters are guarded by a mutex. std::mutex mu; std::condition_variable cv; - - bool ready = false; std::vector> awaiters; }; @@ -152,17 +153,17 @@ struct AsyncValue : public RefCounted { // AsyncValue similar to an AsyncToken created with a reference count of 2. AsyncValue(AsyncRuntime *runtime, int32_t size) - : RefCounted(runtime, /*count=*/2), storage(size) {} - - // Internal state below guarded by a mutex. - std::mutex mu; - std::condition_variable cv; + : RefCounted(runtime, /*count=*/2), ready(false), storage(size) {} - bool ready = false; - std::vector> awaiters; + std::atomic ready; // Use vector of bytes to store async value payload. std::vector storage; + + // Pending awaiters are guarded by a mutex. + std::mutex mu; + std::condition_variable cv; + std::vector> awaiters; }; // Async group provides a mechanism to group together multiple async tokens or @@ -175,10 +176,9 @@ std::atomic pendingTokens; std::atomic rank; - // Internal state below guarded by a mutex. + // Pending awaiters are guarded by a mutex. std::mutex mu; std::condition_variable cv; - std::vector> awaiters; }; @@ -291,13 +291,13 @@ extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { std::unique_lock lock(token->mu); if (!token->ready) - token->cv.wait(lock, [token] { return token->ready; }); + token->cv.wait(lock, [token] { return token->ready.load(); }); } extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) { std::unique_lock lock(value->mu); if (!value->ready) - value->cv.wait(lock, [value] { return value->ready; }); + value->cv.wait(lock, [value] { return value->ready.load(); }); } extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { @@ -319,34 +319,37 @@ extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, CoroHandle handle, CoroResume resume) { - std::unique_lock lock(token->mu); auto execute = [handle, resume]() { (*resume)(handle); }; - if (token->ready) + if (token->ready) { execute(); - else + } else { + std::unique_lock lock(token->mu); token->awaiters.push_back([execute]() { execute(); }); + } } extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value, CoroHandle handle, CoroResume resume) { - std::unique_lock lock(value->mu); auto execute = [handle, resume]() { (*resume)(handle); }; - if (value->ready) + if (value->ready) { execute(); - else + } else { + std::unique_lock lock(value->mu); value->awaiters.push_back([execute]() { execute(); }); + } } extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle, CoroResume resume) { - std::unique_lock lock(group->mu); auto execute = [handle, resume]() { (*resume)(handle); }; - if (group->pendingTokens == 0) + if (group->pendingTokens == 0) { execute(); - else + } else { + std::unique_lock lock(group->mu); group->awaiters.push_back([execute]() { execute(); }); + } } //===----------------------------------------------------------------------===//