diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -10,6 +10,8 @@ set(LLVM_TARGET_DEFINITIONS LLVMOps.td) mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions) +mlir_tablegen(LLVMConversionEnumsToLLVM.inc -gen-enum-to-llvmir-conversions) +mlir_tablegen(LLVMConversionEnumsFromLLVM.inc -gen-enum-from-llvmir-conversions) add_public_tablegen_target(MLIRLLVMConversionsIncGen) set(LLVM_TARGET_DEFINITIONS NVVMOps.td) mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -64,13 +64,36 @@ // Case of the LLVM enum attribute backed by I64Attr with customized string // representation that corresponds to what is visible in the textual IR form. -class LLVM_EnumAttrCase : - I64EnumAttrCase; +// The parameters are as follows: +// - `cppSym`: name of the C++ enumerant for this case in MLIR API; +// - `irSym`: keyword used in the custom form of MLIR operation; +// - `llvmSym`: name of the C++ enumerant for this case in LLVM API. +// For example, `LLVM_EnumAttrCase<"Weak", "weak", "WeakAnyLinkage">` is usable +// as `::Weak` in MLIR API, `WeakAnyLinkage` in LLVM API and +// is printed/parsed as `weak` in MLIR custom textual format. +class LLVM_EnumAttrCase : + I64EnumAttrCase { + + // The name of the equivalent enumerant in LLVM. + string llvmEnumerant = llvmSym; +} // LLVM enum attribute backed by I64Attr with string representation // corresponding to what is visible in the textual IR form. -class LLVM_EnumAttr cases> : - I64EnumAttr; + I64EnumAttr { + + // The equivalent enum class name in LLVM. + string llvmClassName = llvmName; +} #endif // LLVMIR_OP_BASE diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -495,22 +495,33 @@ // https://llvm.org/docs/LangRef.html#linkage-types. The names are equivalent to // visible names in the IR rather than to enum values names in llvm::GlobalValue // since the latter is easier to change. -def LinkagePrivate : LLVM_EnumAttrCase<"Private", "private", 0>; -def LinkageInternal : LLVM_EnumAttrCase<"Internal", "internal", 1>; -def LinkageAvailableExternally : LLVM_EnumAttrCase<"AvailableExternally", - "available_externally", 2>; -def LinkageLinkonce : LLVM_EnumAttrCase<"Linkonce", "linkonce", 3>; -def LinkageWeak : LLVM_EnumAttrCase<"Weak", "weak", 4>; -def LinkageCommon : LLVM_EnumAttrCase<"Common", "common", 5>; -def LinkageAppending : LLVM_EnumAttrCase<"Appending", "appending", 6>; -def LinkageExternWeak : LLVM_EnumAttrCase<"ExternWeak", - "extern_weak", 7>; -def LinkageLinkonceODR : LLVM_EnumAttrCase<"LinkonceODR", - "linkonce_odr", 8>; -def LinkageWeakODR : LLVM_EnumAttrCase<"WeakODR", "weak_odr", 9>; -def LinkageExternal : LLVM_EnumAttrCase<"External", "external", 10>; +def LinkagePrivate + : LLVM_EnumAttrCase<"Private", "private", "PrivateLinkage", 0>; +def LinkageInternal + : LLVM_EnumAttrCase<"Internal", "internal", "InternalLinkage", 1>; +def LinkageAvailableExternally + : LLVM_EnumAttrCase<"AvailableExternally", "available_externally", + "AvailableExternallyLinkage", 2>; +def LinkageLinkonce + : LLVM_EnumAttrCase<"Linkonce", "linkonce", "LinkOnceAnyLinkage", 3>; +def LinkageWeak + : LLVM_EnumAttrCase<"Weak", "weak", "WeakAnyLinkage", 4>; +def LinkageCommon + : LLVM_EnumAttrCase<"Common", "common", "CommonLinkage", 5>; +def LinkageAppending + : LLVM_EnumAttrCase<"Appending", "appending", "AppendingLinkage", 6>; +def LinkageExternWeak + : LLVM_EnumAttrCase<"ExternWeak", "extern_weak", "ExternalWeakLinkage", 7>; +def LinkageLinkonceODR + : LLVM_EnumAttrCase<"LinkonceODR", "linkonce_odr", "LinkOnceODRLinkage", 8>; +def LinkageWeakODR + : LLVM_EnumAttrCase<"WeakODR", "weak_odr", "WeakODRLinkage", 9>; +def LinkageExternal + : LLVM_EnumAttrCase<"External", "external", "ExternalLinkage", 10>; + def Linkage : LLVM_EnumAttr< "Linkage", + "::llvm::GlobalValue::LinkageTypes", "LLVM linkage types", [LinkagePrivate, LinkageInternal, LinkageAvailableExternally, LinkageLinkonce, LinkageWeak, LinkageCommon, LinkageAppending, diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -30,6 +30,8 @@ using namespace mlir; using namespace mlir::LLVM; +#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc" + // Utility to print an LLVM value as a string for passing to emitError(). // FIXME: Diagnostic should be able to natively handle types that have // operator << (raw_ostream&) defined. @@ -363,37 +365,6 @@ return nullptr; } -/// Converts LLVM global variable linkage type into the LLVM dialect predicate. -static LLVM::Linkage -processLinkage(llvm::GlobalVariable::LinkageTypes linkage) { - switch (linkage) { - case llvm::GlobalValue::PrivateLinkage: - return LLVM::Linkage::Private; - case llvm::GlobalValue::InternalLinkage: - return LLVM::Linkage::Internal; - case llvm::GlobalValue::AvailableExternallyLinkage: - return LLVM::Linkage::AvailableExternally; - case llvm::GlobalValue::LinkOnceAnyLinkage: - return LLVM::Linkage::Linkonce; - case llvm::GlobalValue::WeakAnyLinkage: - return LLVM::Linkage::Weak; - case llvm::GlobalValue::CommonLinkage: - return LLVM::Linkage::Common; - case llvm::GlobalValue::AppendingLinkage: - return LLVM::Linkage::Appending; - case llvm::GlobalValue::ExternalWeakLinkage: - return LLVM::Linkage::ExternWeak; - case llvm::GlobalValue::LinkOnceODRLinkage: - return LLVM::Linkage::LinkonceODR; - case llvm::GlobalValue::WeakODRLinkage: - return LLVM::Linkage::WeakODR; - case llvm::GlobalValue::ExternalLinkage: - return LLVM::Linkage::External; - } - - llvm_unreachable("unhandled linkage type"); -} - GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) { auto it = globals.find(GV); if (it != globals.end()) @@ -408,7 +379,7 @@ return nullptr; GlobalOp op = b.create( UnknownLoc::get(context), type, GV->isConstant(), - processLinkage(GV->getLinkage()), GV->getName(), valueAttr); + convertLinkageFromLLVM(GV->getLinkage()), GV->getName(), valueAttr); if (GV->hasInitializer() && !valueAttr) { Region &r = op.getInitializerRegion(); currentEntryBlock = b.createBlock(&r); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -31,6 +31,8 @@ using namespace mlir; using namespace mlir::LLVM; +#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc" + /// Builds a constant of a sequential LLVM type `type`, potentially containing /// other sequential types recursively, from the individual constant values /// provided in `constants`. `shape` contains the number of elements in nested @@ -400,35 +402,6 @@ return success(); } -/// Convert the LLVM dialect linkage type to LLVM IR linkage type. -llvm::GlobalVariable::LinkageTypes convertLinkageType(LLVM::Linkage linkage) { - switch (linkage) { - case LLVM::Linkage::Private: - return llvm::GlobalValue::PrivateLinkage; - case LLVM::Linkage::Internal: - return llvm::GlobalValue::InternalLinkage; - case LLVM::Linkage::AvailableExternally: - return llvm::GlobalValue::AvailableExternallyLinkage; - case LLVM::Linkage::Linkonce: - return llvm::GlobalValue::LinkOnceAnyLinkage; - case LLVM::Linkage::Weak: - return llvm::GlobalValue::WeakAnyLinkage; - case LLVM::Linkage::Common: - return llvm::GlobalValue::CommonLinkage; - case LLVM::Linkage::Appending: - return llvm::GlobalValue::AppendingLinkage; - case LLVM::Linkage::ExternWeak: - return llvm::GlobalValue::ExternalWeakLinkage; - case LLVM::Linkage::LinkonceODR: - return llvm::GlobalValue::LinkOnceODRLinkage; - case LLVM::Linkage::WeakODR: - return llvm::GlobalValue::WeakODRLinkage; - case LLVM::Linkage::External: - return llvm::GlobalValue::ExternalLinkage; - } - llvm_unreachable("unknown linkage type"); -} - /// Create named global variables that correspond to llvm.mlir.global /// definitions. void ModuleTranslation::convertGlobals() { @@ -458,7 +431,7 @@ cst = cast(valueMapping.lookup(ret.getOperand(0))); } - auto linkage = convertLinkageType(op.linkage()); + auto linkage = convertLinkageToLLVM(op.linkage()); bool anyExternalLinkage = (linkage == llvm::GlobalVariable::ExternalLinkage || linkage == llvm::GlobalVariable::ExternalWeakLinkage); diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -11,6 +11,8 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Support/LogicalResult.h" +#include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" @@ -171,6 +173,126 @@ return false; } +namespace { +// Wrapper class around a Tablegen definition of an LLVM enum attribute case. +class LLVMEnumAttrCase : public tblgen::EnumAttrCase { +public: + using tblgen::EnumAttrCase::EnumAttrCase; + + // Constructs a case from a non LLVM-specific enum attribute case. + explicit LLVMEnumAttrCase(const tblgen::EnumAttrCase &other) + : tblgen::EnumAttrCase(&other.getDef()) {} + + // Returns the C++ enumerant for the LLVM API. + StringRef getLLVMEnumerant() const { + return def->getValueAsString("llvmEnumerant"); + } +}; + +// Wraper class around a Tablegen definition of an LLVM enum attribute. +class LLVMEnumAttr : public tblgen::EnumAttr { +public: + using tblgen::EnumAttr::EnumAttr; + + // Returns the C++ enum name for the LLVM API. + StringRef getLLVMClassName() const { + return def->getValueAsString("llvmClassName"); + } + + // Returns all associated cases viewed as LLVM-specific enum cases. + std::vector getAllCases() const { + std::vector cases; + + for (auto &c : tblgen::EnumAttr::getAllCases()) + cases.push_back(LLVMEnumAttrCase(c)); + + return cases; + } +}; +} // namespace + +// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing +// switch-based logic to convert from the MLIR LLVM dialect enum attribute case +// (Enum) to the corresponding LLVM API enumerant +static void emitOneEnumToConversion(const llvm::Record *record, + raw_ostream &os) { + LLVMEnumAttr enumAttr(record); + StringRef llvmClass = enumAttr.getLLVMClassName(); + StringRef cppClassName = enumAttr.getEnumClassName(); + StringRef cppNamespace = enumAttr.getCppNamespace(); + + // Emit the function converting the enum attribute to its LLVM counterpart. + os << formatv("static {0} convert{1}ToLLVM({2}::{1} value) {{\n", llvmClass, + cppClassName, cppNamespace); + os << " switch (value) {\n"; + + for (const auto &enumerant : enumAttr.getAllCases()) { + StringRef llvmEnumerant = enumerant.getLLVMEnumerant(); + StringRef cppEnumerant = enumerant.getSymbol(); + os << formatv(" case {0}::{1}::{2}:\n", cppNamespace, cppClassName, + cppEnumerant); + os << formatv(" return {0}::{1};\n", llvmClass, llvmEnumerant); + } + + os << " }\n"; + os << formatv(" llvm_unreachable(\"unknown {0} type\");\n", + enumAttr.getEnumClassName()); + os << "}\n\n"; +} + +// Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and +// containing switch-based logic to convert from the LLVM API enumerant to MLIR +// LLVM dialect enum attribute (Enum). +static void emitOneEnumFromConversion(const llvm::Record *record, + raw_ostream &os) { + LLVMEnumAttr enumAttr(record); + StringRef llvmClass = enumAttr.getLLVMClassName(); + StringRef cppClassName = enumAttr.getEnumClassName(); + StringRef cppNamespace = enumAttr.getCppNamespace(); + + // Emit the function converting the enum attribute from its LLVM counterpart. + os << formatv("static {0}::{1} convert{1}FromLLVM({2} value) {{\n", + cppNamespace, cppClassName, llvmClass); + os << " switch (value) {\n"; + + for (const auto &enumerant : enumAttr.getAllCases()) { + StringRef llvmEnumerant = enumerant.getLLVMEnumerant(); + StringRef cppEnumerant = enumerant.getSymbol(); + os << formatv(" case {0}::{1}:\n", llvmClass, llvmEnumerant); + os << formatv(" return {0}::{1}::{2};\n", cppNamespace, cppClassName, + cppEnumerant); + } + + os << " }\n"; + os << formatv(" llvm_unreachable(\"unknown {0} type\");", + enumAttr.getLLVMClassName()); + os << "}\n\n"; +} + +// Emits conversion functions between MLIR enum attribute case and corresponding +// LLVM API enumerants for all registered LLVM dialect enum attributes. +template +static bool emitEnumConversionDefs(const RecordKeeper &recordKeeper, + raw_ostream &os) { + for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_EnumAttr")) + if (ConvertTo) + emitOneEnumToConversion(def, os); + else + emitOneEnumFromConversion(def, os); + + return false; +} + static mlir::GenRegistration genLLVMIRConversions("gen-llvmir-conversions", "Generate LLVM IR conversions", emitBuilders); + +static mlir::GenRegistration + genEnumToLLVMConversion("gen-enum-to-llvmir-conversions", + "Generate conversions of EnumAttrs to LLVM IR", + emitEnumConversionDefs); + +static mlir::GenRegistration + genEnumFromLLVMConversion("gen-enum-from-llvmir-conversions", + "Generate conversions of EnumAttrs from LLVM IR", + emitEnumConversionDefs);