diff --git a/mlir/examples/standalone/standalone-translate/standalone-translate.cpp b/mlir/examples/standalone/standalone-translate/standalone-translate.cpp --- a/mlir/examples/standalone/standalone-translate/standalone-translate.cpp +++ b/mlir/examples/standalone/standalone-translate/standalone-translate.cpp @@ -11,16 +11,23 @@ // //===----------------------------------------------------------------------===// +#include "Standalone/StandaloneDialect.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/InitAllTranslations.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Tools/mlir-translate/MlirTranslateMain.h" - -#include "Standalone/StandaloneDialect.h" +#include "mlir/Tools/mlir-translate/Translation.h" int main(int argc, char **argv) { mlir::registerAllTranslations(); // TODO: Register standalone translations here. + mlir::TranslateFromMLIRRegistration withdescription( + "option", "different from option", + [](mlir::ModuleOp op, llvm::raw_ostream &output) { + return mlir::LogicalResult::success(); + }, + [](mlir::DialectRegistry &a) {}); return failed( mlir::mlirTranslateMain(argc, argv, "MLIR Translation Testing Tool")); diff --git a/mlir/include/mlir/Tools/mlir-translate/Translation.h b/mlir/include/mlir/Tools/mlir-translate/Translation.h --- a/mlir/include/mlir/Tools/mlir-translate/Translation.h +++ b/mlir/include/mlir/Tools/mlir-translate/Translation.h @@ -75,6 +75,10 @@ const TranslateSourceMgrToMLIRFunction &function); TranslateToMLIRRegistration(llvm::StringRef name, const TranslateStringRefToMLIRFunction &function); + TranslateToMLIRRegistration(llvm::StringRef name, llvm::StringRef description, + const TranslateSourceMgrToMLIRFunction &function); + TranslateToMLIRRegistration(llvm::StringRef name, llvm::StringRef description, + const TranslateStringRefToMLIRFunction &function); }; struct TranslateFromMLIRRegistration { @@ -82,10 +86,17 @@ llvm::StringRef name, const TranslateFromMLIRFunction &function, const std::function &dialectRegistration = [](DialectRegistry &) {}); + TranslateFromMLIRRegistration( + llvm::StringRef name, llvm::StringRef description, + const TranslateFromMLIRFunction &function, + const std::function &dialectRegistration = + [](DialectRegistry &) {}); }; struct TranslateRegistration { TranslateRegistration(llvm::StringRef name, const TranslateFunction &function); + TranslateRegistration(llvm::StringRef name, llvm::StringRef description, + const TranslateFunction &function); }; /// \} diff --git a/mlir/lib/Tools/mlir-translate/Translation.cpp b/mlir/lib/Tools/mlir-translate/Translation.cpp --- a/mlir/lib/Tools/mlir-translate/Translation.cpp +++ b/mlir/lib/Tools/mlir-translate/Translation.cpp @@ -31,8 +31,15 @@ return translationRegistry; } +/// Get the mutable static map between registered file-to-file MLIR translations +/// and the description corresponding to the command line name. +static llvm::StringMap &getTranslationDescription() { + static llvm::StringMap translationDescription; + return translationDescription; +} + /// Register the given translation. -static void registerTranslation(StringRef name, +static void registerTranslation(StringRef name, StringRef description, const TranslateFunction &function) { auto &translationRegistry = getTranslationRegistry(); if (translationRegistry.find(name) != translationRegistry.end()) @@ -40,12 +47,19 @@ "Attempting to overwrite an existing function"); assert(function && "Attempting to register an empty translate function"); + auto &translationDescription = getTranslationDescription(); translationRegistry[name] = function; + translationDescription[name] = description; } TranslateRegistration::TranslateRegistration( StringRef name, const TranslateFunction &function) { - registerTranslation(name, function); + registerTranslation(name, name, function); +} + +TranslateRegistration::TranslateRegistration( + StringRef name, StringRef description, const TranslateFunction &function) { + registerTranslation(name, description, function); } //===----------------------------------------------------------------------===// @@ -55,7 +69,8 @@ // Puts `function` into the to-MLIR translation registry unless there is already // a function registered for the same name. static void registerTranslateToMLIRFunction( - StringRef name, const TranslateSourceMgrToMLIRFunction &function) { + StringRef name, StringRef description, + const TranslateSourceMgrToMLIRFunction &function) { auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output, MLIRContext *context) { OwningOpRef module = function(sourceMgr, context); @@ -64,20 +79,32 @@ module->print(output); return success(); }; - registerTranslation(name, wrappedFn); + registerTranslation(name, description, wrappedFn); } TranslateToMLIRRegistration::TranslateToMLIRRegistration( StringRef name, const TranslateSourceMgrToMLIRFunction &function) { - registerTranslateToMLIRFunction(name, function); + TranslateToMLIRRegistration(name, name, function); } +TranslateToMLIRRegistration::TranslateToMLIRRegistration( + StringRef name, StringRef description, + const TranslateSourceMgrToMLIRFunction &function) { + registerTranslateToMLIRFunction(name, description, function); +} /// Wraps `function` with a lambda that extracts a StringRef from a source /// manager and registers the wrapper lambda as a to-MLIR conversion. TranslateToMLIRRegistration::TranslateToMLIRRegistration( StringRef name, const TranslateStringRefToMLIRFunction &function) { + TranslateToMLIRRegistration(name, name, function); +} + +TranslateToMLIRRegistration::TranslateToMLIRRegistration( + StringRef name, StringRef description, + const TranslateStringRefToMLIRFunction &function) { registerTranslateToMLIRFunction( - name, [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) { + name, description, + [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) { const llvm::MemoryBuffer *buffer = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); return function(buffer->getBuffer(), ctx); @@ -91,17 +118,26 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration( StringRef name, const TranslateFromMLIRFunction &function, const std::function &dialectRegistration) { - registerTranslation(name, [function, dialectRegistration]( - llvm::SourceMgr &sourceMgr, raw_ostream &output, - MLIRContext *context) { - DialectRegistry registry; - dialectRegistration(registry); - context->appendDialectRegistry(registry); - auto module = parseSourceFile(sourceMgr, context); - if (!module || failed(verify(*module))) - return failure(); - return function(module.get(), output); - }); + TranslateFromMLIRRegistration(name, name, function, dialectRegistration); +} + +TranslateFromMLIRRegistration::TranslateFromMLIRRegistration( + StringRef name, StringRef description, + const TranslateFromMLIRFunction &function, + const std::function &dialectRegistration) { + registerTranslation(name, description, + [function, dialectRegistration]( + llvm::SourceMgr &sourceMgr, raw_ostream &output, + MLIRContext *context) { + DialectRegistry registry; + dialectRegistration(registry); + context->appendDialectRegistry(registry); + auto module = + parseSourceFile(sourceMgr, context); + if (!module || failed(verify(*module))) + return failure(); + return function(module.get(), output); + }); } //===----------------------------------------------------------------------===// @@ -110,8 +146,11 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt) : llvm::cl::parser(opt) { - for (const auto &kv : getTranslationRegistry()) - addLiteralOption(kv.first(), &kv.second, kv.first()); + auto &translationDescription = getTranslationDescription(); + for (const auto &kv : getTranslationRegistry()) { + addLiteralOption(kv.first(), &kv.second, + translationDescription[kv.first()]); + } } void TranslationParser::printOptionInfo(const llvm::cl::Option &o,