diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -24,6 +24,8 @@ #include #include +#include "llvm/Support/ThreadPool.h" + using namespace mlir::runtime; // Shutdown hook registration function will be resolved to the symbol defined by @@ -53,6 +55,7 @@ AsyncRuntime() : numRefCountedObjects(0) {} ~AsyncRuntime() { + threadPool.wait(); // wait for the completion of all async tasks assert(getNumRefCountedObjects() == 0 && "all ref counted objects must be destroyed"); } @@ -61,6 +64,8 @@ return numRefCountedObjects.load(std::memory_order_relaxed); } + llvm::ThreadPool &getThreadPool() { return threadPool; } + private: friend class RefCounted; @@ -74,6 +79,7 @@ } std::atomic numRefCountedObjects; + llvm::ThreadPool threadPool; }; // -------------------------------------------------------------------------- // @@ -311,7 +317,8 @@ } extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { - (*resume)(handle); + auto *runtime = getDefaultAsyncRuntimeInstance(); + runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); }); } extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,