diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -38,11 +38,27 @@ /// a very generic name ("Context") and because it is uncommon for clients to /// interact with it. /// +/// The context wrap some multi-threading facilities, and in particular by +/// default it will implicitly create a thread pool. +/// This can be undesirable if multiple context exists at the same time or if a +/// process will be long-lived and create and destroy contexts. +/// To control better thread spawning, an externally owned ThreadPool can be +/// injected in the context. For example: +/// +/// llvm::ThreadPool myThreadPool; +/// while (auto *request = nextCompilationRequests()) { +/// MLIRContext ctx(registry, MLIRContext::Threading::DISABLED); +/// ctx.setThreadPool(myThreadPool); +/// processRequest(request, cxt); +/// } +/// class MLIRContext { public: + enum class Threading { DISABLED, ENABLED }; /// Create a new Context. - explicit MLIRContext(); - explicit MLIRContext(const DialectRegistry ®istry); + explicit MLIRContext(Threading multithreading = Threading::ENABLED); + explicit MLIRContext(const DialectRegistry ®istry, + Threading multithreading = Threading::ENABLED); ~MLIRContext(); /// Return information about all IR dialects loaded in the context. @@ -118,7 +134,15 @@ disableMultithreading(!enable); } - /// Return the thread pool owned by this context. This method requires that + /// Set a new thread pool to be used in this context. This method requires + /// that multithreading is disabled for this context prior to the call. This + /// allows to share a thread pool across multiple contexts, as well as + /// decoupling the lifetime of the threads from the contexts. The thread pool + /// must outlive the context. Multi-threading will be enabled as part of this + /// method. + void setThreadPool(llvm::ThreadPool &pool); + + /// Return the thread pool used by this context. This method requires that /// multithreading be enabled within the context, and should generally not be /// used directly. Users should instead prefer the threading utilities within /// Threading.h. diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -261,8 +261,15 @@ // Other //===--------------------------------------------------------------------===// - /// The thread pool to use when processing MLIR tasks in parallel. - llvm::Optional threadPool; + /// This points to the ThreadPool used when processing MLIR tasks in parallel. + /// It can't be nullptr when multi-threading is enabled. Otherwise if + /// multi-threading is disabled, and the threadpool wasn't externally provided + /// using `setThreadPool`, this will be nullptr. + llvm::ThreadPool *threadPool = nullptr; + + /// In case where the thread pool is owned by the context, this ensures + /// destruction with the context. + std::unique_ptr ownedThreadPool; /// This is a list of dialects that are created referring to this context. /// The MLIRContext owns the objects. @@ -334,9 +341,13 @@ StringAttr emptyStringAttr; public: - MLIRContextImpl() : identifiers(identifierAllocator) { - if (threadingIsEnabled) - threadPool.emplace(); + MLIRContextImpl(bool threadingIsEnabled) + : threadingIsEnabled(threadingIsEnabled), + identifiers(identifierAllocator) { + if (threadingIsEnabled) { + ownedThreadPool = std::make_unique(); + threadPool = ownedThreadPool.get(); + } } ~MLIRContextImpl() { for (auto typeMapping : registeredTypes) @@ -347,10 +358,11 @@ }; } // end namespace mlir -MLIRContext::MLIRContext() : MLIRContext(DialectRegistry()) {} +MLIRContext::MLIRContext(Threading setting) + : MLIRContext(DialectRegistry(), setting) {} -MLIRContext::MLIRContext(const DialectRegistry ®istry) - : impl(new MLIRContextImpl) { +MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) + : impl(new MLIRContextImpl(setting == Threading::ENABLED)) { // Initialize values based on the command line flags if they were provided. if (clOptions.isConstructed()) { disableMultithreading(clOptions->disableThreading); @@ -579,15 +591,36 @@ // Destroy thread pool (stop all threads) if it is no longer needed, or create // a new one if multithreading was re-enabled. - if (!impl->threadingIsEnabled) - impl->threadPool.reset(); - else if (!impl->threadPool.hasValue()) - impl->threadPool.emplace(); + if (disable) { + // If the thread pool is owned, explicitly set it to nullptr to avoid + // keeping a dangling pointer around. If the thread pool is externally + // owned, we don't do anything. + if (impl->ownedThreadPool) { + assert(impl->threadPool); + impl->threadPool = nullptr; + impl->ownedThreadPool.reset(); + } + } else if (!impl->threadPool) { + // The thread pool isn't externally provided. + assert(!impl->ownedThreadPool); + impl->ownedThreadPool = std::make_unique(); + impl->threadPool = impl->ownedThreadPool.get(); + } +} + +void MLIRContext::setThreadPool(llvm::ThreadPool &pool) { + assert(!isMultithreadingEnabled() && + "expected multi-threading to be disabled when setting a ThreadPool"); + impl->threadPool = &pool; + impl->ownedThreadPool.reset(); + enableMultithreading(); } llvm::ThreadPool &MLIRContext::getThreadPool() { assert(isMultithreadingEnabled() && "expected multi-threading to be enabled within the context"); + assert(impl->threadPool && + "multi-threading is enabled but threadpool not set"); return *impl->threadPool; }