diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1549,6 +1549,14 @@ // The delimiter used to separate bit enum cases in strings. string separator = "|"; + + // Print the "primary group" only for bits that are members of case + // groups that have all bits present. When the value is 0, printing + // will display both both individual bit case names AND the names + // for all groups that the bit is contained in. When the value is 1, + // only the "primary group" (i.e. the first full group in reverse + // declaration order) will be printed (for conciseness). + bit printBitEnumPrimaryGroups = 0; } class I32BitEnumAttrgetValueAsString("specializedAttrClassName"); } +bool EnumAttr::printBitEnumPrimaryGroups() const { + return def->getValueAsBit("printBitEnumPrimaryGroups"); +} + StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) { assert(def->isSubClassOf("StructFieldAttr") && "must be subclass of TableGen 'StructFieldAttr' class"); diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -188,6 +188,12 @@ auto enumerants = enumAttr.getAllCases(); auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants); + const char *const formatCompareRemove = + " if ({0}u == ({0}u & val)) {{ strs.push_back(\"{1}\"); " + "val &= ~static_cast<{2}>({0}); }\n"; + const char *const formatCompare = + " if ({0}u == ({0}u & val)) {{ strs.push_back(\"{1}\"); }\n"; + os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName, symToStrFnRetType); @@ -204,12 +210,34 @@ allBitsUnsetCase->getSymbol()); } os << " ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n"; - for (const auto &enumerant : enumerants) { - // Skip the special enumerant for None. - if (int64_t val = enumerant.getValue()) - os << formatv( - " if ({0}u == ({0}u & val)) {{ strs.push_back(\"{1}\"); }\n ", val, - enumerant.getStr()); + // Optionally elide bits that are members of groups that will be printed for + // more concise output. + if (enumAttr.printBitEnumPrimaryGroups()) { + os << " // Print bit enum groups before individual bits\n"; + // Check group bit patterns (in reverse) + for (auto it = enumerants.rbegin(); it != enumerants.rend(); ++it) { + auto enumerant = *it; + int64_t val = enumerant.getValue(); + if (val && enumerant.getDef().isSubClassOf("BitEnumAttrCaseGroup")) { + os << formatv(formatCompareRemove, val, enumerant.getStr(), + enumAttr.getUnderlyingType()); + } + } + // Check non-group bit patterns + for (const auto &enumerant : enumerants) { + int64_t val = enumerant.getValue(); + if (val && !enumerant.getDef().isSubClassOf("BitEnumAttrCaseGroup")) { + os << formatv(formatCompare, val, enumerant.getStr()); + } + } + } else { + for (const auto &enumerant : enumerants) { + // Skip the special enumerant for None. + if (int64_t val = enumerant.getValue()) { + // Emit code to check ALL cases (individual bits and groups) + os << formatv(formatCompare, val, enumerant.getStr()); + } + } } os << formatv(" return ::llvm::join(strs, \"{0}\");\n", separator); diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp --- a/mlir/unittests/TableGen/EnumsGenTest.cpp +++ b/mlir/unittests/TableGen/EnumsGenTest.cpp @@ -70,6 +70,9 @@ EXPECT_EQ(0u, static_cast(BitEnumWithNone::None)); EXPECT_EQ(1u, static_cast(BitEnumWithNone::Bit0)); EXPECT_EQ(8u, static_cast(BitEnumWithNone::Bit3)); + + EXPECT_EQ(2u, static_cast(BitEnum64_Test::Bit1)); + EXPECT_EQ(144115188075855872u, static_cast(BitEnum64_Test::Bit57)); } TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) { @@ -79,8 +82,11 @@ EXPECT_EQ( stringifyBitEnumWithNone(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3), "Bit0|Bit3"); - EXPECT_EQ(2u, static_cast(BitEnum64_Test::Bit1)); - EXPECT_EQ(144115188075855872u, static_cast(BitEnum64_Test::Bit57)); + + EXPECT_EQ(stringifyBitEnum64_Test(BitEnum64_Test::Bit1), "Bit1"); + EXPECT_EQ( + stringifyBitEnum64_Test(BitEnum64_Test::Bit1 | BitEnum64_Test::Bit57), + "Bit1|Bit57"); } TEST(EnumsGenTest, GeneratedStringToSymbolForBitEnum) { @@ -116,6 +122,26 @@ BitEnumWithGroup::Bit3 | BitEnumWithGroup::Bit0); } +TEST(EnumsGenTest, GeneratedSymbolToStringFnForPrimaryGroupBitEnum) { + EXPECT_EQ(stringifyBitEnumPrimaryGroup( + BitEnumPrimaryGroup::Bit0 | BitEnumPrimaryGroup::Bit1 | + BitEnumPrimaryGroup::Bit2 | BitEnumPrimaryGroup::Bit3), + "Bits0To3"); + EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 | + BitEnumPrimaryGroup::Bit2 | + BitEnumPrimaryGroup::Bit3), + "Bit0,Bit2,Bit3"); + EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 | + BitEnumPrimaryGroup::Bit4 | + BitEnumPrimaryGroup::Bit5), + "Bits4And5,Bit0"); + EXPECT_EQ(stringifyBitEnumPrimaryGroup( + BitEnumPrimaryGroup::Bit0 | BitEnumPrimaryGroup::Bit1 | + BitEnumPrimaryGroup::Bit2 | BitEnumPrimaryGroup::Bit3 | + BitEnumPrimaryGroup::Bit4 | BitEnumPrimaryGroup::Bit5), + "Bits0To5"); +} + TEST(EnumsGenTest, GeneratedOperator) { EXPECT_TRUE(bitEnumContains(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3, BitEnumWithNone::Bit0)); diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td --- a/mlir/unittests/TableGen/enums.td +++ b/mlir/unittests/TableGen/enums.td @@ -39,10 +39,28 @@ def Bits0To3 : I32BitEnumAttrCaseGroup<"Bits0To3", [Bit0, Bit1, Bit2, Bit3]>; +def Bits4And5 : I32BitEnumAttrCaseGroup<"Bits4And5", + [Bit4, Bit5]>; +def Bits0To5 : I32BitEnumAttrCaseGroup<"Bits0To5", + [Bits0To3, Bits4And5]>; def BitEnumWithGroup : I32BitEnumAttr<"BitEnumWithGroup", "A test enum", [Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]>; +def BitEnumPrimaryGroup : I32BitEnumAttr<"BitEnumPrimaryGroup", "test enum", + [Bit0, + Bit1, + Bit2, + Bit3, + Bit4, + Bit5, + Bits0To3, + Bits4And5, + Bits0To5]> { + let separator = ","; + let printBitEnumPrimaryGroups = 1; +} + def BitEnum64_None : I64BitEnumAttrCaseNone<"None">; def BitEnum64_57 : I64BitEnumAttrCaseBit<"Bit57", 57>; def BitEnum64_1 : I64BitEnumAttrCaseBit<"Bit1", 1>;