diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -29,31 +29,153 @@ class PassPipelineCLParser; class PassManager; +/// Configuration options for the mlir-opt tool. +/// This is intended to help building tools like mlir-opt by collecting the +/// supported options. +/// The API is fluent, and the options are sorted in alphabetical order below. +class MlirOptMainConfig { +public: + /// Allow operation with no registered dialects. + /// This option is for convenience during testing only and discouraged in + /// general. + MlirOptMainConfig &allowUnregisteredDialects(bool allow) { + allowUnregisteredDialectsFlag = allow; + return *this; + } + bool shouldAllowUnregisteredDialects() const { + return allowUnregisteredDialectsFlag; + } + + /// Print the pass-pipeline as text before executing. + MlirOptMainConfig &dumpPassPipeline(bool dump) { + dumpPassPipelineFlag = dump; + return *this; + } + bool shouldDumpPassPipeline() const { return dumpPassPipelineFlag; } + + /// Set the output format to bytecode instead of textual IR. + MlirOptMainConfig &emitBytecode(bool emit) { + emitBytecodeFlag = emit; + return *this; + } + bool shouldEmitBytecode() const { return emitBytecodeFlag; } + + /// Set the callback to populate the pass manager. + MlirOptMainConfig & + setPassPipelineSetupFn(std::function callback) { + passPipelineCallback = std::move(callback); + return *this; + } + + /// Set the parser to use to populate the pass manager. + MlirOptMainConfig &setPassPipelineParser(const PassPipelineCLParser &parser); + + /// Populate the passmanager, if any callback was set. + LogicalResult setupPassPipeline(PassManager &pm) const { + if (passPipelineCallback) + return passPipelineCallback(pm); + return success(); + } + + // Deprecated. + MlirOptMainConfig &preloadDialectsInContext(bool preload) { + preloadDialectsInContextFlag = preload; + return *this; + } + bool shouldPreloadDialectsInContext() const { + return preloadDialectsInContextFlag; + } + + /// Show the registered dialects before trying to load the input file. + MlirOptMainConfig &showDialects(bool show) { + showDialectsFlag = show; + return *this; + } + bool shouldShowDialects() const { return showDialectsFlag; } + + /// Set whether to split the input file based on the `// -----` marker into + /// pieces and process each chunk independently. + MlirOptMainConfig &splitInputFile(bool split = true) { + splitInputFileFlag = split; + return *this; + } + bool shouldSplitInputFile() const { return splitInputFileFlag; } + + /// Disable implicit addition of a top-level module op during parsing. + MlirOptMainConfig &useImplicitModule(bool useImplicitModule) { + useImplicitModuleFlag = useImplicitModule; + return *this; + } + bool shouldUseImplicitModule() const { return useImplicitModuleFlag; } + + /// Set whether to check that emitted diagnostics match `expected-*` lines on + /// the corresponding line. This is meant for implementing diagnostic tests. + MlirOptMainConfig &verifyDiagnostics(bool verify) { + verifyDiagnosticsFlag = verify; + return *this; + } + bool shouldVerifyDiagnostics() const { return verifyDiagnosticsFlag; } + + /// Set whether to run the verifier after each transformation pass. + MlirOptMainConfig &verifyPasses(bool verify) { + verifyPassesFlag = verify; + return *this; + } + bool shouldVerifyPasses() const { return verifyPassesFlag; } + +private: + /// Allow operation with no registered dialects. + /// This option is for convenience during testing only and discouraged in + /// general. + bool allowUnregisteredDialectsFlag = false; + + /// Print the pipeline that will be run. + bool dumpPassPipelineFlag = false; + + /// Emit bytecode instead of textual assembly when generating output. + bool emitBytecodeFlag = false; + + /// The callback to populate the pass manager. + std::function passPipelineCallback; + + /// Deprecated. + bool preloadDialectsInContextFlag = false; + + /// Show the registered dialects before trying to load the input file. + bool showDialectsFlag = false; + + /// Split the input file based on the `// -----` marker into pieces and + /// process each chunk independently. + bool splitInputFileFlag = false; + + /// Use an implicit top-level module op during parsing. + bool useImplicitModuleFlag = true; + + /// Set whether to check that emitted diagnostics match `expected-*` lines on + /// the corresponding line. This is meant for implementing diagnostic tests. + bool verifyDiagnosticsFlag = false; + + /// Run the verifier after each transformation pass. + bool verifyPassesFlag = true; +}; + /// This defines the function type used to setup the pass manager. This can be /// used to pass in a callback to setup a default pass pipeline to be applied on /// the loaded IR. using PassPipelineFn = llvm::function_ref; -/// Perform the core processing behind `mlir-opt`: +/// Perform the core processing behind `mlir-opt`. /// - outputStream is the stream where the resulting IR is printed. /// - buffer is the in-memory file to parser and process. -/// - passPipeline is the specification of the pipeline that will be applied. /// - registry should contain all the dialects that can be parsed in the source. -/// - splitInputFile will look for a "-----" marker in the input file, and load -/// each chunk in an individual ModuleOp processed separately. -/// - verifyDiagnostics enables a verification mode where comments starting with -/// "expected-(error|note|remark|warning)" are parsed in the input and matched -/// against emitted diagnostics. -/// - verifyPasses enables the IR verifier in-between each pass in the pipeline. -/// - allowUnregisteredDialects allows to parse and create operation without -/// registering the Dialect in the MLIRContext. -/// - preloadDialectsInContext will trigger the upfront loading of all -/// dialects from the global registry in the MLIRContext. This option is -/// deprecated and will be removed soon. -/// - emitBytecode will generate bytecode output instead of text. -/// - implicitModule will enable implicit addition of a top-level -/// 'builtin.module' if one doesn't already exist. -/// - dumpPassPipeline will dump the pipeline being run to stderr +/// - config contains the configuration options for the tool. +LogicalResult MlirOptMain(llvm::raw_ostream &outputStream, + std::unique_ptr buffer, + DialectRegistry ®istry, + const MlirOptMainConfig &config); + +/// Perform the core processing behind `mlir-opt`. +/// This API is deprecated, use the MlirOptMainConfig version above instead. LogicalResult MlirOptMain(llvm::raw_ostream &outputStream, std::unique_ptr buffer, @@ -63,9 +185,8 @@ bool preloadDialectsInContext = false, bool emitBytecode = false, bool implicitModule = false, bool dumpPassPipeline = false); -/// Support a callback to setup the pass manager. -/// - passManagerSetupFn is the callback invoked to setup the pass manager to -/// apply on the loaded IR. +/// Perform the core processing behind `mlir-opt`. +/// This API is deprecated, use the MlirOptMainConfig version above instead. LogicalResult MlirOptMain( llvm::raw_ostream &outputStream, std::unique_ptr buffer, PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -40,6 +40,25 @@ using namespace mlir; using namespace llvm; +MlirOptMainConfig &MlirOptMainConfig::setPassPipelineParser( + const PassPipelineCLParser &passPipeline) { + passPipelineCallback = [&](PassManager &pm) { + auto errorHandler = [&](const Twine &msg) { + emitError(UnknownLoc::get(pm.getContext())) << msg; + return failure(); + }; + if (failed(passPipeline.addToPipeline(pm, errorHandler))) + return failure(); + if (this->shouldDumpPassPipeline()) { + + pm.dump(); + llvm::errs() << "\n"; + } + return success(); + }; + return *this; +} + /// Perform the actions on the input file indicated by the command line flags /// within the specified context. /// @@ -47,10 +66,9 @@ /// passes, then prints the output. /// static LogicalResult -performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses, +performActions(raw_ostream &os, const std::shared_ptr &sourceMgr, - MLIRContext *context, PassPipelineFn passManagerSetupFn, - bool emitBytecode, bool implicitModule) { + MLIRContext *context, const MlirOptMainConfig &config) { DefaultTimingManager tm; applyDefaultTimingManagerCLOptions(tm); TimingScope timing = tm.getRootScope(); @@ -66,13 +84,14 @@ // untouched. PassReproducerOptions reproOptions; FallbackAsmResourceMap fallbackResourceMap; - ParserConfig config(context, /*verifyAfterParse=*/true, &fallbackResourceMap); - reproOptions.attachResourceParser(config); + ParserConfig parseConfig(context, /*verifyAfterParse=*/true, + &fallbackResourceMap); + reproOptions.attachResourceParser(parseConfig); // Parse the input file and reset the context threading state. TimingScope parserTiming = timing.nest("Parser"); - OwningOpRef op = - parseSourceFileForTool(sourceMgr, config, implicitModule); + OwningOpRef op = parseSourceFileForTool( + sourceMgr, parseConfig, config.shouldUseImplicitModule()); context->enableMultithreading(wasThreadingEnabled); if (!op) return failure(); @@ -80,10 +99,10 @@ // Prepare the pass manager, applying command-line and reproducer options. PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit); - pm.enableVerifier(verifyPasses); + pm.enableVerifier(config.shouldVerifyPasses()); applyPassManagerCLOptions(pm); pm.enableTiming(timing); - if (failed(reproOptions.apply(pm)) || failed(passManagerSetupFn(pm))) + if (failed(reproOptions.apply(pm)) || failed(config.setupPassPipeline(pm))) return failure(); // Run the pipeline. @@ -92,7 +111,7 @@ // Print the output. TimingScope outputTiming = timing.nest("Output"); - if (emitBytecode) { + if (config.shouldEmitBytecode()) { BytecodeWriterConfig writerConfig(fallbackResourceMap); writeBytecodeToFile(op.get(), os, writerConfig); } else { @@ -106,13 +125,11 @@ /// Parses the memory buffer. If successfully, run a series of passes against /// it and print the result. -static LogicalResult -processBuffer(raw_ostream &os, std::unique_ptr ownedBuffer, - bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects, bool preloadDialectsInContext, - bool emitBytecode, bool implicitModule, - PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, - llvm::ThreadPool *threadPool) { +static LogicalResult processBuffer(raw_ostream &os, + std::unique_ptr ownedBuffer, + const MlirOptMainConfig &config, + DialectRegistry ®istry, + llvm::ThreadPool *threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. auto sourceMgr = std::make_shared(); sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); @@ -124,20 +141,18 @@ context.setThreadPool(*threadPool); // Parse the input file. - if (preloadDialectsInContext) + if (config.shouldPreloadDialectsInContext()) context.loadAllAvailableDialects(); - context.allowUnregisteredDialects(allowUnregisteredDialects); - if (verifyDiagnostics) + context.allowUnregisteredDialects(config.shouldAllowUnregisteredDialects()); + if (config.shouldVerifyDiagnostics()) context.printOpOnDiagnostic(false); context.getDebugActionManager().registerActionHandler(); // If we are in verify diagnostics mode then we have a lot of work to do, // otherwise just perform the actions without worrying about it. - if (!verifyDiagnostics) { + if (!config.shouldVerifyDiagnostics()) { SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context); - return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, - &context, passManagerSetupFn, emitBytecode, - implicitModule); + return performActions(os, sourceMgr, &context, config); } SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context); @@ -145,22 +160,17 @@ // Do any processing requested by command line flags. We don't care whether // these actions succeed or fail, we only care what diagnostics they produce // and whether they match our expectations. - (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context, - passManagerSetupFn, emitBytecode, implicitModule); + (void)performActions(os, sourceMgr, &context, config); // Verify the diagnostic handler to make sure that each of the diagnostics // matched. return sourceMgrHandler.verify(); } -LogicalResult mlir::MlirOptMain(raw_ostream &outputStream, - std::unique_ptr buffer, - PassPipelineFn passManagerSetupFn, - DialectRegistry ®istry, bool splitInputFile, - bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects, - bool preloadDialectsInContext, - bool emitBytecode, bool implicitModule) { +LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream, + std::unique_ptr buffer, + DialectRegistry ®istry, + const MlirOptMainConfig &config) { // The split-input-file mode is a very specific mode that slices the file // up into small pieces and checks each independently. // We use an explicit threadpool to avoid creating and joining/destroying @@ -177,13 +187,32 @@ auto chunkFn = [&](std::unique_ptr chunkBuffer, raw_ostream &os) { - return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, - verifyPasses, allowUnregisteredDialects, - preloadDialectsInContext, emitBytecode, implicitModule, - passManagerSetupFn, registry, threadPool); + return processBuffer(os, std::move(chunkBuffer), config, registry, + threadPool); }; return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream, - splitInputFile, /*insertMarkerInOutput=*/true); + config.shouldSplitInputFile(), + /*insertMarkerInOutput=*/true); +} + +LogicalResult mlir::MlirOptMain(raw_ostream &outputStream, + std::unique_ptr buffer, + PassPipelineFn passManagerSetupFn, + DialectRegistry ®istry, bool splitInputFile, + bool verifyDiagnostics, bool verifyPasses, + bool allowUnregisteredDialects, + bool preloadDialectsInContext, + bool emitBytecode, bool implicitModule) { + return MlirOptMain(outputStream, std::move(buffer), registry, + MlirOptMainConfig{} + .splitInputFile(splitInputFile) + .verifyDiagnostics(verifyDiagnostics) + .verifyPasses(verifyPasses) + .allowUnregisteredDialects(allowUnregisteredDialects) + .preloadDialectsInContext(preloadDialectsInContext) + .emitBytecode(emitBytecode) + .useImplicitModule(implicitModule) + .setPassPipelineSetupFn(passManagerSetupFn)); } LogicalResult mlir::MlirOptMain( @@ -192,23 +221,17 @@ bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, bool emitBytecode, bool implicitModule, bool dumpPassPipeline) { - auto passManagerSetupFn = [&](PassManager &pm) { - auto errorHandler = [&](const Twine &msg) { - emitError(UnknownLoc::get(pm.getContext())) << msg; - return failure(); - }; - if (failed(passPipeline.addToPipeline(pm, errorHandler))) - return failure(); - if (dumpPassPipeline) { - pm.dump(); - llvm::errs() << "\n"; - } - return success(); - }; - return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn, - registry, splitInputFile, verifyDiagnostics, verifyPasses, - allowUnregisteredDialects, preloadDialectsInContext, - emitBytecode, implicitModule); + return MlirOptMain(outputStream, std::move(buffer), registry, + MlirOptMainConfig{} + .splitInputFile(splitInputFile) + .verifyDiagnostics(verifyDiagnostics) + .verifyPasses(verifyPasses) + .allowUnregisteredDialects(allowUnregisteredDialects) + .preloadDialectsInContext(preloadDialectsInContext) + .emitBytecode(emitBytecode) + .useImplicitModule(implicitModule) + .dumpPassPipeline(dumpPassPipeline) + .setPassPipelineParser(passPipeline)); } LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, @@ -301,12 +324,19 @@ llvm::errs() << errorMessage << "\n"; return failure(); } - - if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry, - splitInputFile, verifyDiagnostics, verifyPasses, - allowUnregisteredDialects, preloadDialectsInContext, - emitBytecode, /*implicitModule=*/!noImplicitModule, - dumpPassPipeline))) + // Setup the configuration for the main function. + MlirOptMainConfig config; + config.setPassPipelineParser(passPipeline) + .splitInputFile(splitInputFile) + .verifyDiagnostics(verifyDiagnostics) + .verifyPasses(verifyPasses) + .allowUnregisteredDialects(allowUnregisteredDialects) + .preloadDialectsInContext(preloadDialectsInContext) + .emitBytecode(emitBytecode) + .useImplicitModule(!noImplicitModule) + .dumpPassPipeline(dumpPassPipeline); + + if (failed(MlirOptMain(output->os(), std::move(file), registry, config))) return failure(); // Keep the output file if the invocation of MlirOptMain was successful.