diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -1403,11 +1403,11 @@ Similarly for the following `BitEnumAttr` definition: ```tablegen -def None: BitEnumAttrCaseNone<"None">; -def Bit0: BitEnumAttrCaseBit<"Bit0", 0>; -def Bit1: BitEnumAttrCaseBit<"Bit1", 1>; -def Bit2: BitEnumAttrCaseBit<"Bit2", 2>; -def Bit3: BitEnumAttrCaseBit<"Bit3", 3>; +def None: I32BitEnumAttrCaseNone<"None">; +def Bit0: I32BitEnumAttrCaseBit<"Bit0", 0, "tagged">; +def Bit1: I32BitEnumAttrCaseBit<"Bit1", 1>; +def Bit2: I32BitEnumAttrCaseBit<"Bit2", 2>; +def Bit3: I32BitEnumAttrCaseBit<"Bit3", 3>; def MyBitEnum: BitEnumAttr<"MyBitEnum", "An example bit enum", [None, Bit0, Bit1, Bit2, Bit3]>; @@ -1428,14 +1428,37 @@ llvm::Optional symbolizeMyBitEnum(uint32_t); std::string stringifyMyBitEnum(MyBitEnum); llvm::Optional symbolizeMyBitEnum(llvm::StringRef); -inline MyBitEnum operator|(MyBitEnum lhs, MyBitEnum rhs) { - return static_cast(static_cast(lhs) | static_cast(rhs)); + +inline constexpr MyBitEnum operator|(MyBitEnum a, MyBitEnum b) { + return static_cast(static_cast(a) | static_cast(b)); +} +inline constexpr MyBitEnum operator&(MyBitEnum a, MyBitEnum b) { + return static_cast(static_cast(a) & static_cast(b)); +} +inline constexpr MyBitEnum operator^(MyBitEnum a, MyBitEnum b) { + return static_cast(static_cast(a) ^ static_cast(b)); +} +inline constexpr MyBitEnum operator~(MyBitEnum bits) { + // Ensure only bits that can be present in the enum are set + return static_cast(~static_cast(bits) & static_cast(15u)); +} +inline constexpr bool bitEnumContains(MyBitEnum bits, MyBitEnum bit) { + return (bits & bit) == bit; } -inline MyBitEnum operator&(MyBitEnum lhs, MyBitEnum rhs) { - return static_cast(static_cast(lhs) & static_cast(rhs)); +inline constexpr MyBitEnum bitEnumClear(MyBitEnum bits, MyBitEnum bit) { + return bits & ~bit; } -inline bool bitEnumContains(MyBitEnum bits, MyBitEnum bit) { - return (static_cast(bits) & static_cast(bit)) != 0; + +inline std::string stringifyEnum(MyBitEnum enumValue) { + return stringifyMyBitEnum(enumValue); +} + +template +::llvm::Optional symbolizeEnum(::llvm::StringRef); + +template <> +inline ::llvm::Optional symbolizeEnum(::llvm::StringRef str) { + return symbolizeMyBitEnum(str); } namespace llvm { @@ -1467,7 +1490,7 @@ // Special case for all bits unset. if (val == 0) return "None"; llvm::SmallVector strs; - if (1u == (1u & val)) { strs.push_back("Bit0"); } + if (1u == (1u & val)) { strs.push_back("tagged"); } if (2u == (2u & val)) { strs.push_back("Bit1"); } if (4u == (4u & val)) { strs.push_back("Bit2"); } if (8u == (8u & val)) { strs.push_back("Bit3"); } @@ -1485,7 +1508,7 @@ uint32_t val = 0; for (auto symbol : symbols) { auto bit = llvm::StringSwitch>(symbol) - .Case("Bit0", 1) + .Case("tagged", 1) .Case("Bit1", 2) .Case("Bit2", 4) .Case("Bit3", 8) @@ -1499,7 +1522,7 @@ // Special case for all bits unset. if (value == 0) return MyBitEnum::None; - if (value & ~(1u | 2u | 4u | 8u)) return llvm::None; + if (value & ~static_cast(15u)) return llvm::None; return static_cast(value); } ``` diff --git a/mlir/test/mlir-tblgen/enums-gen.td b/mlir/test/mlir-tblgen/enums-gen.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/enums-gen.td @@ -0,0 +1,42 @@ +// RUN: mlir-tblgen -gen-enum-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL +// RUN: mlir-tblgen -gen-enum-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF + +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpBase.td" + +// Test bit enums +def None: I32BitEnumAttrCaseNone<"None">; +def Bit0: I32BitEnumAttrCaseBit<"Bit0", 0, "tagged">; +def Bit1: I32BitEnumAttrCaseBit<"Bit1", 1>; +def Bit2: I32BitEnumAttrCaseBit<"Bit2", 2>; +def Bit3: I32BitEnumAttrCaseBit<"Bit3", 3>; + +def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum", + [None, Bit0, Bit1, Bit2, Bit3]> { + let genSpecializedAttr = 0; +} + +// DECL-LABEL: enum class MyBitEnum : uint32_t +// DECL: None = 0, +// DECL: Bit0 = 1, +// DECL: Bit1 = 2, +// DECL: Bit2 = 4, +// DECL: Bit3 = 8, +// DECL: } + +// DECL: ::llvm::Optional symbolizeMyBitEnum(uint32_t); +// DECL: std::string stringifyMyBitEnum(MyBitEnum); +// DECL: ::llvm::Optional symbolizeMyBitEnum(::llvm::StringRef); + +// DEF-LABEL: std::string stringifyMyBitEnum +// DEF: auto val = static_cast +// DEF: if (val == 0) return "None"; +// DEF: if (1u == (1u & val)) +// DEF-NEXT: push_back("tagged") +// DEF: if (2u == (2u & val)) +// DEF-NEXT: push_back("Bit1") + +// DEF-LABEL: ::llvm::Optional symbolizeMyBitEnum(::llvm::StringRef str) +// DEF: if (str == "None") return MyBitEnum::None; +// DEF: .Case("tagged", 1) +// DEF: .Case("Bit1", 2) diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -136,28 +136,42 @@ // // inline constexpr operator|( a, b); // inline constexpr operator&( a, b); -// inline constexpr bool bitEnumContains( a, b); +// inline constexpr operator^( a, b); +// inline constexpr operator~( bits); +// inline constexpr bool bitEnumContains( bits, bit); +// inline constexpr bitEnumClear( bits, bit); +// inline constexpr bitEnumSet( bits, bit, +// bool value=true); static void emitOperators(const Record &enumDef, raw_ostream &os) { EnumAttr enumAttr(enumDef); StringRef enumName = enumAttr.getEnumClassName(); std::string underlyingType = std::string(enumAttr.getUnderlyingType()); - os << formatv("inline constexpr {0} operator|({0} lhs, {0} rhs) {{\n", - enumName) - << formatv(" return static_cast<{0}>(" - "static_cast<{1}>(lhs) | static_cast<{1}>(rhs));\n", - enumName, underlyingType) - << "}\n"; - os << formatv("inline constexpr {0} operator&({0} lhs, {0} rhs) {{\n", - enumName) - << formatv(" return static_cast<{0}>(" - "static_cast<{1}>(lhs) & static_cast<{1}>(rhs));\n", - enumName, underlyingType) - << "}\n"; - os << formatv( - "inline constexpr bool bitEnumContains({0} bits, {0} bit) {{\n" - " return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;\n", - enumName, underlyingType) - << "}\n"; + int64_t validBits = enumDef.getValueAsInt("validBits"); + const char *const operators = R"( +inline constexpr {0} operator|({0} a, {0} b) {{ + return static_cast<{0}>(static_cast<{1}>(a) | static_cast<{1}>(b)); +} +inline constexpr {0} operator&({0} a, {0} b) {{ + return static_cast<{0}>(static_cast<{1}>(a) & static_cast<{1}>(b)); +} +inline constexpr {0} operator^({0} a, {0} b) {{ + return static_cast<{0}>(static_cast<{1}>(a) ^ static_cast<{1}>(b)); +} +inline constexpr {0} operator~({0} bits) {{ + // Ensure only bits that can be present in the enum are set + return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u)); +} +inline constexpr bool bitEnumContains({0} bits, {0} bit) {{ + return (bits & bit) == bit; +} +inline constexpr {0} bitEnumClear({0} bits, {0} bit) {{ + return bits & ~bit; +} +inline constexpr {0} bitEnumSet({0} bits, {0} bit, /*optional*/bool value=true) {{ + return value ? (bits | bit) : bitEnumClear(bits, bit); +} + )"; + os << formatv(operators, enumName, underlyingType, validBits); } static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) { @@ -424,13 +438,9 @@ os << formatv(" if (value == 0) return {0}::{1};\n\n", enumName, makeIdentifier(allBitsUnsetCase->getSymbol())); } - llvm::SmallVector values; - for (const auto &enumerant : enumerants) { - if (auto val = enumerant.getValue()) - values.push_back(std::string(formatv("{0}u", val))); - } - os << formatv(" if (value & ~static_cast<{0}>({1})) return llvm::None;\n", - underlyingType, llvm::join(values, " | ")); + int64_t validBits = enumDef.getValueAsInt("validBits"); + os << formatv(" if (value & ~static_cast<{0}>({1}u)) return llvm::None;\n", + underlyingType, validBits); os << formatv(" return static_cast<{0}>(value);\n", enumName); os << "}\n"; }