diff --git a/mlir/include/mlir/Target/LLVMIR.h b/mlir/include/mlir/Target/LLVMIR.h --- a/mlir/include/mlir/Target/LLVMIR.h +++ b/mlir/include/mlir/Target/LLVMIR.h @@ -27,14 +27,14 @@ class OwningModuleRef; class MLIRContext; -class ModuleOp; +class Operation; /// Convert the given MLIR module into LLVM IR. The LLVM context is extracted /// from the registered LLVM IR dialect. In case of error, report it /// to the error handler registered with the MLIR context, if any (obtained from /// the MLIR module), and return `nullptr`. std::unique_ptr -translateModuleToLLVMIR(ModuleOp m, llvm::LLVMContext &llvmContext, +translateModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext, StringRef name = "LLVMDialectModule"); /// Convert the given LLVM module into MLIR's LLVM dialect. The LLVM context is diff --git a/mlir/include/mlir/Translation.h b/mlir/include/mlir/Translation.h --- a/mlir/include/mlir/Translation.h +++ b/mlir/include/mlir/Translation.h @@ -15,7 +15,9 @@ #include "llvm/Support/CommandLine.h" namespace llvm { +class LLVMContext; class MemoryBuffer; +class Module; class SourceMgr; class StringRef; } // namespace llvm @@ -25,6 +27,7 @@ struct LogicalResult; class MLIRContext; class ModuleOp; +class Operation; class OwningModuleRef; /// Interface of the function that translates the sources managed by `sourceMgr` @@ -54,10 +57,20 @@ using TranslateFunction = std::function; +/// Interface of the function that translates an MLIR module into an LLVM +/// module. Such functions can be passed to mlir-cpu-runner to generate LLVM IR +/// for a specific target architecture. +using TranslateMLIRToLLVMFunction = std::function( + Operation *, llvm::LLVMContext &)>; + /// Use Translate[ToMLIR|FromMLIR]Registration as an initializer that /// registers a function and associates it with name. This requires that a /// translation has not been registered to a given name. /// +/// Use TranslateFromMLIRToLLVMRegistration for MLIR to LLVMIR translations. +/// Such translations are made available to mlir-cpu-runner and can be +/// specified via command line parameter. +/// /// Usage: /// /// // At file scope. @@ -81,20 +94,37 @@ std::function dialectRegistration = [](DialectRegistry &) {}); }; +struct TranslateFromMLIRToLLVMRegistration + : public TranslateFromMLIRRegistration { + TranslateFromMLIRToLLVMRegistration( + llvm::StringRef name, const TranslateMLIRToLLVMFunction &function, + std::function dialectRegistration = + [](DialectRegistry &) {}); +}; struct TranslateRegistration { TranslateRegistration(llvm::StringRef name, const TranslateFunction &function); }; /// \} -/// A command line parser for translation functions. -struct TranslationParser : public llvm::cl::parser { - TranslationParser(llvm::cl::Option &opt); +/// A command line parser for translation functions. F specifies the type of +/// translations: TranslateFunction or TranslateMLIRToLLVMFunction. +template +struct TranslationParserBase : public llvm::cl::parser { + TranslationParserBase(llvm::cl::Option &opt); void printOptionInfo(const llvm::cl::Option &o, size_t globalWidth) const override; }; +/// A command line parser for general translation functions. +using TranslationParser = TranslationParserBase; + +/// A command line parser for MLIR to LLVMIR translations to be used with +/// mlir-cpu-runner. +using MLIRToLLVMTranslationParser = + TranslationParserBase; + /// Translate to/from an MLIR module from/to an external representation (e.g. /// LLVM IR, SPIRV binary, ...). This is the entry point for the implementation /// of tools like `mlir-translate`. The translation to perform is parsed from diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -24,6 +24,7 @@ #include "mlir/InitAllDialects.h" #include "mlir/Parser.h" #include "mlir/Support/FileUtilities.h" +#include "mlir/Translation.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" @@ -91,6 +92,12 @@ llvm::cl::opt objectFilename{ "object-filename", llvm::cl::desc("Dump JITted-compiled object to file .o")}; + + // Add flags for all the registered MLIR to LLVMIR translations. + llvm::cl::opt + translationRequested{"translation", + llvm::cl::desc("Translation to perform")}; }; struct CompileAndExecuteConfig { @@ -362,6 +369,12 @@ compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder; compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap; + if (const auto *f = options.translationRequested.getValue()) { + assert(!compileAndExecuteConfig.llvmModuleBuilder && + "Competing translations specified via both command line and config"); + compileAndExecuteConfig.llvmModuleBuilder = *f; + } + // Get the function used to compile and execute the module. using CompileAndExecuteFnT = Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig); diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -21,33 +21,24 @@ #include "llvm/IR/Verifier.h" #include "llvm/Support/ToolOutputFile.h" -using namespace mlir; - +namespace mlir { std::unique_ptr -mlir::translateModuleToLLVMIR(ModuleOp m, llvm::LLVMContext &llvmContext, - StringRef name) { +translateModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext, + StringRef name) { auto llvmModule = LLVM::ModuleTranslation::translateModule<>(m, llvmContext, name); if (!llvmModule) - emitError(m.getLoc(), "Fail to convert MLIR to LLVM IR"); + emitError(m->getLoc(), "Fail to convert MLIR to LLVM IR"); else if (verifyModule(*llvmModule)) - emitError(m.getLoc(), "LLVM IR fails to verify"); + emitError(m->getLoc(), "LLVM IR fails to verify"); return llvmModule; } -namespace mlir { void registerToLLVMIRTranslation() { - TranslateFromMLIRRegistration registration( + TranslateFromMLIRToLLVMRegistration( "mlir-to-llvmir", - [](ModuleOp module, raw_ostream &output) { - llvm::LLVMContext llvmContext; - auto llvmModule = LLVM::ModuleTranslation::translateModule<>( - module, llvmContext, "LLVMDialectModule"); - if (!llvmModule) - return failure(); - - llvmModule->print(output, nullptr); - return success(); + [](Operation *m, llvm::LLVMContext &llvmContext) { + return translateModuleToLLVMIR(m, llvmContext, "LLVMDialectModule"); }, [](DialectRegistry ®istry) { registry.insert(); diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -98,16 +98,10 @@ namespace mlir { void registerToNVVMIRTranslation() { - TranslateFromMLIRRegistration registration( + TranslateFromMLIRToLLVMRegistration registration( "mlir-to-nvvmir", - [](ModuleOp module, raw_ostream &output) { - llvm::LLVMContext llvmContext; - auto llvmModule = mlir::translateModuleToNVVMIR(module, llvmContext); - if (!llvmModule) - return failure(); - - llvmModule->print(output, nullptr); - return success(); + [](Operation *m, llvm::LLVMContext &llvmContext) { + return translateModuleToNVVMIR(m, llvmContext); }, [](DialectRegistry ®istry) { registry.insert(); diff --git a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp @@ -102,16 +102,10 @@ namespace mlir { void registerToROCDLIRTranslation() { - TranslateFromMLIRRegistration registration( + TranslateFromMLIRToLLVMRegistration registration( "mlir-to-rocdlir", - [](ModuleOp module, raw_ostream &output) { - llvm::LLVMContext llvmContext; - auto llvmModule = mlir::translateModuleToROCDLIR(module, llvmContext); - if (!llvmModule) - return failure(); - - llvmModule->print(output, nullptr); - return success(); + [](Operation *m, llvm::LLVMContext &llvmContext) { + return translateModuleToROCDLIR(m, llvmContext); }, [](DialectRegistry ®istry) { registry.insert(); diff --git a/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp --- a/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp @@ -33,28 +33,15 @@ return LLVM::ModuleTranslation::convertOperation(opInst, builder); } }; - -std::unique_ptr -translateLLVMAVX512ModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext, - StringRef name) { - return LLVM::ModuleTranslation::translateModule( - m, llvmContext, name); -} } // end namespace namespace mlir { void registerAVX512ToLLVMIRTranslation() { - TranslateFromMLIRRegistration reg( + TranslateFromMLIRToLLVMRegistration( "avx512-mlir-to-llvmir", - [](ModuleOp module, raw_ostream &output) { - llvm::LLVMContext llvmContext; - auto llvmModule = translateLLVMAVX512ModuleToLLVMIR( - module, llvmContext, "LLVMDialectModule"); - if (!llvmModule) - return failure(); - - llvmModule->print(output, nullptr); - return success(); + [](Operation *m, llvm::LLVMContext &llvmContext) { + return LLVM::ModuleTranslation::translateModule< + LLVMAVX512ModuleTranslation>(m, llvmContext, "LLVMDialectModule"); }, [](DialectRegistry ®istry) { registry.insert(); diff --git a/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp b/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp --- a/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp @@ -44,17 +44,11 @@ namespace mlir { void registerArmNeonToLLVMIRTranslation() { - TranslateFromMLIRRegistration reg( + TranslateFromMLIRToLLVMRegistration reg( "arm-neon-mlir-to-llvmir", - [](ModuleOp module, raw_ostream &output) { - llvm::LLVMContext llvmContext; - auto llvmModule = translateLLVMArmNeonModuleToLLVMIR( - module, llvmContext, "LLVMDialectModule"); - if (!llvmModule) - return failure(); - - llvmModule->print(output, nullptr); - return success(); + [](Operation *m, llvm::LLVMContext &llvmContext) { + return translateLLVMArmNeonModuleToLLVMIR(m, llvmContext, + "LLVMDialectModule"); }, [](DialectRegistry ®istry) { registry.insert(); diff --git a/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp b/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp --- a/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp @@ -44,17 +44,11 @@ namespace mlir { void registerArmSVEToLLVMIRTranslation() { - TranslateFromMLIRRegistration reg( + TranslateFromMLIRToLLVMRegistration reg( "arm-sve-mlir-to-llvmir", - [](ModuleOp module, raw_ostream &output) { - llvm::LLVMContext llvmContext; - auto llvmModule = translateLLVMArmSVEModuleToLLVMIR( - module, llvmContext, "LLVMDialectModule"); - if (!llvmModule) - return failure(); - - llvmModule->print(output, nullptr); - return success(); + [](Operation *m, llvm::LLVMContext &llvmContext) { + return translateLLVMArmSVEModuleToLLVMIR(m, llvmContext, + "LLVMDialectModule"); }, [](DialectRegistry ®istry) { registry.insert(); diff --git a/mlir/lib/Translation/Translation.cpp b/mlir/lib/Translation/Translation.cpp --- a/mlir/lib/Translation/Translation.cpp +++ b/mlir/lib/Translation/Translation.cpp @@ -17,6 +17,9 @@ #include "mlir/Parser.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/ToolUtilities.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/IR/LLVMContext.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" @@ -27,28 +30,38 @@ // Translation Registry //===----------------------------------------------------------------------===// -/// Get the mutable static map between registered file-to-file MLIR translations -/// and the TranslateFunctions that perform those translations. -static llvm::StringMap &getTranslationRegistry() { - static llvm::StringMap translationRegistry; - return translationRegistry; -} +/// A TranslationRegistry stores mappings betwen MLIR translations and the +/// functions that perform those translations. There are two kind of +/// translations: +/// +/// * File-to-file MLIR translations (F = TranslateFunctions) +/// * MLIR module to LLVM module translations (F = TranslateMLIRToLLVMFunction) +/// +/// Only the second kind of translations can be used with mlir-cpu-runner. +template +struct TranslationRegistry { + /// Get the mutable static map between registered MLIR translations and the + /// functions that perform those translations. + static llvm::StringMap &get() { + static llvm::StringMap translationRegistry; + return translationRegistry; + } -/// Register the given translation. -static void registerTranslation(StringRef name, - const TranslateFunction &function) { - auto &translationRegistry = getTranslationRegistry(); - if (translationRegistry.find(name) != translationRegistry.end()) - llvm::report_fatal_error( - "Attempting to overwrite an existing function"); - assert(function && - "Attempting to register an empty translate function"); - translationRegistry[name] = function; -} + /// Register the given translation. + static void registerTranslation(StringRef name, const F &function) { + auto &translationRegistry = get(); + if (translationRegistry.find(name) != translationRegistry.end()) + llvm::report_fatal_error( + "Attempting to overwrite an existing function"); + assert(function && + "Attempting to register an empty translate function"); + translationRegistry[name] = function; + } +}; TranslateRegistration::TranslateRegistration( StringRef name, const TranslateFunction &function) { - registerTranslation(name, function); + TranslationRegistry::registerTranslation(name, function); } //===----------------------------------------------------------------------===// @@ -67,7 +80,7 @@ module->print(output); return success(); }; - registerTranslation(name, wrappedFn); + TranslationRegistry::registerTranslation(name, wrappedFn); } TranslateToMLIRRegistration::TranslateToMLIRRegistration( @@ -94,38 +107,74 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration( StringRef name, const TranslateFromMLIRFunction &function, std::function dialectRegistration) { - registerTranslation(name, [function, dialectRegistration]( - llvm::SourceMgr &sourceMgr, raw_ostream &output, - MLIRContext *context) { + auto fn = [function, dialectRegistration](llvm::SourceMgr &sourceMgr, + raw_ostream &output, + MLIRContext *context) { dialectRegistration(context->getDialectRegistry()); auto module = OwningModuleRef(parseSourceFile(sourceMgr, context)); if (!module) return failure(); return function(module.get(), output); - }); + }; + TranslationRegistry::registerTranslation(name, fn); +} + +//===----------------------------------------------------------------------===// +// Translation from MLIR to LLVM IR +//===----------------------------------------------------------------------===// + +TranslateFromMLIRToLLVMRegistration::TranslateFromMLIRToLLVMRegistration( + StringRef name, const TranslateMLIRToLLVMFunction &function, + std::function dialectRegistration) + : TranslateFromMLIRRegistration( + /*name=*/name, + /*function=*/ + [function, dialectRegistration](ModuleOp module, + llvm::raw_ostream &output) { + llvm::LLVMContext llvmContext; + auto llvmModule = function(module, llvmContext); + if (!llvmModule) + return failure(); + + llvmModule->print(output, nullptr); + return success(); + }, + /*dialectRegistration=*/dialectRegistration) { + TranslationRegistry::registerTranslation( + name, function); } //===----------------------------------------------------------------------===// // Translation Parser //===----------------------------------------------------------------------===// -TranslationParser::TranslationParser(llvm::cl::Option &opt) - : llvm::cl::parser(opt) { - for (const auto &kv : getTranslationRegistry()) - addLiteralOption(kv.first(), &kv.second, kv.first()); +template +TranslationParserBase::TranslationParserBase(llvm::cl::Option &opt) + : llvm::cl::parser(opt) { + for (const auto &kv : TranslationRegistry::get()) + this->addLiteralOption(kv.first(), &kv.second, kv.first()); } -void TranslationParser::printOptionInfo(const llvm::cl::Option &o, - size_t globalWidth) const { - TranslationParser *tp = const_cast(this); - llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(), - [](const TranslationParser::OptionInfo *lhs, - const TranslationParser::OptionInfo *rhs) { - return lhs->Name.compare(rhs->Name); - }); - llvm::cl::parser::printOptionInfo(o, globalWidth); +template +void TranslationParserBase::printOptionInfo(const llvm::cl::Option &o, + size_t globalWidth) const { + TranslationParserBase *tp = const_cast *>(this); + llvm::array_pod_sort( + tp->Values.begin(), tp->Values.end(), + [](const typename TranslationParserBase::OptionInfo *lhs, + const typename TranslationParserBase::OptionInfo *rhs) { + return lhs->Name.compare(rhs->Name); + }); + llvm::cl::parser::printOptionInfo(o, globalWidth); } +//===----------------------------------------------------------------------===// +// Explicit Translation Parser Instantiations +//===----------------------------------------------------------------------===// + +template class mlir::TranslationParserBase; +template class mlir::TranslationParserBase; + LogicalResult mlir::mlirTranslateMain(int argc, char **argv, llvm::StringRef toolName) { diff --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp --- a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp +++ b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp @@ -15,6 +15,7 @@ #include "mlir/ExecutionEngine/JitRunner.h" #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/InitAllDialects.h" +#include "mlir/InitAllTranslations.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" @@ -24,6 +25,7 @@ llvm::InitializeNativeTargetAsmPrinter(); llvm::InitializeNativeTargetAsmParser(); mlir::initializeLLVMPasses(); + mlir::registerAllTranslations(); return mlir::JitRunnerMain(argc, argv); }