diff --git a/mlir/include/mlir/Tools/mlir-translate/MlirTranslateMain.h b/mlir/include/mlir/Tools/mlir-translate/MlirTranslateMain.h --- a/mlir/include/mlir/Tools/mlir-translate/MlirTranslateMain.h +++ b/mlir/include/mlir/Tools/mlir-translate/MlirTranslateMain.h @@ -13,16 +13,30 @@ #ifndef MLIR_TOOLS_MLIRTRANSLATE_MLIRTRANSLATEMAIN_H #define MLIR_TOOLS_MLIRTRANSLATE_MLIRTRANSLATEMAIN_H +#include "mlir/IR/Dialect.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/StringRef.h" namespace mlir { + /// Translate to/from an MLIR module from/to an external representation (e.g. /// LLVM IR, SPIRV binary, ...). This is the entry point for the implementation /// of tools like `mlir-translate`. The translation to perform is parsed from /// the command line. The `toolName` argument is used for the header displayed /// by `--help`. -LogicalResult mlirTranslateMain(int argc, char **argv, StringRef toolName); +/// +/// Dialect translation typically registers the dialects produced or returned +/// by the translation itself, but some translation testing tools may want +/// additional dialects registered so the .mlir parser can read them. In this +/// case, `extraDialects` may be specified with additional dialects to use. +/// +/// The client may specify a "customization" function if they'd like, which +/// is invoked when an MLIRContext is set up, allowing custom settings. +LogicalResult +mlirTranslateMain(int argc, char **argv, StringRef toolName, + const DialectRegistry &extraDialects = DialectRegistry(), + llvm::function_ref customization = {}); } // namespace mlir #endif // MLIR_TOOLS_MLIRTRANSLATE_MLIRTRANSLATEMAIN_H diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp --- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp +++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp @@ -22,11 +22,13 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// Translation Parser +// mlir-translate tool driver //===----------------------------------------------------------------------===// -LogicalResult mlir::mlirTranslateMain(int argc, char **argv, - llvm::StringRef toolName) { +LogicalResult +mlir::mlirTranslateMain(int argc, char **argv, llvm::StringRef toolName, + const DialectRegistry &extraDialects, + llvm::function_ref customization) { static llvm::cl::opt inputFilename( llvm::cl::Positional, llvm::cl::desc(""), @@ -80,8 +82,22 @@ auto processBuffer = [&](std::unique_ptr ownedBuffer, raw_ostream &os) { MLIRContext context; - context.allowUnregisteredDialects(allowUnregisteredDialects); + + // If the client wanted to register additional dialects, go ahead and add + // them to our context. + context.appendDialectRegistry(extraDialects); + + // If a customization callback was provided, apply it to the MLIRContext. + // This could add dialects to the registry or change context defaults. + if (customization) + customization(context); + + // If command line flags were used to customize the context, apply their + // settings. + if (allowUnregisteredDialects.getNumOccurrences()) + context.allowUnregisteredDialects(allowUnregisteredDialects); context.printOpOnDiagnostic(!verifyDiagnostics); + llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());