diff --git a/mlir/include/mlir/Support/MlirOptMain.h b/mlir/include/mlir/Support/MlirOptMain.h --- a/mlir/include/mlir/Support/MlirOptMain.h +++ b/mlir/include/mlir/Support/MlirOptMain.h @@ -10,8 +10,10 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/StringRef.h" + #include -#include namespace llvm { class raw_ostream; @@ -19,13 +21,27 @@ } // end namespace llvm namespace mlir { -struct LogicalResult; class PassPipelineCLParser; -LogicalResult MlirOptMain(llvm::raw_ostream &os, +/// Perform the core processing behind `mlir-opt`: +/// - outputStream is the stream where the resulting IR is printed. +/// - buffer is the in-memory file to parser and process. +/// - passPipeline is the specification of the pipeline that will be applied. +/// - splitInputFile will look for a "-----" marker in the input file, and load +/// each chunk in an individual ModuleOp processed separately. +/// - verifyDiagnostics enables a verification mode where comments starting with +/// "expected-(error|note|remark|warning)" are parsed in the input and matched +/// against emitted diagnostics. +/// - verifyPasses enables the IR verifier in-between each pass in the pipeline. +/// - allowUnregisteredDialects allows to parse and create operation without +/// registering the Dialect in the MLIRContext. +LogicalResult MlirOptMain(llvm::raw_ostream &outputStream, std::unique_ptr buffer, const PassPipelineCLParser &passPipeline, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects); +/// Implementation for tools like `mlir-opt`. +LogicalResult MlirOptMain(int argc, char **argv, llvm::StringRef toolName); + } // end namespace mlir diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -12,19 +12,25 @@ //===----------------------------------------------------------------------===// #include "mlir/Support/MlirOptMain.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" #include "mlir/Support/ToolUtilities.h" #include "mlir/Transforms/Passes.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/FileUtilities.h" +#include "llvm/Support/InitLLVM.h" #include "llvm/Support/Regex.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" using namespace mlir; using namespace llvm; @@ -106,7 +112,7 @@ return sourceMgrHandler.verify(); } -LogicalResult mlir::MlirOptMain(raw_ostream &os, +LogicalResult mlir::MlirOptMain(raw_ostream &outputStream, std::unique_ptr buffer, const PassPipelineCLParser &passPipeline, bool splitInputFile, bool verifyDiagnostics, @@ -122,8 +128,86 @@ verifyPasses, allowUnregisteredDialects, passPipeline); }, - os); + outputStream); - return processBuffer(os, std::move(buffer), verifyDiagnostics, verifyPasses, - allowUnregisteredDialects, passPipeline); + return processBuffer(outputStream, std::move(buffer), verifyDiagnostics, + verifyPasses, allowUnregisteredDialects, passPipeline); +} + +LogicalResult mlir::MlirOptMain(int argc, char **argv, StringRef toolName) { + static cl::opt inputFilename( + cl::Positional, cl::desc(""), cl::init("-")); + + static cl::opt outputFilename("o", cl::desc("Output filename"), + cl::value_desc("filename"), + cl::init("-")); + + static cl::opt splitInputFile( + "split-input-file", + cl::desc("Split the input file into pieces and process each " + "chunk independently"), + cl::init(false)); + + static cl::opt verifyDiagnostics( + "verify-diagnostics", + cl::desc("Check that emitted diagnostics match " + "expected-* lines on the corresponding line"), + cl::init(false)); + + static cl::opt verifyPasses( + "verify-each", + cl::desc("Run the verifier after each transformation pass"), + cl::init(true)); + + static cl::opt allowUnregisteredDialects( + "allow-unregistered-dialect", + cl::desc("Allow operation with no registered dialects"), cl::init(false)); + + static cl::opt showDialects( + "show-dialects", cl::desc("Print the list of registered dialects"), + cl::init(false)); + + InitLLVM y(argc, argv); + + // Register any command line options. + registerAsmPrinterCLOptions(); + registerMLIRContextCLOptions(); + registerPassManagerCLOptions(); + PassPipelineCLParser passPipeline("", "Compiler passes to run"); + + // Parse pass names in main to ensure static initialization completed. + cl::ParseCommandLineOptions(argc, argv, (toolName + "\n").str()); + + if (showDialects) { + llvm::outs() << "Registered Dialects:\n"; + MLIRContext context; + interleave( + context.getRegisteredDialects(), llvm::outs(), + [](Dialect *dialect) { llvm::outs() << dialect->getNamespace(); }, + "\n"); + return success(); + } + + // Set up the input file. + std::string errorMessage; + auto file = openInputFile(inputFilename, &errorMessage); + if (!file) { + llvm::errs() << errorMessage << "\n"; + return failure(); + } + + auto output = openOutputFile(outputFilename, &errorMessage); + if (!output) { + llvm::errs() << errorMessage << "\n"; + return failure(); + } + + if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, + splitInputFile, verifyDiagnostics, verifyPasses, + allowUnregisteredDialects))) + return failure(); + + // Keep the output file if the invocation of MlirOptMain was successful. + output->keep(); + return success(); } diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -72,34 +72,6 @@ void registerVectorizerTestPass(); } // namespace mlir -static cl::opt - inputFilename(cl::Positional, cl::desc(""), cl::init("-")); - -static cl::opt outputFilename("o", cl::desc("Output filename"), - cl::value_desc("filename"), - cl::init("-")); - -static cl::opt - splitInputFile("split-input-file", - cl::desc("Split the input file into pieces and process each " - "chunk independently"), - cl::init(false)); - -static cl::opt - verifyDiagnostics("verify-diagnostics", - cl::desc("Check that emitted diagnostics match " - "expected-* lines on the corresponding line"), - cl::init(false)); - -static cl::opt - verifyPasses("verify-each", - cl::desc("Run the verifier after each transformation pass"), - cl::init(true)); - -static cl::opt allowUnregisteredDialects( - "allow-unregistered-dialect", - cl::desc("Allow operation with no registered dialects"), cl::init(false)); - #ifdef MLIR_INCLUDE_TESTS void registerTestPasses() { registerConvertToTargetEnvPass(); @@ -150,57 +122,11 @@ } #endif -static cl::opt - showDialects("show-dialects", - cl::desc("Print the list of registered dialects"), - cl::init(false)); - int main(int argc, char **argv) { registerAllDialects(); registerAllPasses(); #ifdef MLIR_INCLUDE_TESTS registerTestPasses(); #endif - InitLLVM y(argc, argv); - - // Register any command line options. - registerAsmPrinterCLOptions(); - registerMLIRContextCLOptions(); - registerPassManagerCLOptions(); - PassPipelineCLParser passPipeline("", "Compiler passes to run"); - - // Parse pass names in main to ensure static initialization completed. - cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n"); - - if(showDialects) { - llvm::outs() << "Registered Dialects:\n"; - MLIRContext context; - for(Dialect *dialect : context.getRegisteredDialects()) { - llvm::outs() << dialect->getNamespace() << "\n"; - } - return 0; - } - - // Set up the input file. - std::string errorMessage; - auto file = openInputFile(inputFilename, &errorMessage); - if (!file) { - llvm::errs() << errorMessage << "\n"; - return 1; - } - - auto output = openOutputFile(outputFilename, &errorMessage); - if (!output) { - llvm::errs() << errorMessage << "\n"; - exit(1); - } - - if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, - splitInputFile, verifyDiagnostics, verifyPasses, - allowUnregisteredDialects))) { - return 1; - } - // Keep the output file if the invocation of MlirOptMain was successful. - output->keep(); - return 0; + return failed(MlirOptMain(argc, argv, "MLIR modular optimizer driver\n")); }