diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -262,7 +262,7 @@ //===--------------------------------------------------------------------===// /// The thread pool to use when processing MLIR tasks in parallel. - llvm::ThreadPool threadPool; + std::unique_ptr threadPool; /// This is a list of dialects that are created referring to this context. /// The MLIRContext owns the objects. @@ -334,7 +334,9 @@ StringAttr emptyStringAttr; public: - MLIRContextImpl() : identifiers(identifierAllocator) {} + MLIRContextImpl() + : threadPool(std::make_unique()), + identifiers(identifierAllocator) {} ~MLIRContextImpl() { for (auto typeMapping : registeredTypes) typeMapping.second->~AbstractType(); @@ -573,12 +575,20 @@ impl->affineUniquer.disableMultithreading(disable); impl->attributeUniquer.disableMultithreading(disable); impl->typeUniquer.disableMultithreading(disable); + + // Destroy thread pool (stop all threads) if it is no longer needed. + if (!impl->threadingIsEnabled) + impl->threadPool.reset(); + + // Create new thread pool if multithreading was re-enabled. + if (impl->threadingIsEnabled && impl->threadPool == nullptr) + impl->threadPool = std::make_unique(); } llvm::ThreadPool &MLIRContext::getThreadPool() { assert(isMultithreadingEnabled() && "expected multi-threading to be enabled within the context"); - return impl->threadPool; + return *impl->threadPool; } void MLIRContext::enterMultiThreadedExecution() {