diff --git a/mlir/include/mlir/Translation.h b/mlir/include/mlir/Translation.h --- a/mlir/include/mlir/Translation.h +++ b/mlir/include/mlir/Translation.h @@ -21,6 +21,7 @@ } // namespace llvm namespace mlir { +class DialectRegistry; struct LogicalResult; class MLIRContext; class ModuleOp; @@ -91,6 +92,12 @@ void printOptionInfo(const llvm::cl::Option &o, size_t globalWidth) const override; }; + +/// Implementation for tools like `mlir-translate`. ToolName is used for the +/// header displayed by `--help`. +LogicalResult mlirTranslateMain(int argc, char **argv, + llvm::StringRef toolName); + } // namespace mlir #endif // MLIR_TRANSLATION_H diff --git a/mlir/lib/Translation/Translation.cpp b/mlir/lib/Translation/Translation.cpp --- a/mlir/lib/Translation/Translation.cpp +++ b/mlir/lib/Translation/Translation.cpp @@ -11,11 +11,15 @@ //===----------------------------------------------------------------------===// #include "mlir/Translation.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/Module.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser.h" -#include "mlir/Support/LLVM.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/ToolUtilities.h" +#include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" using namespace mlir; @@ -119,3 +123,82 @@ }); llvm::cl::parser::printOptionInfo(o, globalWidth); } + +LogicalResult mlir::mlirTranslateMain(int argc, char **argv, + llvm::StringRef toolName) { + + static llvm::cl::opt inputFilename( + llvm::cl::Positional, llvm::cl::desc(""), + llvm::cl::init("-")); + + static llvm::cl::opt outputFilename( + "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), + llvm::cl::init("-")); + + static llvm::cl::opt splitInputFile( + "split-input-file", + llvm::cl::desc("Split the input file into pieces and " + "process each chunk independently"), + llvm::cl::init(false)); + + static llvm::cl::opt verifyDiagnostics( + "verify-diagnostics", + llvm::cl::desc("Check that emitted diagnostics match " + "expected-* lines on the corresponding line"), + llvm::cl::init(false)); + + llvm::InitLLVM y(argc, argv); + + // Add flags for all the registered translations. + llvm::cl::opt + translationRequested("", llvm::cl::desc("Translation to perform"), + llvm::cl::Required); + registerAsmPrinterCLOptions(); + registerMLIRContextCLOptions(); + llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR translation driver\n"); + + std::string errorMessage; + auto input = openInputFile(inputFilename, &errorMessage); + if (!input) { + llvm::errs() << errorMessage << "\n"; + return failure(); + } + + auto output = openOutputFile(outputFilename, &errorMessage); + if (!output) { + llvm::errs() << errorMessage << "\n"; + return failure(); + } + + // Processes the memory buffer with a new MLIRContext. + auto processBuffer = [&](std::unique_ptr ownedBuffer, + raw_ostream &os) { + MLIRContext context; + context.printOpOnDiagnostic(!verifyDiagnostics); + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc()); + + if (!verifyDiagnostics) { + SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); + return (*translationRequested)(sourceMgr, os, &context); + } + + // In the diagnostic verification flow, we ignore whether the translation + // failed (in most cases, it is expected to fail). Instead, we check if the + // diagnostics were produced as expected. + SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context); + (*translationRequested)(sourceMgr, os, &context); + return sourceMgrHandler.verify(); + }; + + if (splitInputFile) { + if (failed(splitAndProcessBuffer(std::move(input), processBuffer, + output->os()))) + return failure(); + } else if (failed(processBuffer(std::move(input), output->os()))) { + return failure(); + } + + output->keep(); + return success(); +} diff --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp --- a/mlir/tools/mlir-translate/mlir-translate.cpp +++ b/mlir/tools/mlir-translate/mlir-translate.cpp @@ -11,41 +11,12 @@ // //===----------------------------------------------------------------------===// -#include "mlir/IR/AsmState.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllTranslations.h" -#include "mlir/Support/FileUtilities.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Support/ToolUtilities.h" #include "mlir/Translation.h" -#include "llvm/Support/InitLLVM.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/ToolOutputFile.h" using namespace mlir; -static llvm::cl::opt inputFilename(llvm::cl::Positional, - llvm::cl::desc(""), - llvm::cl::init("-")); - -static llvm::cl::opt - outputFilename("o", llvm::cl::desc("Output filename"), - llvm::cl::value_desc("filename"), llvm::cl::init("-")); - -static llvm::cl::opt - splitInputFile("split-input-file", - llvm::cl::desc("Split the input file into pieces and " - "process each chunk independently"), - llvm::cl::init(false)); - -static llvm::cl::opt verifyDiagnostics( - "verify-diagnostics", - llvm::cl::desc("Check that emitted diagnostics match " - "expected-* lines on the corresponding line"), - llvm::cl::init(false)); namespace mlir { // Defined in the test directory, no public header. @@ -59,64 +30,9 @@ } int main(int argc, char **argv) { - registerAllDialects(); registerAllTranslations(); registerTestTranslations(); - llvm::InitLLVM y(argc, argv); - - // Add flags for all the registered translations. - llvm::cl::opt - translationRequested("", llvm::cl::desc("Translation to perform"), - llvm::cl::Required); - registerAsmPrinterCLOptions(); - registerMLIRContextCLOptions(); - llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR translation driver\n"); - - std::string errorMessage; - auto input = openInputFile(inputFilename, &errorMessage); - if (!input) { - llvm::errs() << errorMessage << "\n"; - return 1; - } - - auto output = openOutputFile(outputFilename, &errorMessage); - if (!output) { - llvm::errs() << errorMessage << "\n"; - return 1; - } - - // Processes the memory buffer with a new MLIRContext. - auto processBuffer = [&](std::unique_ptr ownedBuffer, - raw_ostream &os) { - MLIRContext context(false); - registerAllDialects(&context); - context.allowUnregisteredDialects(); - context.printOpOnDiagnostic(!verifyDiagnostics); - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc()); - - if (!verifyDiagnostics) { - SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); - return (*translationRequested)(sourceMgr, os, &context); - } - - // In the diagnostic verification flow, we ignore whether the translation - // failed (in most cases, it is expected to fail). Instead, we check if the - // diagnostics were produced as expected. - SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context); - (*translationRequested)(sourceMgr, os, &context); - return sourceMgrHandler.verify(); - }; - - if (splitInputFile) { - if (failed(splitAndProcessBuffer(std::move(input), processBuffer, - output->os()))) - return 1; - } else { - if (failed(processBuffer(std::move(input), output->os()))) - return 1; - } - - output->keep(); - return 0; + // TODO: remove the global dialect registry + registerAllDialects(); + return failed(mlirTranslateMain(argc, argv, "MLIR Translation Testing Tool")); }