diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -40,9 +40,11 @@ /// class MLIRContext { public: + enum class Threading { disabled, enabled }; /// Create a new Context. - explicit MLIRContext(); - explicit MLIRContext(const DialectRegistry ®istry); + explicit MLIRContext(Threading multithreading = Threading::enabled); + explicit MLIRContext(const DialectRegistry ®istry, + Threading multithreading = Threading::enabled); ~MLIRContext(); /// Return information about all IR dialects loaded in the context. @@ -118,7 +120,13 @@ disableMultithreading(!enable); } - /// Return the thread pool owned by this context. This method requires that + /// Set a new thread pool to be used in this context. This method requires + /// that multithreading is disabled for this context prior to the call. This + /// allows to share a thread pool across multiple contexts, as well as + /// decoupling the lifetime of the threads from the contexts. + void setThreadPool(llvm::ThreadPool &pool); + + /// Return the thread pool used by this context. This method requires that /// multithreading be enabled within the context, and should generally not be /// used directly. Users should instead prefer the threading utilities within /// Threading.h. diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -262,7 +262,13 @@ //===--------------------------------------------------------------------===// /// The thread pool to use when processing MLIR tasks in parallel. - llvm::Optional threadPool; + /// It is be nullptr when multi-threading is disabled and the thread pool is + /// owned/private to the context, but not when it is externally set. + llvm::ThreadPool *threadPool = nullptr; + + /// In case where the thread pool is owned by the context, this ensures + /// destruction with the context. + std::unique_ptr ownedThreadPool; /// This is a list of dialects that are created referring to this context. /// The MLIRContext owns the objects. @@ -334,9 +340,13 @@ StringAttr emptyStringAttr; public: - MLIRContextImpl() : identifiers(identifierAllocator) { - if (threadingIsEnabled) - threadPool.emplace(); + MLIRContextImpl(bool threadingIsEnabled) + : threadingIsEnabled(threadingIsEnabled), + identifiers(identifierAllocator) { + if (threadingIsEnabled) { + ownedThreadPool = std::make_unique(); + threadPool = ownedThreadPool.get(); + } } ~MLIRContextImpl() { for (auto typeMapping : registeredTypes) @@ -347,10 +357,11 @@ }; } // end namespace mlir -MLIRContext::MLIRContext() : MLIRContext(DialectRegistry()) {} +MLIRContext::MLIRContext(Threading setting) + : MLIRContext(DialectRegistry(), setting) {} -MLIRContext::MLIRContext(const DialectRegistry ®istry) - : impl(new MLIRContextImpl) { +MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) + : impl(new MLIRContextImpl(setting == Threading::enabled)) { // Initialize values based on the command line flags if they were provided. if (clOptions.isConstructed()) { disableMultithreading(clOptions->disableThreading); @@ -579,15 +590,30 @@ // Destroy thread pool (stop all threads) if it is no longer needed, or create // a new one if multithreading was re-enabled. - if (!impl->threadingIsEnabled) - impl->threadPool.reset(); - else if (!impl->threadPool.hasValue()) - impl->threadPool.emplace(); + if (disable) { + // If the thread pool is owned, explicitly set it to nullptr to avoid + // keeping a dangling pointer around. If the thread pool is externally + // owned, we don't do anything. + if (impl->threadPool == impl->ownedThreadPool.get()) + impl->threadPool = nullptr; + impl->ownedThreadPool.reset(); + } else if (!impl->threadPool) { + if (!impl->ownedThreadPool) + impl->ownedThreadPool = std::make_unique(); + impl->threadPool = impl->ownedThreadPool.get(); + } +} + +void MLIRContext::setThreadPool(llvm::ThreadPool &pool) { + impl->threadPool = &pool; + impl->ownedThreadPool.reset(); } llvm::ThreadPool &MLIRContext::getThreadPool() { assert(isMultithreadingEnabled() && "expected multi-threading to be enabled within the context"); + assert(impl->threadPool && + "multi-threading is enabled but threadpool not set"); return *impl->threadPool; }