diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -57,7 +57,8 @@ struct MLIRContextOptions { llvm::cl::opt disableThreading{ "mlir-disable-threading", - llvm::cl::desc("Disabling multi-threading within MLIR")}; + llvm::cl::desc("Disable multi-threading within MLIR, overrides any " + "further call to MLIRContext::enableMultiThreading()")}; llvm::cl::opt printOpOnDiagnostic{ "mlir-print-op-on-diagnostic", @@ -74,6 +75,14 @@ static llvm::ManagedStatic clOptions; +static bool isThreadingGloballyDisabled() { +#if LLVM_ENABLE_THREADS != 0 + return !clOptions.isConstructed() || !clOptions->disableThreading; +#else + return true; +#endif +} + /// Register a set of useful command-line options that can be used to configure /// various flags within the MLIRContext. These flags are used when constructing /// an MLIR context for initialization. @@ -362,10 +371,10 @@ : MLIRContext(DialectRegistry(), setting) {} MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) - : impl(new MLIRContextImpl(setting == Threading::ENABLED)) { + : impl(new MLIRContextImpl(setting == Threading::ENABLED && + !isThreadingGloballyDisabled())) { // Initialize values based on the command line flags if they were provided. if (clOptions.isConstructed()) { - disableMultithreading(clOptions->disableThreading); printOpOnDiagnostic(clOptions->printOpOnDiagnostic); printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic); } @@ -582,6 +591,11 @@ /// Set the flag specifying if multi-threading is disabled by the context. void MLIRContext::disableMultithreading(bool disable) { + // This API can be overridden by the global debugging flag + // --mlir-disable-threading + if (isThreadingGloballyDisabled()) + return; + impl->threadingIsEnabled = !disable; // Update the threading mode for each of the uniquers. diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -32,6 +32,7 @@ #include "llvm/Support/Regex.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/StringSaver.h" +#include "llvm/Support/ThreadPool.h" #include "llvm/Support/ToolOutputFile.h" using namespace mlir; @@ -93,19 +94,22 @@ /// Parses the memory buffer. If successfully, run a series of passes against /// it and print the result. -static LogicalResult processBuffer(raw_ostream &os, - std::unique_ptr ownedBuffer, - bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects, - bool preloadDialectsInContext, - const PassPipelineCLParser &passPipeline, - DialectRegistry ®istry) { +static LogicalResult +processBuffer(raw_ostream &os, std::unique_ptr ownedBuffer, + bool verifyDiagnostics, bool verifyPasses, + bool allowUnregisteredDialects, bool preloadDialectsInContext, + const PassPipelineCLParser &passPipeline, + DialectRegistry ®istry, llvm::ThreadPool &threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); + // Create a context just for the current buffer. Disable threading on creation + // since we'll inject the thread-pool separately. + MLIRContext context(registry, MLIRContext::Threading::DISABLED); + context.setThreadPool(threadPool); + // Parse the input file. - MLIRContext context(registry); if (preloadDialectsInContext) context.loadAllAvailableDialects(); context.allowUnregisteredDialects(allowUnregisteredDialects); @@ -143,20 +147,24 @@ bool preloadDialectsInContext) { // The split-input-file mode is a very specific mode that slices the file // up into small pieces and checks each independently. + // We use an explicit threadpool to avoid creating and joining/destroying + // threads for each of the split. + llvm::ThreadPool threadPool; if (splitInputFile) return splitAndProcessBuffer( std::move(buffer), [&](std::unique_ptr chunkBuffer, raw_ostream &os) { return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, - preloadDialectsInContext, passPipeline, - registry); + preloadDialectsInContext, passPipeline, registry, + threadPool); }, outputStream); return processBuffer(outputStream, std::move(buffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, - preloadDialectsInContext, passPipeline, registry); + preloadDialectsInContext, passPipeline, registry, + threadPool); } LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,