diff --git a/mlir/include/mlir/Support/MlirOptMain.h b/mlir/include/mlir/Support/MlirOptMain.h --- a/mlir/include/mlir/Support/MlirOptMain.h +++ b/mlir/include/mlir/Support/MlirOptMain.h @@ -27,6 +27,12 @@ namespace mlir { class DialectRegistry; class PassPipelineCLParser; +class PassManager; + +/// This defines the function type used to setup the pass manager. This can be +/// used to pass in a callback to setup a default pass pipeline to be applied on +/// the loaded IR. +using PassPipelineFn = llvm::function_ref; /// Perform the core processing behind `mlir-opt`: /// - outputStream is the stream where the resulting IR is printed. @@ -52,6 +58,17 @@ bool allowUnregisteredDialects, bool preloadDialectsInContext = false); +/// Support a callback to setup the pass manager. +/// - passManagerSetupFn is the callback invoked to setup the pass manager to +/// apply on the loaded IR. +LogicalResult MlirOptMain(llvm::raw_ostream &outputStream, + std::unique_ptr buffer, + PassPipelineFn passManagerSetupFn, + DialectRegistry ®istry, bool splitInputFile, + bool verifyDiagnostics, bool verifyPasses, + bool allowUnregisteredDialects, + bool preloadDialectsInContext = false); + /// Implementation for tools like `mlir-opt`. /// - toolName is used for the header displayed by `--help`. /// - registry should contain all the dialects that can be parsed in the source. diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -48,7 +48,7 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses, SourceMgr &sourceMgr, MLIRContext *context, - const PassPipelineCLParser &passPipeline) { + PassPipelineFn passManagerSetupFn) { DefaultTimingManager tm; applyDefaultTimingManagerCLOptions(tm); TimingScope timing = tm.getRootScope(); @@ -72,13 +72,8 @@ applyPassManagerCLOptions(pm); pm.enableTiming(timing); - auto errorHandler = [&](const Twine &msg) { - emitError(UnknownLoc::get(context)) << msg; - return failure(); - }; - - // Build the provided pipeline. - if (failed(passPipeline.addToPipeline(pm, errorHandler))) + // Callback to build the pipeline. + if (failed(passManagerSetupFn(pm))) return failure(); // Run the pipeline. @@ -98,8 +93,8 @@ processBuffer(raw_ostream &os, std::unique_ptr ownedBuffer, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, - const PassPipelineCLParser &passPipeline, - DialectRegistry ®istry, llvm::ThreadPool &threadPool) { + PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, + llvm::ThreadPool &threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); @@ -122,7 +117,7 @@ if (!verifyDiagnostics) { SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, - &context, passPipeline); + &context, passManagerSetupFn); } SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context); @@ -131,7 +126,7 @@ // these actions succeed or fail, we only care what diagnostics they produce // and whether they match our expectations. (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context, - passPipeline); + passManagerSetupFn); // Verify the diagnostic handler to make sure that each of the diagnostics // matched. @@ -140,7 +135,7 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream, std::unique_ptr buffer, - const PassPipelineCLParser &passPipeline, + PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, @@ -156,17 +151,36 @@ [&](std::unique_ptr chunkBuffer, raw_ostream &os) { return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, - preloadDialectsInContext, passPipeline, registry, - threadPool); + preloadDialectsInContext, passManagerSetupFn, + registry, threadPool); }, outputStream); return processBuffer(outputStream, std::move(buffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, - preloadDialectsInContext, passPipeline, registry, + preloadDialectsInContext, passManagerSetupFn, registry, threadPool); } +LogicalResult mlir::MlirOptMain(raw_ostream &outputStream, + std::unique_ptr buffer, + const PassPipelineCLParser &passPipeline, + DialectRegistry ®istry, bool splitInputFile, + bool verifyDiagnostics, bool verifyPasses, + bool allowUnregisteredDialects, + bool preloadDialectsInContext) { + auto passManagerSetupFn = [&](PassManager &pm) { + auto errorHandler = [&](const Twine &msg) { + emitError(UnknownLoc::get(pm.getContext())) << msg; + return failure(); + }; + return passPipeline.addToPipeline(pm, errorHandler); + }; + return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn, + registry, splitInputFile, verifyDiagnostics, verifyPasses, + allowUnregisteredDialects, preloadDialectsInContext); +} + LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, DialectRegistry ®istry, bool preloadDialectsInContext) {