diff --git a/mlir/include/mlir/Translation.h b/mlir/include/mlir/Translation.h --- a/mlir/include/mlir/Translation.h +++ b/mlir/include/mlir/Translation.h @@ -101,8 +101,12 @@ /// of tools like `mlir-translate`. The translation to perform is parsed from /// the command line. The `toolName` argument is used for the header displayed /// by `--help`. -LogicalResult mlirTranslateMain(int argc, char **argv, - llvm::StringRef toolName); +/// +/// The client may specify a "customization" function if they'd like, which +/// is invoked when an MLIRContext is set up, allowing custom settings. +LogicalResult +mlirTranslateMain(int argc, char **argv, llvm::StringRef toolName, + std::function customization = {}); } // namespace mlir diff --git a/mlir/lib/Translation/Translation.cpp b/mlir/lib/Translation/Translation.cpp --- a/mlir/lib/Translation/Translation.cpp +++ b/mlir/lib/Translation/Translation.cpp @@ -129,8 +129,9 @@ llvm::cl::parser::printOptionInfo(o, globalWidth); } -LogicalResult mlir::mlirTranslateMain(int argc, char **argv, - llvm::StringRef toolName) { +LogicalResult +mlir::mlirTranslateMain(int argc, char **argv, llvm::StringRef toolName, + std::function customization) { static llvm::cl::opt inputFilename( llvm::cl::Positional, llvm::cl::desc(""), @@ -184,8 +185,19 @@ auto processBuffer = [&](std::unique_ptr ownedBuffer, raw_ostream &os) { MLIRContext context; - context.allowUnregisteredDialects(allowUnregisteredDialects); + + // If a customization callback was provided, apply it to the MLIRContext. + // This could add dialects to the registry or change context defaults. + if (customization) + customization(context); + + // If command line flags were used to customize the context, apply their + // settings. + if (allowUnregisteredDialects.getNumOccurrences()) + context.allowUnregisteredDialects(allowUnregisteredDialects); + context.printOpOnDiagnostic(!verifyDiagnostics); + llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());