diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -1400,10 +1400,9 @@ ```c++ std::string stringifyMyBitEnum(MyBitEnum symbol) { auto val = static_cast(symbol); + assert(15u == (15u | val) && "invalid bits set in bit enum"); // Special case for all bits unset. if (val == 0) return "None"; - // Return an empty string if any invalid bits are present - if (15 != (15 | val)) return ""; llvm::SmallVector strs; if (1u == (1u & val)) { strs.push_back("Bit0"); } if (2u == (2u & val)) { strs.push_back("Bit1"); } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -77,7 +77,7 @@ // Wrapper over base BitEnumAttr to set common fields. class SPV_BitEnumAttr cases> : + list cases> : BitEnumAttr { let predicate = And<[ I32Attr.predicate, diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1323,28 +1323,32 @@ // A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the // ordinal number of the bit that is set. It is the 32-bit integer with only // one bit set. -class BitEnumAttrCaseBase : - EnumAttrCaseInfo, - SignlessIntegerAttrBase { +class BitEnumAttrCase + : EnumAttrCaseInfo, + SignlessIntegerAttrBase { } // The special bit enum case for no bits set (i.e. value = 0). -class BitEnumAttrCaseNone : - BitEnumAttrCaseBase { +class BitEnumAttrCaseNone + : BitEnumAttrCase { } // The bit enum case for a single bit, specified by the bit position. // The pos argument refers to the index of the bit, and is currently // limited to be in the range [0, 31]. -class BitEnumAttrCaseBit : - BitEnumAttrCaseBase { - assert !and(!ge(pos, 0), !le(pos, 31)), "Bit position must be between 0 and 31"; +class BitEnumAttrCaseBit + : BitEnumAttrCase { + assert !and(!ge(pos, 0), !le(pos, 31)), + "Bit position must be between 0 and 31"; } // A bit enum case for a group/list of previously declared single bits, // providing a convenient alias for that group. -class BitEnumAttrCaseGroup cases, string str = sym> : - BitEnumAttrCaseBase { +class BitEnumAttrCaseGroup cases, + string str = sym> + : BitEnumAttrCase< + sym, !foldl(0, cases, value, bitcase, !or(value, bitcase.value)), + str> { } // Additional information for an enum attribute. @@ -1471,7 +1475,7 @@ // be generated on the integer to make sure only allowed bits are set. Besides, // helper methods are generated to parse a string separated with a specified // delimiter to a symbol and vice versa. -class BitEnumAttrBase cases, string summary> : +class BitEnumAttrBase cases, string summary> : SignlessIntegerAttrBase { let predicate = And<[ I32Attr.predicate, @@ -1482,7 +1486,7 @@ ]>; } -class BitEnumAttr cases> : +class BitEnumAttr cases> : EnumAttrInfo> { let underlyingType = "uint32_t"; diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -195,7 +195,9 @@ enumAttr.getUnderlyingType()); // If we have unknown bit set, return an empty string to signal errors. int64_t validBits = enumDef.getValueAsInt("validBits"); - os << formatv(" if ({0}u != ({0}u | val)) return \"\";\n", validBits); + os << formatv(" assert({0}u == ({0}u | val) && \"invalid bits set in bit " + "enum\");\n", + validBits); if (allBitsUnsetCase) { os << " // Special case for all bits unset.\n"; os << formatv(" if (val == 0) return \"{0}\";\n\n",