diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -29,31 +29,153 @@ class PassPipelineCLParser; class PassManager; +/// Configuration options for the mlir-opt tool. +class MlirOptMainConfig { +public: + /// Set whether to split the input file based on the `// -----` marker into + /// pieces and process each chunk independently. + MlirOptMainConfig &setSplitInputFile(bool split) { + splitInputFile = split; + return *this; + } + bool shouldSplitInputFile() const { return splitInputFile; } + + /// Set whether to check that emitted diagnostics match `expected-*` lines on + /// the corresponding line + MlirOptMainConfig &setVerifyDiagnostics(bool verify) { + verifyDiagnostics = verify; + return *this; + } + bool shouldVerifyDiagnostics() const { return verifyDiagnostics; } + + /// Set whether to run the verifier after each transformation pass. + MlirOptMainConfig &setVerifyPasses(bool verify) { + verifyPasses = verify; + return *this; + } + bool shouldVerifyPasses() const { return verifyPasses; } + + /// Set whether to run the verifier after each transformation pass. + MlirOptMainConfig &setAllowUnregisteredDialects(bool allow) { + allowUnregisteredDialects = allow; + return *this; + } + bool shouldAllowUnregisteredDialects() const { + return allowUnregisteredDialects; + } + + /// Set whether to run the verifier after each transformation pass. + MlirOptMainConfig &setShowDialects(bool show) { + showDialects = show; + return *this; + } + bool shouldShowDialects() const { return showDialects; } + + /// Set whether to run the verifier after each transformation pass. + MlirOptMainConfig &setEmitBytecode(bool emit) { + emitBytecode = emit; + return *this; + } + bool shouldEmitBytecode() const { return emitBytecode; } + + /// Set whether to run the verifier after each transformation pass. + MlirOptMainConfig &setUseImplicitModule(bool useImplicitModule) { + this->useImplicitModule = useImplicitModule; + return *this; + } + bool shouldUseImplicitModule() const { return useImplicitModule; } + + /// Set whether to run the verifier after each transformation pass. + MlirOptMainConfig &setDumpPassPipeline(bool dump) { + dumpPassPipeline = dump; + return *this; + } + bool shouldDumpPassPipeline() const { return dumpPassPipeline; } + + /// Set the callback to populate the pass manager. + MlirOptMainConfig &setPassPipelineSetupFn( + llvm::function_ref callback) { + passPipelineCallback = std::move(callback); + return *this; + } + + MlirOptMainConfig &setPassPipelineParser(const PassPipelineCLParser &parser); + + /// Populate the passmanager, if any callback was set. + LogicalResult setupPassPipeline(PassManager &pm) const { + if (passPipelineCallback) + return passPipelineCallback(pm); + return success(); + } + + /// Deprecated. + MlirOptMainConfig &setPreloadDialectsInContext(bool preload) { + preloadDialectsInContext = preload; + return *this; + } + + /// Deprecated. + bool shouldPreloadDialectsInContext() const { + return preloadDialectsInContext; + } + +private: + /// Input .mlir or .mlirbc filename for the mlir-opt tool. + std::string inputFilename = "-"; + + /// Output .mlir or .mlirbc filename for the mlir-opt tool. + std::string outputFilename = "-"; + + /// Split the input file based on the `// -----` marker into pieces and + /// process each chunk independently. + bool splitInputFile = false; + + /// Check that emitted diagnostics match `expected-*` lines on the + /// corresponding line + bool verifyDiagnostics = false; + + /// Run the verifier after each transformation pass. + bool verifyPasses = true; + + /// Allow operation with no registered dialects. + bool allowUnregisteredDialects = false; + + /// Print the list of registered dialects. + bool showDialects = false; + + /// Emit bytecode instead of textual assembly when generating output. + bool emitBytecode = false; + + /// Use an implicit top-level module op during parsing. + bool useImplicitModule = false; + + /// Deprecated. + bool preloadDialectsInContext = false; + + /// Print the pipeline that will be run. + bool dumpPassPipeline = false; + + /// The callback to populate the pass manager. + std::function passPipelineCallback; +}; + /// This defines the function type used to setup the pass manager. This can be /// used to pass in a callback to setup a default pass pipeline to be applied on /// the loaded IR. using PassPipelineFn = llvm::function_ref; -/// Perform the core processing behind `mlir-opt`: +/// Perform the core processing behind `mlir-opt`. /// - outputStream is the stream where the resulting IR is printed. /// - buffer is the in-memory file to parser and process. -/// - passPipeline is the specification of the pipeline that will be applied. /// - registry should contain all the dialects that can be parsed in the source. -/// - splitInputFile will look for a "-----" marker in the input file, and load -/// each chunk in an individual ModuleOp processed separately. -/// - verifyDiagnostics enables a verification mode where comments starting with -/// "expected-(error|note|remark|warning)" are parsed in the input and matched -/// against emitted diagnostics. -/// - verifyPasses enables the IR verifier in-between each pass in the pipeline. -/// - allowUnregisteredDialects allows to parse and create operation without -/// registering the Dialect in the MLIRContext. -/// - preloadDialectsInContext will trigger the upfront loading of all -/// dialects from the global registry in the MLIRContext. This option is -/// deprecated and will be removed soon. -/// - emitBytecode will generate bytecode output instead of text. -/// - implicitModule will enable implicit addition of a top-level -/// 'builtin.module' if one doesn't already exist. -/// - dumpPassPipeline will dump the pipeline being run to stderr +/// - config contains the configuration options for the tool. +LogicalResult MlirOptMain(llvm::raw_ostream &outputStream, + std::unique_ptr buffer, + DialectRegistry ®istry, + const MlirOptMainConfig &config); + +/// Perform the core processing behind `mlir-opt`. +/// This API is deprecated, use the MlirOptMainConfig version above instead. LogicalResult MlirOptMain(llvm::raw_ostream &outputStream, std::unique_ptr buffer, @@ -63,9 +185,8 @@ bool preloadDialectsInContext = false, bool emitBytecode = false, bool implicitModule = false, bool dumpPassPipeline = false); -/// Support a callback to setup the pass manager. -/// - passManagerSetupFn is the callback invoked to setup the pass manager to -/// apply on the loaded IR. +/// Perform the core processing behind `mlir-opt`. +/// This API is deprecated, use the MlirOptMainConfig version above instead. LogicalResult MlirOptMain( llvm::raw_ostream &outputStream, std::unique_ptr buffer, PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -40,6 +40,25 @@ using namespace mlir; using namespace llvm; +MlirOptMainConfig &MlirOptMainConfig::setPassPipelineParser( + const PassPipelineCLParser &passPipeline) { + passPipelineCallback = [&](PassManager &pm) { + auto errorHandler = [&](const Twine &msg) { + emitError(UnknownLoc::get(pm.getContext())) << msg; + return failure(); + }; + if (failed(passPipeline.addToPipeline(pm, errorHandler))) + return failure(); + if (this->shouldDumpPassPipeline()) { + + pm.dump(); + llvm::errs() << "\n"; + } + return success(); + }; + return *this; +} + /// Perform the actions on the input file indicated by the command line flags /// within the specified context. /// @@ -47,10 +66,9 @@ /// passes, then prints the output. /// static LogicalResult -performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses, +performActions(raw_ostream &os, const std::shared_ptr &sourceMgr, - MLIRContext *context, PassPipelineFn passManagerSetupFn, - bool emitBytecode, bool implicitModule) { + MLIRContext *context, const MlirOptMainConfig &config) { DefaultTimingManager tm; applyDefaultTimingManagerCLOptions(tm); TimingScope timing = tm.getRootScope(); @@ -66,13 +84,14 @@ // untouched. PassReproducerOptions reproOptions; FallbackAsmResourceMap fallbackResourceMap; - ParserConfig config(context, /*verifyAfterParse=*/true, &fallbackResourceMap); - reproOptions.attachResourceParser(config); + ParserConfig parseConfig(context, /*verifyAfterParse=*/true, + &fallbackResourceMap); + reproOptions.attachResourceParser(parseConfig); // Parse the input file and reset the context threading state. TimingScope parserTiming = timing.nest("Parser"); - OwningOpRef op = - parseSourceFileForTool(sourceMgr, config, implicitModule); + OwningOpRef op = parseSourceFileForTool( + sourceMgr, parseConfig, config.shouldUseImplicitModule()); context->enableMultithreading(wasThreadingEnabled); if (!op) return failure(); @@ -80,10 +99,10 @@ // Prepare the pass manager, applying command-line and reproducer options. PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit); - pm.enableVerifier(verifyPasses); + pm.enableVerifier(config.shouldVerifyPasses()); applyPassManagerCLOptions(pm); pm.enableTiming(timing); - if (failed(reproOptions.apply(pm)) || failed(passManagerSetupFn(pm))) + if (failed(reproOptions.apply(pm)) || failed(config.setupPassPipeline(pm))) return failure(); // Run the pipeline. @@ -92,7 +111,7 @@ // Print the output. TimingScope outputTiming = timing.nest("Output"); - if (emitBytecode) { + if (config.shouldEmitBytecode()) { BytecodeWriterConfig writerConfig(fallbackResourceMap); writeBytecodeToFile(op.get(), os, writerConfig); } else { @@ -106,13 +125,11 @@ /// Parses the memory buffer. If successfully, run a series of passes against /// it and print the result. -static LogicalResult -processBuffer(raw_ostream &os, std::unique_ptr ownedBuffer, - bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects, bool preloadDialectsInContext, - bool emitBytecode, bool implicitModule, - PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, - llvm::ThreadPool *threadPool) { +static LogicalResult processBuffer(raw_ostream &os, + std::unique_ptr ownedBuffer, + const MlirOptMainConfig &config, + DialectRegistry ®istry, + llvm::ThreadPool *threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. auto sourceMgr = std::make_shared(); sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); @@ -124,20 +141,18 @@ context.setThreadPool(*threadPool); // Parse the input file. - if (preloadDialectsInContext) + if (config.shouldPreloadDialectsInContext()) context.loadAllAvailableDialects(); - context.allowUnregisteredDialects(allowUnregisteredDialects); - if (verifyDiagnostics) + context.allowUnregisteredDialects(config.shouldAllowUnregisteredDialects()); + if (config.shouldVerifyDiagnostics()) context.printOpOnDiagnostic(false); context.getDebugActionManager().registerActionHandler(); // If we are in verify diagnostics mode then we have a lot of work to do, // otherwise just perform the actions without worrying about it. - if (!verifyDiagnostics) { + if (!config.shouldVerifyDiagnostics()) { SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context); - return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, - &context, passManagerSetupFn, emitBytecode, - implicitModule); + return performActions(os, sourceMgr, &context, config); } SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context); @@ -145,22 +160,17 @@ // Do any processing requested by command line flags. We don't care whether // these actions succeed or fail, we only care what diagnostics they produce // and whether they match our expectations. - (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context, - passManagerSetupFn, emitBytecode, implicitModule); + (void)performActions(os, sourceMgr, &context, config); // Verify the diagnostic handler to make sure that each of the diagnostics // matched. return sourceMgrHandler.verify(); } -LogicalResult mlir::MlirOptMain(raw_ostream &outputStream, - std::unique_ptr buffer, - PassPipelineFn passManagerSetupFn, - DialectRegistry ®istry, bool splitInputFile, - bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects, - bool preloadDialectsInContext, - bool emitBytecode, bool implicitModule) { +LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream, + std::unique_ptr buffer, + DialectRegistry ®istry, + const MlirOptMainConfig &config) { // The split-input-file mode is a very specific mode that slices the file // up into small pieces and checks each independently. // We use an explicit threadpool to avoid creating and joining/destroying @@ -177,13 +187,33 @@ auto chunkFn = [&](std::unique_ptr chunkBuffer, raw_ostream &os) { - return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, - verifyPasses, allowUnregisteredDialects, - preloadDialectsInContext, emitBytecode, implicitModule, - passManagerSetupFn, registry, threadPool); + return processBuffer(os, std::move(chunkBuffer), config, registry, + threadPool); }; return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream, - splitInputFile, /*insertMarkerInOutput=*/true); + config.shouldSplitInputFile(), + /*insertMarkerInOutput=*/true); +} + +LogicalResult mlir::MlirOptMain(raw_ostream &outputStream, + std::unique_ptr buffer, + PassPipelineFn passManagerSetupFn, + DialectRegistry ®istry, bool splitInputFile, + bool verifyDiagnostics, bool verifyPasses, + bool allowUnregisteredDialects, + bool preloadDialectsInContext, + bool emitBytecode, bool implicitModule) { + return MlirOptMain( + outputStream, std::move(buffer), registry, + MlirOptMainConfig{} + .setSplitInputFile(splitInputFile) + .setVerifyDiagnostics(verifyDiagnostics) + .setVerifyPasses(verifyPasses) + .setAllowUnregisteredDialects(allowUnregisteredDialects) + .setPreloadDialectsInContext(preloadDialectsInContext) + .setEmitBytecode(emitBytecode) + .setUseImplicitModule(implicitModule) + .setPassPipelineSetupFn(passManagerSetupFn)); } LogicalResult mlir::MlirOptMain( @@ -192,23 +222,18 @@ bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, bool emitBytecode, bool implicitModule, bool dumpPassPipeline) { - auto passManagerSetupFn = [&](PassManager &pm) { - auto errorHandler = [&](const Twine &msg) { - emitError(UnknownLoc::get(pm.getContext())) << msg; - return failure(); - }; - if (failed(passPipeline.addToPipeline(pm, errorHandler))) - return failure(); - if (dumpPassPipeline) { - pm.dump(); - llvm::errs() << "\n"; - } - return success(); - }; - return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn, - registry, splitInputFile, verifyDiagnostics, verifyPasses, - allowUnregisteredDialects, preloadDialectsInContext, - emitBytecode, implicitModule); + return MlirOptMain( + outputStream, std::move(buffer), registry, + MlirOptMainConfig{} + .setSplitInputFile(splitInputFile) + .setVerifyDiagnostics(verifyDiagnostics) + .setVerifyPasses(verifyPasses) + .setAllowUnregisteredDialects(allowUnregisteredDialects) + .setPreloadDialectsInContext(preloadDialectsInContext) + .setEmitBytecode(emitBytecode) + .setUseImplicitModule(implicitModule) + .setDumpPassPipeline(dumpPassPipeline) + .setPassPipelineParser(passPipeline)); } LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, @@ -302,11 +327,18 @@ return failure(); } - if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry, - splitInputFile, verifyDiagnostics, verifyPasses, - allowUnregisteredDialects, preloadDialectsInContext, - emitBytecode, /*implicitModule=*/!noImplicitModule, - dumpPassPipeline))) + if (failed(MlirOptMain( + output->os(), std::move(file), registry, + MlirOptMainConfig{} + .setPassPipelineParser(passPipeline) + .setSplitInputFile(splitInputFile) + .setVerifyDiagnostics(verifyDiagnostics) + .setVerifyPasses(verifyPasses) + .setAllowUnregisteredDialects(allowUnregisteredDialects) + .setPreloadDialectsInContext(preloadDialectsInContext) + .setEmitBytecode(emitBytecode) + .setUseImplicitModule(!noImplicitModule) + .setDumpPassPipeline(dumpPassPipeline)))) return failure(); // Keep the output file if the invocation of MlirOptMain was successful.