diff --git a/mlir/include/mlir/Support/MlirOptMain.h b/mlir/include/mlir/Support/MlirOptMain.h --- a/mlir/include/mlir/Support/MlirOptMain.h +++ b/mlir/include/mlir/Support/MlirOptMain.h @@ -27,11 +27,16 @@ namespace mlir { class DialectRegistry; class PassPipelineCLParser; +class PassManager; + +using PassPipelineFn = llvm::function_ref errorHandler)>; /// Perform the core processing behind `mlir-opt`: /// - outputStream is the stream where the resulting IR is printed. /// - buffer is the in-memory file to parser and process. -/// - passPipeline is the specification of the pipeline that will be applied. +/// - pipelineFn is the callback to the pass pipeline that will be applied. /// - registry should contain all the dialects that can be parsed in the source. /// - splitInputFile will look for a "-----" marker in the input file, and load /// each chunk in an individual ModuleOp processed separately. @@ -46,10 +51,9 @@ /// deprecated and will be removed soon. LogicalResult MlirOptMain(llvm::raw_ostream &outputStream, std::unique_ptr buffer, - const PassPipelineCLParser &passPipeline, - DialectRegistry ®istry, bool splitInputFile, - bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects, + PassPipelineFn pipelineFn, DialectRegistry ®istry, + bool splitInputFile, bool verifyDiagnostics, + bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext = false); /// Implementation for tools like `mlir-opt`. diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -48,7 +48,7 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses, SourceMgr &sourceMgr, MLIRContext *context, - const PassPipelineCLParser &passPipeline) { + PassPipelineFn pipelineFn) { DefaultTimingManager tm; applyDefaultTimingManagerCLOptions(tm); TimingScope timing = tm.getRootScope(); @@ -77,8 +77,8 @@ return failure(); }; - // Build the provided pipeline. - if (failed(passPipeline.addToPipeline(pm, errorHandler))) + // Callback to build the pipeline. + if (failed(pipelineFn(pm, errorHandler))) return failure(); // Run the pipeline. @@ -98,8 +98,8 @@ processBuffer(raw_ostream &os, std::unique_ptr ownedBuffer, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, - const PassPipelineCLParser &passPipeline, - DialectRegistry ®istry, llvm::ThreadPool &threadPool) { + PassPipelineFn pipelineFn, DialectRegistry ®istry, + llvm::ThreadPool &threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); @@ -122,7 +122,7 @@ if (!verifyDiagnostics) { SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, - &context, passPipeline); + &context, pipelineFn); } SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context); @@ -131,7 +131,7 @@ // these actions succeed or fail, we only care what diagnostics they produce // and whether they match our expectations. (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context, - passPipeline); + pipelineFn); // Verify the diagnostic handler to make sure that each of the diagnostics // matched. @@ -140,7 +140,7 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream, std::unique_ptr buffer, - const PassPipelineCLParser &passPipeline, + PassPipelineFn pipelineFn, DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, @@ -156,14 +156,14 @@ [&](std::unique_ptr chunkBuffer, raw_ostream &os) { return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, - preloadDialectsInContext, passPipeline, registry, + preloadDialectsInContext, pipelineFn, registry, threadPool); }, outputStream); return processBuffer(outputStream, std::move(buffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, - preloadDialectsInContext, passPipeline, registry, + preloadDialectsInContext, pipelineFn, registry, threadPool); } @@ -266,9 +266,14 @@ return failure(); } - if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry, - splitInputFile, verifyDiagnostics, verifyPasses, - allowUnregisteredDialects, preloadDialectsInContext))) + if (failed(MlirOptMain( + output->os(), std::move(file), + [&](PassManager &pm, + function_ref errorHandler) { + return passPipeline.addToPipeline(pm, errorHandler); + }, + registry, splitInputFile, verifyDiagnostics, verifyPasses, + allowUnregisteredDialects, preloadDialectsInContext))) return failure(); // Keep the output file if the invocation of MlirOptMain was successful.