diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -53,6 +53,10 @@ return numRefCountedObjects.load(std::memory_order_relaxed); } +#if LLVM_ENABLE_THREADSS + llvm::ThreadPool &getThreadPool() { return threadPool; } +#endif + private: friend class RefCounted; @@ -66,6 +70,10 @@ } std::atomic numRefCountedObjects; + +#if LLVM_ENABLE_THREADSS + llvm::ThreadPool threadPool; +#endif }; // Returns the default per-process instance of an async runtime. @@ -143,15 +151,13 @@ }; // Adds references to reference counted runtime object. -extern "C" MLIR_ASYNCRUNTIME_EXPORT void -mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { +extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { RefCounted *refCounted = static_cast(ptr); refCounted->addRef(count); } // Drops references from reference counted runtime object. -extern "C" MLIR_ASYNCRUNTIME_EXPORT void -mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { +extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { RefCounted *refCounted = static_cast(ptr); refCounted->dropRef(count); } @@ -163,13 +169,13 @@ } // Create a new `async.group` in empty state. -extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() { +extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() { AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance()); return group; } -extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t -mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) { +extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, + AsyncGroup *group) { std::unique_lock lockToken(token->mu); std::unique_lock lockGroup(group->mu); @@ -177,27 +183,35 @@ int rank = group->rank.fetch_add(1); group->pendingTokens.fetch_add(1); - auto onTokenReady = [group, token](bool dropRef) { + auto onTokenReady = [group]() { // Run all group awaiters if it was the last token in the group. if (group->pendingTokens.fetch_sub(1) == 1) { group->cv.notify_all(); for (auto &awaiter : group->awaiters) awaiter(); } - - // We no longer need the token or the group, drop references on them. - if (dropRef) { - group->dropRef(); - token->dropRef(); - } }; if (token->ready) { - onTokenReady(false); + // Update group pending tokens immediately and maybe run awaiters. + onTokenReady(); + } else { + // Update group pending tokens when token will become ready. Because this + // will happen asynchronously we must ensure that `group` and `token` are + // alive until then, and re-ackquire the lock. group->addRef(); token->addRef(); - token->awaiters.push_back([onTokenReady]() { onTokenReady(true); }); + + token->awaiters.push_back([group, token, onTokenReady]() { + // Make sure that `dropRef` does not destroy the mutex owned by the lock. + { + std::unique_lock lockGroup(group->mu); + onTokenReady(); + } + group->dropRef(); + token->dropRef(); + }); } return rank; @@ -205,11 +219,14 @@ // Switches `async.token` to ready state and runs all awaiters. extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { - std::unique_lock lock(token->mu); - token->ready = true; - token->cv.notify_all(); - for (auto &awaiter : token->awaiters) - awaiter(); + // Make sure that `dropRef` does not destroy the mutex owned by the lock. + { + std::unique_lock lock(token->mu); + token->ready = true; + token->cv.notify_all(); + for (auto &awaiter : token->awaiters) + awaiter(); + } // Async tokens created with a ref count `2` to keep token alive until the // async task completes. Drop this reference explicitly when token emplaced. @@ -222,17 +239,16 @@ token->cv.wait(lock, [token] { return token->ready; }); } -extern "C" MLIR_ASYNCRUNTIME_EXPORT void -mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { +extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { std::unique_lock lock(group->mu); if (group->pendingTokens != 0) group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); } extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { -#if LLVM_ENABLE_THREADS - static llvm::ThreadPool *threadPool = new llvm::ThreadPool(); - threadPool->async([handle, resume]() { (*resume)(handle); }); +#if LLVM_ENABLE_THREADSS + auto *runtime = getDefaultAsyncRuntimeInstance(); + runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); }); #else (*resume)(handle); #endif @@ -246,7 +262,7 @@ auto execute = [handle, resume, token](bool dropRef) { if (dropRef) token->dropRef(); - mlirAsyncRuntimeExecute(handle, resume); + (*resume)(handle); }; if (token->ready) { @@ -257,15 +273,15 @@ } } -extern "C" MLIR_ASYNCRUNTIME_EXPORT void -mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle, - CoroResume resume) { +extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, + CoroHandle handle, + CoroResume resume) { std::unique_lock lock(group->mu); auto execute = [handle, resume, group](bool dropRef) { if (dropRef) group->dropRef(); - mlirAsyncRuntimeExecute(handle, resume); + (*resume)(handle); }; if (group->pendingTokens == 0) {