diff --git a/mlir/test/mlir-tblgen/enums-gen.td b/mlir/test/mlir-tblgen/enums-gen.td --- a/mlir/test/mlir-tblgen/enums-gen.td +++ b/mlir/test/mlir-tblgen/enums-gen.td @@ -10,9 +10,12 @@ def Bit1: I32BitEnumAttrCaseBit<"Bit1", 1>; def Bit2: I32BitEnumAttrCaseBit<"Bit2", 2>; def Bit3: I32BitEnumAttrCaseBit<"Bit3", 3>; +def BitGroup: I32BitEnumAttrCaseGroup<"BitGroup", [ + Bit0, Bit1 +]>; def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum", - [None, Bit0, Bit1, Bit2, Bit3]> { + [None, Bit0, Bit1, Bit2, Bit3, BitGroup]> { let genSpecializedAttr = 0; } @@ -44,6 +47,15 @@ // DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyBitEnum value) { // DECL: auto valueStr = stringifyEnum(value); +// DECL: switch (value) { +// DECL: case ::MyBitEnum::BitGroup: +// DECL: return p << valueStr; +// DECL: default: +// DECL: break; +// DECL: } +// DECL: auto underlyingValue = static_cast>(value); +// DECL: if (underlyingValue && !llvm::has_single_bit(underlyingValue)) +// DECL: return p << '"' << valueStr << '"'; // DECL: return p << valueStr; // DEF-LABEL: std::string stringifyMyBitEnum diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -80,16 +80,6 @@ if (!mlir::tblgen::canFormatStringAsKeyword(caseVal.getStr())) nonKeywordCases.set(index); - // If this is a bit enum attribute, don't allow cases that may overlap with - // other cases. For simplicity sake, only allow cases with a single bit value. - if (enumAttr.isBitEnum()) { - for (auto [index, caseVal] : llvm::enumerate(cases)) { - int64_t value = caseVal.getValue(); - if (value < 0 || (value != 0 && !llvm::isPowerOf2_64(value))) - nonKeywordCases.set(index); - } - } - // Generate the parser and the start of the printer for the enum. const char *parsedAndPrinterStart = R"( namespace mlir { @@ -137,7 +127,7 @@ if (nonKeywordCases.test(it.index())) continue; StringRef symbol = it.value().getSymbol(); - os << llvm::formatv(" case {0}::{1}:\n", qualName, + os << llvm::formatv(" case {0}::{1}:\n", qualName, llvm::isDigit(symbol.front()) ? ("_" + symbol) : symbol); } @@ -145,6 +135,37 @@ " default:\n" " return p << '\"' << valueStr << '\"';\n" " }\n"; + + // If this is a bit enum, conservatively print the string form if the value + // is not a power of two (i.e. not a single bit case) and not a known case. + } else if (enumAttr.isBitEnum()) { + // Process the known multi-bit cases that use valid keywords. + std::vector validMultiBitCases; + for (auto [index, caseVal] : llvm::enumerate(cases)) { + uint64_t value = caseVal.getValue(); + if (value && !nonKeywordCases.test(index) && !llvm::has_single_bit(value)) + validMultiBitCases.push_back(&caseVal); + } + if (!validMultiBitCases.empty()) { + os << " switch (value) {\n"; + for (EnumAttrCase *caseVal : validMultiBitCases) { + StringRef symbol = caseVal->getSymbol(); + os << llvm::formatv(" case {0}::{1}:\n", qualName, + llvm::isDigit(symbol.front()) ? ("_" + symbol) + : symbol); + } + os << " return p << valueStr;\n" + " default:\n" + " break;\n" + " }\n"; + } + + // All other multi-bit cases should be printed as strings. + os << formatv(" auto underlyingValue = " + "static_cast>(value);\n", + qualName); + os << " if (underlyingValue && !llvm::has_single_bit(underlyingValue))\n" + " return p << '\"' << valueStr << '\"';\n"; } os << " return p << valueStr;\n" "}\n"