diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -280,13 +280,19 @@ // Pass Reproducer //===----------------------------------------------------------------------===// -/// Attach an assembly resource parser that handles MLIR reproducer -/// configurations. Any found reproducer information will be attached to the -/// given pass manager, e.g. the reproducer pipeline, verification flags, etc. -// FIXME: Remove the `enableThreading` flag when possible. Some tools, e.g. -// mlir-opt, force disable threading during parsing. -void attachPassReproducerAsmResource(ParserConfig &config, PassManager &pm, - bool &enableThreading); +struct PassReproducerOptions { + /// Attach an assembly resource parser to 'config' that collects the MLIR + /// reproducer configuration into this instance. + void attachResourceParser(ParserConfig &config); + + /// Apply the reproducer options to 'pm' and its context. + LogicalResult apply(PassManager &pm) const; + +private: + Optional pipeline; + Optional verifyEach; + Optional disableThreading; +}; } // namespace mlir diff --git a/mlir/lib/Pass/PassCrashRecovery.cpp b/mlir/lib/Pass/PassCrashRecovery.cpp --- a/mlir/lib/Pass/PassCrashRecovery.cpp +++ b/mlir/lib/Pass/PassCrashRecovery.cpp @@ -443,29 +443,24 @@ // Asm Resource //===----------------------------------------------------------------------===// -void mlir::attachPassReproducerAsmResource(ParserConfig &config, - PassManager &pm, - bool &enableThreading) { - auto parseFn = [&](AsmParsedResourceEntry &entry) -> LogicalResult { +void PassReproducerOptions::attachResourceParser(ParserConfig &config) { + auto parseFn = [this](AsmParsedResourceEntry &entry) -> LogicalResult { if (entry.getKey() == "pipeline") { - FailureOr pipeline = entry.parseAsString(); - if (failed(pipeline)) - return failure(); - return parsePassPipeline(*pipeline, pm); + FailureOr value = entry.parseAsString(); + if (succeeded(value)) + this->pipeline = std::move(*value); + return value; } if (entry.getKey() == "disable_threading") { FailureOr value = entry.parseAsBool(); - - // FIXME: We should just update the context directly, but some places - // force disable threading during parsing. if (succeeded(value)) - enableThreading = !(*value); + this->disableThreading = *value; return value; } if (entry.getKey() == "verify_each") { FailureOr value = entry.parseAsBool(); if (succeeded(value)) - pm.enableVerifier(*value); + this->verifyEach = *value; return value; } return entry.emitError() << "unknown 'mlir_reproducer' resource key '" @@ -473,3 +468,17 @@ }; config.attachResourceParser("mlir_reproducer", parseFn); } + +LogicalResult PassReproducerOptions::apply(PassManager &pm) const { + if (pipeline.has_value()) + if (failed(parsePassPipeline(*pipeline, pm))) + return failure(); + + if (disableThreading.has_value()) + pm.getContext()->disableMultithreading(*disableThreading); + + if (verifyEach.has_value()) + pm.enableVerifier(*verifyEach); + + return success(); +} diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -59,19 +59,14 @@ bool wasThreadingEnabled = context->isMultithreadingEnabled(); context->disableMultithreading(); - // Prepare the pass manager and apply any command line options. - PassManager pm(context, OpPassManager::Nesting::Implicit); - pm.enableVerifier(verifyPasses); - applyPassManagerCLOptions(pm); - pm.enableTiming(timing); - // Prepare the parser config, and attach any useful/necessary resource // handlers. Unhandled external resources are treated as passthrough, i.e. // they are not processed and will be emitted directly to the output // untouched. + PassReproducerOptions reproOptions; FallbackAsmResourceMap fallbackResourceMap; ParserConfig config(context, &fallbackResourceMap); - attachPassReproducerAsmResource(config, pm, wasThreadingEnabled); + reproOptions.attachResourceParser(config); // Parse the input file and reset the context threading state. TimingScope parserTiming = timing.nest("Parser"); @@ -81,8 +76,13 @@ return failure(); parserTiming.stop(); - // Callback to build the pipeline. - if (failed(passManagerSetupFn(pm))) + // Prepare the pass manager, applying command-line and reproducer options. + PassManager pm(context, OpPassManager::Nesting::Implicit, + module->getOperationName()); + pm.enableVerifier(verifyPasses); + applyPassManagerCLOptions(pm); + pm.enableTiming(timing); + if (failed(reproOptions.apply(pm)) || failed(passManagerSetupFn(pm))) return failure(); // Run the pipeline.