diff --git a/mlir/include/mlir/IR/Threading.h b/mlir/include/mlir/IR/Threading.h --- a/mlir/include/mlir/IR/Threading.h +++ b/mlir/include/mlir/IR/Threading.h @@ -41,10 +41,7 @@ // If multithreading is disabled or there is a small number of elements, // process the elements directly on this thread. - // FIXME: ThreadPool should allow work stealing to avoid deadlocks when - // scheduling work within a worker thread. - if (!context->isMultithreadingEnabled() || numElements <= 1 || - context->getThreadPool().isWorkerThread()) { + if (!context->isMultithreadingEnabled() || numElements <= 1) { for (; begin != end; ++begin) if (failed(func(*begin))) return failure(); @@ -72,9 +69,10 @@ llvm::ThreadPool &threadPool = context->getThreadPool(); size_t numActions = std::min(numElements, threadPool.getThreadCount()); SmallVector> threadFutures; + llvm::ThreadPoolTaskGroup tasksGroup(threadPool); threadFutures.reserve(numActions - 1); for (unsigned i = 1; i < numActions; ++i) - threadFutures.emplace_back(threadPool.async(processFn)); + threadFutures.emplace_back(tasksGroup.async(processFn)); processFn(); // Wait for all of the threads to finish.