diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -62,4 +62,15 @@ class LLVM_IntrOp traits = []> : LLVM_Op<"intr."#mnemonic, traits>; +// Case of the LLVM enum attribute backed by I64Attr with customized string +// representation that corresponds to what is visible in the textual IR form. +class LLVM_EnumAttrCase : + I64EnumAttrCase; + +// LLVM enum attribute backed by I64Attr with string representation +// corresponding to what is visible in the textual IR form. +class LLVM_EnumAttr cases> : + I64EnumAttr; + #endif // LLVMIR_OP_BASE diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -494,18 +494,21 @@ // https://llvm.org/docs/LangRef.html#linkage-types. The names are equivalent to // visible names in the IR rather than to enum values names in llvm::GlobalValue // since the latter is easier to change. -def LinkagePrivate : I64EnumAttrCase<"Private", 0>; -def LinkageInternal : I64EnumAttrCase<"Internal", 1>; -def LinkageAvailableExternally : I64EnumAttrCase<"AvailableExternally", 2>; -def LinkageLinkonce : I64EnumAttrCase<"Linkonce", 3>; -def LinkageWeak : I64EnumAttrCase<"Weak", 4>; -def LinkageCommon : I64EnumAttrCase<"Common", 5>; -def LinkageAppending : I64EnumAttrCase<"Appending", 6>; -def LinkageExternWeak : I64EnumAttrCase<"ExternWeak", 7>; -def LinkageLinkonceODR : I64EnumAttrCase<"LinkonceODR", 8>; -def LinkageWeakODR : I64EnumAttrCase<"WeakODR", 9>; -def LinkageExternal : I64EnumAttrCase<"External", 10>; -def Linkage : I64EnumAttr< +def LinkagePrivate : LLVM_EnumAttrCase<"Private", "private", 0>; +def LinkageInternal : LLVM_EnumAttrCase<"Internal", "internal", 1>; +def LinkageAvailableExternally : LLVM_EnumAttrCase<"AvailableExternally", + "available_externally", 2>; +def LinkageLinkonce : LLVM_EnumAttrCase<"Linkonce", "linkonce", 3>; +def LinkageWeak : LLVM_EnumAttrCase<"Weak", "weak", 4>; +def LinkageCommon : LLVM_EnumAttrCase<"Common", "common", 5>; +def LinkageAppending : LLVM_EnumAttrCase<"Appending", "appending", 6>; +def LinkageExternWeak : LLVM_EnumAttrCase<"ExternWeak", + "extern_weak", 7>; +def LinkageLinkonceODR : LLVM_EnumAttrCase<"LinkonceODR", + "linkonce_odr", 8>; +def LinkageWeakODR : LLVM_EnumAttrCase<"WeakODR", "weak_odr", 9>; +def LinkageExternal : LLVM_EnumAttrCase<"External", "external", 10>; +def Linkage : LLVM_EnumAttr< "Linkage", "LLVM linkage types", [LinkagePrivate, LinkageInternal, LinkageAvailableExternally, diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -808,39 +808,47 @@ // Enum attribute kinds // Additional information for an enum attribute case. -class EnumAttrCaseInfo { - // The C++ enumerant symbol +class EnumAttrCaseInfo { + // The C++ enumerant symbol. string symbol = sym; - // The C++ enumerant value + // The C++ enumerant value. // If less than zero, there will be no explicit discriminator values assigned // to enumerators in the generated enum class. - int value = val; + int value = intVal; + + // The string representation of the enumerant. May be the same as symbol. + string str = strVal; } // An enum attribute case stored with StringAttr. class StrEnumAttrCase : - EnumAttrCaseInfo, + EnumAttrCaseInfo, StringBasedAttr< CPred<"$_self.cast().getValue() == \"" # sym # "\"">, "case " # sym>; -// An enum attribute case stored with IntegerAttr. -class IntEnumAttrCaseBase : - EnumAttrCaseInfo, - IntegerAttrBase { +// An enum attribute case stored with IntegerAttr, which has an integer value, +// its representation as a string and a C++ symbol name which may be different. +class IntEnumAttrCaseBase : + EnumAttrCaseInfo, + IntegerAttrBase { let predicate = - CPred<"$_self.cast().getInt() == " # val>; + CPred<"$_self.cast().getInt() == " # intVal>; } -class I32EnumAttrCase : IntEnumAttrCaseBase; -class I64EnumAttrCase : IntEnumAttrCaseBase; +// Cases of integer enum attributes with a specific type. By default, the string +// representation is the same as the C++ symbol name. +class I32EnumAttrCase + : IntEnumAttrCaseBase; +class I64EnumAttrCase + : IntEnumAttrCaseBase; // A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the // ordinal number of the bit that is set. It is the 32-bit integer with only // one bit set. class BitEnumAttrCase : - EnumAttrCaseInfo, + EnumAttrCaseInfo, IntegerAttrBase { let predicate = CPred< "$_self.cast().getValue().getZExtValue() & " # val # "u">; diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -134,6 +134,9 @@ // Returns the symbol of this enum attribute case. StringRef getSymbol() const; + // Returns the textual representation of this enum attribute case. + StringRef getStr() const; + // Returns the value of this enum attribute case. int64_t getValue() const; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1079,44 +1079,8 @@ result.addRegion(); } -// Returns the textual representation of the given linkage. -static StringRef linkageToStr(LLVM::Linkage linkage) { - switch (linkage) { - case LLVM::Linkage::Private: - return "private"; - case LLVM::Linkage::Internal: - return "internal"; - case LLVM::Linkage::AvailableExternally: - return "available_externally"; - case LLVM::Linkage::Linkonce: - return "linkonce"; - case LLVM::Linkage::Weak: - return "weak"; - case LLVM::Linkage::Common: - return "common"; - case LLVM::Linkage::Appending: - return "appending"; - case LLVM::Linkage::ExternWeak: - return "extern_weak"; - case LLVM::Linkage::LinkonceODR: - return "linkonce_odr"; - case LLVM::Linkage::WeakODR: - return "weak_odr"; - case LLVM::Linkage::External: - return "external"; - } - llvm_unreachable("unknown linkage type"); -} - -// Prints the keyword for the linkage type using the printer. -static void printLinkage(OpAsmPrinter &p, LLVM::Linkage linkage) { - p << linkageToStr(linkage); -} - static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { - p << op.getOperationName() << ' '; - printLinkage(p, op.linkage()); - p << ' '; + p << op.getOperationName() << ' ' << stringifyLinkage(op.linkage()) << ' '; if (op.constant()) p << "constant "; p.printSymbolName(op.sym_name()); @@ -1150,22 +1114,30 @@ return -1; } -// Parses one of the linkage keywords and, if succeeded, appends the "linkage" -// integer attribute with the corresponding value to `result`. -// -// linkage ::= `private` | `internal` | `available_externally` | `linkonce` -// | `weak` | `common` | `appending` | `extern_weak` -// | `linkonce_odr` | `weak_odr` | `external -static ParseResult parseOptionalLinkageKeyword(OpAsmParser &parser, - OperationState &result) { - int index = parseOptionalKeywordAlternative( - parser, {"private", "internal", "available_externally", "linkonce", - "weak", "common", "appending", "extern_weak", "linkonce_odr", - "weak_odr", "external"}); +namespace { +template struct EnumTraits {}; + +#define REGISTER_ENUM_TYPE(Ty) \ + template <> struct EnumTraits { \ + static StringRef stringify(Ty value) { return stringify##Ty(value); } \ + static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ + } + +REGISTER_ENUM_TYPE(Linkage); +} // end namespace + +template +static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser, + OperationState &result, + StringRef name) { + SmallVector names; + for (unsigned i = 0, e = getMaxEnumValForLinkage(); i <= e; ++i) + names.push_back(EnumTraits::stringify(static_cast(i))); + + int index = parseOptionalKeywordAlternative(parser, names); if (index == -1) return failure(); - result.addAttribute(getLinkageAttrName(), - parser.getBuilder().getI64IntegerAttr(index)); + result.addAttribute(name, parser.getBuilder().getI64IntegerAttr(index)); return success(); } @@ -1175,7 +1147,8 @@ // The type can be omitted for string attributes, in which case it will be // inferred from the value of the string as [strlen(value) x i8]. static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { - if (failed(parseOptionalLinkageKeyword(parser, result))) + if (failed(parseOptionalLLVMKeyword(parser, result, + getLinkageAttrName()))) return parser.emitError(parser.getCurrentLocation(), "expected linkage"); if (succeeded(parser.parseOptionalKeyword("constant"))) @@ -1398,7 +1371,8 @@ static ParseResult parseLLVMFuncOp(OpAsmParser &parser, OperationState &result) { // Default to external linkage if no keyword is provided. - if (failed(parseOptionalLinkageKeyword(parser, result))) + if (failed(parseOptionalLLVMKeyword(parser, result, + getLinkageAttrName()))) result.addAttribute(getLinkageAttrName(), parser.getBuilder().getI64IntegerAttr( static_cast(LLVM::Linkage::External))); @@ -1441,10 +1415,8 @@ // the external linkage since it is the default value. static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) { p << op.getOperationName() << ' '; - if (op.linkage() != LLVM::Linkage::External) { - printLinkage(p, op.linkage()); - p << ' '; - } + if (op.linkage() != LLVM::Linkage::External) + p << stringifyLinkage(op.linkage()) << ' '; p.printSymbolName(op.getName()); LLVMType fnType = op.getType(); @@ -1510,16 +1482,16 @@ static LogicalResult verify(LLVMFuncOp op) { if (op.linkage() == LLVM::Linkage::Common) return op.emitOpError() - << "functions cannot have '" << linkageToStr(LLVM::Linkage::Common) - << "' linkage"; + << "functions cannot have '" + << stringifyLinkage(LLVM::Linkage::Common) << "' linkage"; if (op.isExternal()) { if (op.linkage() != LLVM::Linkage::External && op.linkage() != LLVM::Linkage::ExternWeak) return op.emitOpError() << "external functions must have '" - << linkageToStr(LLVM::Linkage::External) << "' or '" - << linkageToStr(LLVM::Linkage::ExternWeak) << "' linkage"; + << stringifyLinkage(LLVM::Linkage::External) << "' or '" + << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage"; return success(); } diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -154,6 +154,10 @@ return def->getValueAsString("symbol"); } +StringRef tblgen::EnumAttrCase::getStr() const { + return def->getValueAsString("str"); +} + int64_t tblgen::EnumAttrCase::getValue() const { return def->getValueAsInt("value"); } diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -165,8 +165,9 @@ os << " switch (val) {\n"; for (const auto &enumerant : enumerants) { auto symbol = enumerant.getSymbol(); + auto str = enumerant.getStr(); os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName, - makeIdentifier(symbol), symbol); + makeIdentifier(symbol), str); } os << " }\n"; os << " return \"\";\n"; @@ -219,7 +220,8 @@ enumName); for (const auto &enumerant : enumerants) { auto symbol = enumerant.getSymbol(); - os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, symbol, + auto str = enumerant.getStr(); + os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, str, makeIdentifier(symbol)); } os << " .Default(llvm::None);\n"; diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp --- a/mlir/unittests/TableGen/EnumsGenTest.cpp +++ b/mlir/unittests/TableGen/EnumsGenTest.cpp @@ -94,3 +94,21 @@ EXPECT_FALSE(bitEnumContains(BitEnumWithNone::Bit1 & BitEnumWithNone::Bit3, BitEnumWithNone::Bit1)); } + +TEST(EnumsGenTest, GeneratedSymbolToCustomStringFn) { + EXPECT_EQ(stringifyPrettyIntEnum(PrettyIntEnum::Case1), "case_one"); + EXPECT_EQ(stringifyPrettyIntEnum(PrettyIntEnum::Case2), "case_two"); +} + +TEST(EnumsGenTest, GeneratedCustomStringToSymbolFn) { + auto one = symbolizePrettyIntEnum("case_one"); + EXPECT_TRUE(one); + EXPECT_EQ(*one, PrettyIntEnum::Case1); + + auto two = symbolizePrettyIntEnum("case_two"); + EXPECT_TRUE(two); + EXPECT_EQ(*two, PrettyIntEnum::Case2); + + auto none = symbolizePrettyIntEnum("Case1"); + EXPECT_FALSE(none); +} diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td --- a/mlir/unittests/TableGen/enums.td +++ b/mlir/unittests/TableGen/enums.td @@ -31,3 +31,9 @@ def BitEnumWithoutNone : BitEnumAttr<"BitEnumWithoutNone", "A test enum", [Bit1, Bit3]>; + +def PrettyIntEnumCase1: I32EnumAttrCase<"Case1", 1, "case_one">; +def PrettyIntEnumCase2: I32EnumAttrCase<"Case2", 2, "case_two">; + +def PrettyIntEnum: I32EnumAttr<"PrettyIntEnum", "A test enum", + [PrettyIntEnumCase1, PrettyIntEnumCase2]>;