diff --git a/mlir/include/mlir/Support/MlirOptMain.h b/mlir/include/mlir/Support/MlirOptMain.h --- a/mlir/include/mlir/Support/MlirOptMain.h +++ b/mlir/include/mlir/Support/MlirOptMain.h @@ -27,6 +27,9 @@ namespace mlir { class DialectRegistry; class PassPipelineCLParser; +class PassManager; + +using DefaultPassPipelineFn = llvm::function_ref; /// Perform the core processing behind `mlir-opt`: /// - outputStream is the stream where the resulting IR is printed. @@ -50,7 +53,8 @@ DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, - bool preloadDialectsInContext = false); + bool preloadDialectsInContext = false, + DefaultPassPipelineFn fn = nullptr); /// Implementation for tools like `mlir-opt`. /// - toolName is used for the header displayed by `--help`. diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -26,6 +26,7 @@ #include "mlir/Support/FileUtilities.h" #include "mlir/Support/Timing.h" #include "mlir/Support/ToolUtilities.h" +#include "mlir/Transforms/Passes.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileUtilities.h" #include "llvm/Support/InitLLVM.h" @@ -48,7 +49,8 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses, SourceMgr &sourceMgr, MLIRContext *context, - const PassPipelineCLParser &passPipeline) { + const PassPipelineCLParser &passPipeline, + DefaultPassPipelineFn fn) { DefaultTimingManager tm; applyDefaultTimingManagerCLOptions(tm); TimingScope timing = tm.getRootScope(); @@ -77,9 +79,14 @@ return failure(); }; - // Build the provided pipeline. - if (failed(passPipeline.addToPipeline(pm, errorHandler))) - return failure(); + if (passPipeline.hasAnyOccurrences()) { + // Build the provided pipeline. + if (failed(passPipeline.addToPipeline(pm, errorHandler))) + return failure(); + } else if (fn) { + // Set default pipeline. + fn(pm); + } // Run the pipeline. if (failed(pm.run(*module))) @@ -99,7 +106,8 @@ bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, const PassPipelineCLParser &passPipeline, - DialectRegistry ®istry, llvm::ThreadPool &threadPool) { + DialectRegistry ®istry, llvm::ThreadPool &threadPool, + DefaultPassPipelineFn fn) { // Tell sourceMgr about this buffer, which is what the parser will pick up. SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); @@ -122,7 +130,7 @@ if (!verifyDiagnostics) { SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, - &context, passPipeline); + &context, passPipeline, fn); } SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context); @@ -131,7 +139,7 @@ // these actions succeed or fail, we only care what diagnostics they produce // and whether they match our expectations. (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context, - passPipeline); + passPipeline, fn); // Verify the diagnostic handler to make sure that each of the diagnostics // matched. @@ -144,7 +152,8 @@ DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, - bool preloadDialectsInContext) { + bool preloadDialectsInContext, + DefaultPassPipelineFn fn) { // The split-input-file mode is a very specific mode that slices the file // up into small pieces and checks each independently. // We use an explicit threadpool to avoid creating and joining/destroying @@ -157,14 +166,14 @@ return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, preloadDialectsInContext, passPipeline, registry, - threadPool); + threadPool, fn); }, outputStream); return processBuffer(outputStream, std::move(buffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, preloadDialectsInContext, passPipeline, registry, - threadPool); + threadPool, fn); } LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, @@ -266,9 +275,11 @@ return failure(); } - if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry, - splitInputFile, verifyDiagnostics, verifyPasses, - allowUnregisteredDialects, preloadDialectsInContext))) + if (failed(MlirOptMain( + output->os(), std::move(file), passPipeline, registry, splitInputFile, + verifyDiagnostics, verifyPasses, allowUnregisteredDialects, + preloadDialectsInContext, + [&](PassManager &pm) { pm.addPass(createCanonicalizerPass()); }))) return failure(); // Keep the output file if the invocation of MlirOptMain was successful.