diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -45,6 +45,7 @@ AsyncRuntime() : numRefCountedObjects(0) {} ~AsyncRuntime() { + threadPool.wait(); // wait for the completion of all async tasks assert(getNumRefCountedObjects() == 0 && "all ref counted objects must be destroyed"); } @@ -53,6 +54,8 @@ return numRefCountedObjects.load(std::memory_order_relaxed); } + llvm::ThreadPool &getThreadPool() { return threadPool; } + private: friend class RefCounted; @@ -66,6 +69,8 @@ } std::atomic numRefCountedObjects; + + llvm::ThreadPool threadPool; }; // Returns the default per-process instance of an async runtime. @@ -143,15 +148,13 @@ }; // Adds references to reference counted runtime object. -extern "C" MLIR_ASYNCRUNTIME_EXPORT void -mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { +extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { RefCounted *refCounted = static_cast(ptr); refCounted->addRef(count); } // Drops references from reference counted runtime object. -extern "C" MLIR_ASYNCRUNTIME_EXPORT void -mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { +extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { RefCounted *refCounted = static_cast(ptr); refCounted->dropRef(count); } @@ -163,13 +166,13 @@ } // Create a new `async.group` in empty state. -extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() { +extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() { AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance()); return group; } -extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t -mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) { +extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, + AsyncGroup *group) { std::unique_lock lockToken(token->mu); std::unique_lock lockGroup(group->mu); @@ -177,27 +180,33 @@ int rank = group->rank.fetch_add(1); group->pendingTokens.fetch_add(1); - auto onTokenReady = [group, token](bool dropRef) { + auto onTokenReady = [group]() { // Run all group awaiters if it was the last token in the group. if (group->pendingTokens.fetch_sub(1) == 1) { group->cv.notify_all(); for (auto &awaiter : group->awaiters) awaiter(); } - - // We no longer need the token or the group, drop references on them. - if (dropRef) { - group->dropRef(); - token->dropRef(); - } }; if (token->ready) { - onTokenReady(false); + // Update group pending tokens immediately and maybe run awaiters. + onTokenReady(); + } else { + // Update group pending tokens when token will become ready. Because this + // will happen asynchronously we must ensure that `group` is alive until + // then, and re-ackquire the lock. group->addRef(); - token->addRef(); - token->awaiters.push_back([onTokenReady]() { onTokenReady(true); }); + + token->awaiters.push_back([group, token, onTokenReady]() { + // Make sure that `dropRef` does not destroy the mutex owned by the lock. + { + std::unique_lock lockGroup(group->mu); + onTokenReady(); + } + group->dropRef(); + }); } return rank; @@ -205,11 +214,14 @@ // Switches `async.token` to ready state and runs all awaiters. extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { - std::unique_lock lock(token->mu); - token->ready = true; - token->cv.notify_all(); - for (auto &awaiter : token->awaiters) - awaiter(); + // Make sure that `dropRef` does not destroy the mutex owned by the lock. + { + std::unique_lock lock(token->mu); + token->ready = true; + token->cv.notify_all(); + for (auto &awaiter : token->awaiters) + awaiter(); + } // Async tokens created with a ref count `2` to keep token alive until the // async task completes. Drop this reference explicitly when token emplaced. @@ -222,58 +234,37 @@ token->cv.wait(lock, [token] { return token->ready; }); } -extern "C" MLIR_ASYNCRUNTIME_EXPORT void -mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { +extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { std::unique_lock lock(group->mu); if (group->pendingTokens != 0) group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); } extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { -#if LLVM_ENABLE_THREADS - static llvm::ThreadPool *threadPool = new llvm::ThreadPool(); - threadPool->async([handle, resume]() { (*resume)(handle); }); -#else - (*resume)(handle); -#endif + auto *runtime = getDefaultAsyncRuntimeInstance(); + runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); }); } extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, CoroHandle handle, CoroResume resume) { std::unique_lock lock(token->mu); - - auto execute = [handle, resume, token](bool dropRef) { - if (dropRef) - token->dropRef(); - mlirAsyncRuntimeExecute(handle, resume); - }; - - if (token->ready) { - execute(false); - } else { - token->addRef(); - token->awaiters.push_back([execute]() { execute(true); }); - } + auto execute = [handle, resume, token]() { (*resume)(handle); }; + if (token->ready) + execute(); + else + token->awaiters.push_back([execute]() { execute(); }); } -extern "C" MLIR_ASYNCRUNTIME_EXPORT void -mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle, - CoroResume resume) { +extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, + CoroHandle handle, + CoroResume resume) { std::unique_lock lock(group->mu); - - auto execute = [handle, resume, group](bool dropRef) { - if (dropRef) - group->dropRef(); - mlirAsyncRuntimeExecute(handle, resume); - }; - - if (group->pendingTokens == 0) { - execute(false); - } else { - group->addRef(); - group->awaiters.push_back([execute]() { execute(true); }); - } + auto execute = [handle, resume, group]() { (*resume)(handle); }; + if (group->pendingTokens == 0) + execute(); + else + group->awaiters.push_back([execute]() { execute(); }); } //===----------------------------------------------------------------------===// @@ -282,7 +273,7 @@ extern "C" void mlirAsyncRuntimePrintCurrentThreadId() { static thread_local std::thread::id thisId = std::this_thread::get_id(); - std::cout << "Current thread id: " << thisId << "\n"; + std::cout << "Current thread id: " << thisId << std::endl; } #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS