diff --git a/mlir/include/mlir/Translation.h b/mlir/include/mlir/Translation.h --- a/mlir/include/mlir/Translation.h +++ b/mlir/include/mlir/Translation.h @@ -12,9 +12,7 @@ #ifndef MLIR_TRANSLATION_H #define MLIR_TRANSLATION_H -#include "llvm/ADT/StringMap.h" - -#include +#include "llvm/Support/CommandLine.h" namespace llvm { class MemoryBuffer; @@ -82,13 +80,13 @@ }; /// \} -/// Get a read-only reference to the translator registry. -const llvm::StringMap & -getTranslationToMLIRRegistry(); -const llvm::StringMap & -getTranslationFromMLIRRegistry(); -const llvm::StringMap &getTranslationRegistry(); +/// A command line parser for translation functions. +struct TranslationParser : public llvm::cl::parser { + TranslationParser(llvm::cl::Option &opt); + void printOptionInfo(const llvm::cl::Option &o, + size_t globalWidth) const override; +}; } // namespace mlir #endif // MLIR_TRANSLATION_H diff --git a/mlir/lib/Support/CMakeLists.txt b/mlir/lib/Support/CMakeLists.txt --- a/mlir/lib/Support/CMakeLists.txt +++ b/mlir/lib/Support/CMakeLists.txt @@ -4,7 +4,6 @@ MlirOptMain.cpp StorageUniquer.cpp ToolUtilities.cpp - TranslateClParser.cpp ) add_mlir_library(MLIRSupport @@ -34,19 +33,6 @@ MLIRSupport ) -add_mlir_library(MLIRTranslateClParser - TranslateClParser.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Support - ) -target_link_libraries(MLIRTranslateClParser - PUBLIC - LLVMSupport - MLIRIR - MLIRTranslation - MLIRParser) - add_llvm_library(MLIRJitRunner JitRunner.cpp ) diff --git a/mlir/lib/Translation/CMakeLists.txt b/mlir/lib/Translation/CMakeLists.txt --- a/mlir/lib/Translation/CMakeLists.txt +++ b/mlir/lib/Translation/CMakeLists.txt @@ -8,4 +8,5 @@ PUBLIC LLVMSupport MLIRIR + MLIRParser ) diff --git a/mlir/lib/Translation/Translation.cpp b/mlir/lib/Translation/Translation.cpp --- a/mlir/lib/Translation/Translation.cpp +++ b/mlir/lib/Translation/Translation.cpp @@ -11,46 +11,59 @@ //===----------------------------------------------------------------------===// #include "mlir/Translation.h" +#include "mlir/Analysis/Verifier.h" #include "mlir/IR/Module.h" +#include "mlir/Parser.h" #include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" #include "llvm/Support/SourceMgr.h" using namespace mlir; -// Get the mutable static map between registered "to MLIR" translations and the -// TranslateToMLIRFunctions that perform those translations. -static llvm::StringMap & -getMutableTranslationToMLIRRegistry() { - static llvm::StringMap - translationToMLIRRegistry; - return translationToMLIRRegistry; -} -// Get the mutable static map between registered "from MLIR" translations and -// the TranslateFromMLIRFunctions that perform those translations. -static llvm::StringMap & -getMutableTranslationFromMLIRRegistry() { - static llvm::StringMap translationFromMLIRRegistry; - return translationFromMLIRRegistry; -} +//===----------------------------------------------------------------------===// +// Translation Registry +//===----------------------------------------------------------------------===// -// Get the mutable static map between registered file-to-file MLIR translations -// and the TranslateFunctions that perform those translations. -static llvm::StringMap &getMutableTranslationRegistry() { +/// Get the mutable static map between registered file-to-file MLIR translations +/// and the TranslateFunctions that perform those translations. +static llvm::StringMap &getTranslationRegistry() { static llvm::StringMap translationRegistry; return translationRegistry; } +/// Register the given translation. +static void registerTranslation(StringRef name, + const TranslateFunction &function) { + auto &translationRegistry = getTranslationRegistry(); + if (translationRegistry.find(name) != translationRegistry.end()) + llvm::report_fatal_error( + "Attempting to overwrite an existing function"); + assert(function && + "Attempting to register an empty translate function"); + translationRegistry[name] = function; +} + +TranslateRegistration::TranslateRegistration( + StringRef name, const TranslateFunction &function) { + registerTranslation(name, function); +} + +//===----------------------------------------------------------------------===// +// Translation to MLIR +//===----------------------------------------------------------------------===// + // Puts `function` into the to-MLIR translation registry unless there is already // a function registered for the same name. static void registerTranslateToMLIRFunction( StringRef name, const TranslateSourceMgrToMLIRFunction &function) { - auto &translationToMLIRRegistry = getMutableTranslationToMLIRRegistry(); - if (translationToMLIRRegistry.find(name) != translationToMLIRRegistry.end()) - llvm::report_fatal_error( - "Attempting to overwrite an existing function"); - assert(function && "Attempting to register an empty translate function"); - translationToMLIRRegistry[name] = function; + auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output, + MLIRContext *context) { + OwningModuleRef module = function(sourceMgr, context); + if (!module || failed(verify(*module))) + return failure(); + module->print(output); + return success(); + }; + registerTranslation(name, wrappedFn); } TranslateToMLIRRegistration::TranslateToMLIRRegistration( @@ -58,54 +71,51 @@ registerTranslateToMLIRFunction(name, function); } -// Wraps `function` with a lambda that extracts a StringRef from a source -// manager and registers the wrapper lambda as a to-MLIR conversion. +/// Wraps `function` with a lambda that extracts a StringRef from a source +/// manager and registers the wrapper lambda as a to-MLIR conversion. TranslateToMLIRRegistration::TranslateToMLIRRegistration( StringRef name, const TranslateStringRefToMLIRFunction &function) { - auto translationFunction = [function](llvm::SourceMgr &sourceMgr, - MLIRContext *ctx) { - const llvm::MemoryBuffer *buffer = - sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); - return function(buffer->getBuffer(), ctx); - }; - registerTranslateToMLIRFunction(name, translationFunction); + registerTranslateToMLIRFunction( + name, [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) { + const llvm::MemoryBuffer *buffer = + sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); + return function(buffer->getBuffer(), ctx); + }); } +//===----------------------------------------------------------------------===// +// Translation from MLIR +//===----------------------------------------------------------------------===// + TranslateFromMLIRRegistration::TranslateFromMLIRRegistration( StringRef name, const TranslateFromMLIRFunction &function) { - auto &translationFromMLIRRegistry = getMutableTranslationFromMLIRRegistry(); - if (translationFromMLIRRegistry.find(name) != - translationFromMLIRRegistry.end()) - llvm::report_fatal_error( - "Attempting to overwrite an existing function"); - assert(function && - "Attempting to register an empty translate function"); - translationFromMLIRRegistry[name] = function; -} - -TranslateRegistration::TranslateRegistration( - StringRef name, const TranslateFunction &function) { - auto &translationRegistry = getMutableTranslationRegistry(); - if (translationRegistry.find(name) != translationRegistry.end()) - llvm::report_fatal_error( - "Attempting to overwrite an existing function"); - assert(function && - "Attempting to register an empty translate function"); - translationRegistry[name] = function; + registerTranslation(name, [function](llvm::SourceMgr &sourceMgr, + raw_ostream &output, + MLIRContext *context) { + auto module = OwningModuleRef(parseSourceFile(sourceMgr, context)); + if (!module) + return failure(); + return function(module.get(), output); + }); } -// Merely add the const qualifier to the mutable registry so that external users -// cannot modify it. -const llvm::StringMap & -mlir::getTranslationToMLIRRegistry() { - return getMutableTranslationToMLIRRegistry(); -} +//===----------------------------------------------------------------------===// +// Translation Parser +//===----------------------------------------------------------------------===// -const llvm::StringMap & -mlir::getTranslationFromMLIRRegistry() { - return getMutableTranslationFromMLIRRegistry(); +TranslationParser::TranslationParser(llvm::cl::Option &opt) + : llvm::cl::parser(opt) { + for (const auto &kv : getTranslationRegistry()) + addLiteralOption(kv.first(), &kv.second, kv.first()); } -const llvm::StringMap &mlir::getTranslationRegistry() { - return getMutableTranslationRegistry(); +void TranslationParser::printOptionInfo(const llvm::cl::Option &o, + size_t globalWidth) const { + TranslationParser *tp = const_cast(this); + llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(), + [](const TranslationParser::OptionInfo *lhs, + const TranslationParser::OptionInfo *rhs) { + return lhs->Name.compare(rhs->Name); + }); + llvm::cl::parser::printOptionInfo(o, globalWidth); } diff --git a/mlir/tools/mlir-translate/CMakeLists.txt b/mlir/tools/mlir-translate/CMakeLists.txt --- a/mlir/tools/mlir-translate/CMakeLists.txt +++ b/mlir/tools/mlir-translate/CMakeLists.txt @@ -24,4 +24,4 @@ ) llvm_update_compile_flags(mlir-translate) whole_archive_link(mlir-translate ${FULL_LIBS}) -target_link_libraries(mlir-translate PRIVATE MLIRIR MLIRTranslateClParser ${LIBS} LLVMSupport) +target_link_libraries(mlir-translate PRIVATE MLIRIR MLIRTranslation ${LIBS} LLVMSupport) diff --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp --- a/mlir/tools/mlir-translate/mlir-translate.cpp +++ b/mlir/tools/mlir-translate/mlir-translate.cpp @@ -17,7 +17,7 @@ #include "mlir/Support/FileUtilities.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/ToolUtilities.h" -#include "mlir/Support/TranslateClParser.h" +#include "mlir/Translation.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h"