diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -262,7 +262,7 @@ //===--------------------------------------------------------------------===// /// The thread pool to use when processing MLIR tasks in parallel. - llvm::ThreadPool threadPool; + llvm::Optional threadPool; /// This is a list of dialects that are created referring to this context. /// The MLIRContext owns the objects. @@ -334,7 +334,10 @@ StringAttr emptyStringAttr; public: - MLIRContextImpl() : identifiers(identifierAllocator) {} + MLIRContextImpl() : identifiers(identifierAllocator) { + if (threadingIsEnabled) + threadPool.emplace(); + } ~MLIRContextImpl() { for (auto typeMapping : registeredTypes) typeMapping.second->~AbstractType(); @@ -573,12 +576,20 @@ impl->affineUniquer.disableMultithreading(disable); impl->attributeUniquer.disableMultithreading(disable); impl->typeUniquer.disableMultithreading(disable); + + // Destroy thread pool (stop all threads) if it is no longer needed. + if (!impl->threadingIsEnabled) + impl->threadPool.reset(); + + // Create new thread pool if multithreading was re-enabled. + if (impl->threadingIsEnabled && !impl->threadPool.hasValue()) + impl->threadPool.emplace(); } llvm::ThreadPool &MLIRContext::getThreadPool() { assert(isMultithreadingEnabled() && "expected multi-threading to be enabled within the context"); - return impl->threadPool; + return *impl->threadPool; } void MLIRContext::enterMultiThreadedExecution() {