Index: llvm/include/llvm/Support/ThreadPool.h =================================================================== --- llvm/include/llvm/Support/ThreadPool.h +++ llvm/include/llvm/Support/ThreadPool.h @@ -27,6 +27,8 @@ #include #include +#include + namespace llvm { /// A ThreadPool for asynchronous parallel execution on a defined number of @@ -35,6 +37,20 @@ /// The pool keeps a vector of threads alive, waiting on a condition variable /// for some work to become available. class ThreadPool { + struct TaskBase { + virtual ~TaskBase() {} + virtual void execute() = 0; + }; + + template struct TypedTask : public TaskBase { + explicit TypedTask(std::packaged_task Task) + : Task(std::move(Task)) {} + + void execute() override { Task(); } + + std::packaged_task Task; + }; + public: using TaskTy = std::function; using PackagedTaskTy = std::packaged_task; @@ -52,7 +68,8 @@ /// Asynchronous submission of a task to the pool. The returned future can be /// used to wait for the task to finish and is *non-blocking* on destruction. template - inline std::shared_future async(Function &&F, Args &&... ArgList) { + inline std::shared_future::type> + async(Function &&F, Args &&... ArgList) { auto Task = std::bind(std::forward(F), std::forward(ArgList)...); return asyncImpl(std::move(Task)); @@ -61,7 +78,8 @@ /// Asynchronous submission of a task to the pool. The returned future can be /// used to wait for the task to finish and is *non-blocking* on destruction. template - inline std::shared_future async(Function &&F) { + inline std::shared_future::type> + async(Function &&F) { return asyncImpl(std::forward(F)); } @@ -72,13 +90,35 @@ private: /// Asynchronous submission of a task to the pool. The returned future can be /// used to wait for the task to finish and is *non-blocking* on destruction. - std::shared_future asyncImpl(TaskTy F); + template + std::shared_future::type> + asyncImpl(TaskTy &&Task) { + typedef decltype(Task()) ResultTy; + + /// Wrap the Task in a packaged_task to return a future object. + std::packaged_task PackagedTask(std::move(Task)); + auto Future = PackagedTask.get_future(); + std::unique_ptr TB = + llvm::make_unique>(std::move(PackagedTask)); + + { + // Lock the queue and push the new task + std::unique_lock LockGuard(QueueLock); + + // Don't allow enqueueing after disabling the pool + assert(EnableFlag && "Queuing a thread during ThreadPool destruction"); + + Tasks.push(std::move(TB)); + } + QueueCondition.notify_one(); + return Future.share(); + } /// Threads in flight std::vector Threads; /// Tasks waiting for execution in the pool. - std::queue Tasks; + std::queue> Tasks; /// Locking and signaling for accessing the Tasks queue. std::mutex QueueLock; Index: llvm/lib/Support/ThreadPool.cpp =================================================================== --- llvm/lib/Support/ThreadPool.cpp +++ llvm/lib/Support/ThreadPool.cpp @@ -32,7 +32,7 @@ for (unsigned ThreadID = 0; ThreadID < ThreadCount; ++ThreadID) { Threads.emplace_back([&] { while (true) { - PackagedTaskTy Task; + std::unique_ptr Task; { std::unique_lock LockGuard(QueueLock); // Wait for tasks to be pushed in the queue @@ -54,7 +54,7 @@ Tasks.pop(); } // Run the task we just grabbed - Task(); + Task->execute(); { // Adjust `ActiveThreads`, in case someone waits on ThreadPool::wait() @@ -79,23 +79,6 @@ [&] { return !ActiveThreads && Tasks.empty(); }); } -std::shared_future ThreadPool::asyncImpl(TaskTy Task) { - /// Wrap the Task in a packaged_task to return a future object. - PackagedTaskTy PackagedTask(std::move(Task)); - auto Future = PackagedTask.get_future(); - { - // Lock the queue and push the new task - std::unique_lock LockGuard(QueueLock); - - // Don't allow enqueueing after disabling the pool - assert(EnableFlag && "Queuing a thread during ThreadPool destruction"); - - Tasks.push(std::move(PackagedTask)); - } - QueueCondition.notify_one(); - return Future.share(); -} - // The destructor joins all threads, waiting for completion. ThreadPool::~ThreadPool() { { Index: llvm/unittests/Support/ThreadPool.cpp =================================================================== --- llvm/unittests/Support/ThreadPool.cpp +++ llvm/unittests/Support/ThreadPool.cpp @@ -147,6 +147,25 @@ ASSERT_EQ(2, i.load()); } +TEST_F(ThreadPoolTest, TaskWithResult) { + CHECK_UNSUPPORTED(); + // By making only 1 thread in the pool the two tasks are serialized with + // respect to each other, which means that the second one must return 2. + ThreadPool Pool{1}; + std::atomic_int i{0}; + Pool.async([this, &i] { + waitForMainThread(); + ++i; + }); + // Force the future using get() + std::shared_future Future = Pool.async([&i] { return ++i; }); + ASSERT_EQ(0, i.load()); + setMainThreadReady(); + int Result = Future.get(); + ASSERT_EQ(2, i.load()); + ASSERT_EQ(2, Result); +} + TEST_F(ThreadPoolTest, PoolDestruction) { CHECK_UNSUPPORTED(); // Test that we are waiting on destruction