diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -25,6 +25,7 @@ #include #include "llvm/ADT/StringMap.h" +#include "llvm/Support/ThreadPool.h" using namespace mlir::runtime; @@ -49,6 +50,7 @@ AsyncRuntime() : numRefCountedObjects(0) {} ~AsyncRuntime() { + threadPool.wait(); // wait for the completion of all async tasks assert(getNumRefCountedObjects() == 0 && "all ref counted objects must be destroyed"); } @@ -57,6 +59,8 @@ return numRefCountedObjects.load(std::memory_order_relaxed); } + llvm::ThreadPool &getThreadPool() { return threadPool; } + private: friend class RefCounted; @@ -70,6 +74,7 @@ } std::atomic numRefCountedObjects; + llvm::ThreadPool threadPool; }; // -------------------------------------------------------------------------- // @@ -307,7 +312,8 @@ } extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { - (*resume)(handle); + auto *runtime = getDefaultAsyncRuntime(); + runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); }); } extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,