diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -94,7 +94,7 @@ bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, - llvm::ThreadPool &threadPool) { + llvm::ThreadPool *threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); @@ -102,7 +102,8 @@ // Create a context just for the current buffer. Disable threading on creation // since we'll inject the thread-pool separately. MLIRContext context(registry, MLIRContext::Threading::DISABLED); - context.setThreadPool(threadPool); + if (threadPool) + context.setThreadPool(*threadPool); // Parse the input file. if (preloadDialectsInContext) @@ -144,7 +145,15 @@ // up into small pieces and checks each independently. // We use an explicit threadpool to avoid creating and joining/destroying // threads for each of the split. - llvm::ThreadPool threadPool; + ThreadPool *threadPool = nullptr; + // Create a temporary context for the sake of checking if + // --mlir-disable-threading was passed on the command line. + // We use the thread-pool this context is creating, and avoid + // creating any thread when disabled. + MLIRContext threadPoolCtx; + if (threadPoolCtx.isMultithreadingEnabled()) + threadPool = &threadPoolCtx.getThreadPool(); + if (splitInputFile) return splitAndProcessBuffer( std::move(buffer), @@ -152,14 +161,14 @@ return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, preloadDialectsInContext, passManagerSetupFn, - registry, threadPool); + registry, threadPool.get()); }, outputStream); return processBuffer(outputStream, std::move(buffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, preloadDialectsInContext, passManagerSetupFn, registry, - threadPool); + threadPool.get()); } LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,