diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -1245,7 +1245,8 @@ [`StringAttr`][StringAttr] in the op. * `IntEnumAttr`: each enum case is an integer, the attribute is stored as a [`IntegerAttr`][IntegerAttr] in the op. -* `BitEnumAttr`: each enum case is a bit, the attribute is stored as a +* `BitEnumAttr`: each enum case is a either the empty case, a single bit, + or a group of single bits, and the attribute is stored as a [`IntegerAttr`][IntegerAttr] in the op. All these `*EnumAttr` attributes require fully specifying all of the allowed @@ -1349,13 +1350,14 @@ Similarly for the following `BitEnumAttr` definition: ```tablegen -def None: BitEnumAttrCase<"None", 0x0000>; -def Bit1: BitEnumAttrCase<"Bit1", 0x0001>; -def Bit2: BitEnumAttrCase<"Bit2", 0x0002>; -def Bit3: BitEnumAttrCase<"Bit3", 0x0004>; +def None: BitEnumAttrCaseNone<"None">; +def Bit0: BitEnumAttrCaseBit<"Bit0", 0>; +def Bit1: BitEnumAttrCaseBit<"Bit1", 1>; +def Bit2: BitEnumAttrCaseBit<"Bit2", 2>; +def Bit3: BitEnumAttrCaseBit<"Bit3", 3>; def MyBitEnum: BitEnumAttr<"MyBitEnum", "An example bit enum", - [None, Bit1, Bit2, Bit3]>; + [None, Bit0, Bit1, Bit2, Bit3]>; ``` We can have: @@ -1364,9 +1366,10 @@ // An example bit enum enum class MyBitEnum : uint32_t { None = 0, - Bit1 = 1, - Bit2 = 2, - Bit3 = 4, + Bit0 = 1, + Bit1 = 2, + Bit2 = 4, + Bit3 = 8, }; llvm::Optional symbolizeMyBitEnum(uint32_t); @@ -1407,15 +1410,15 @@ ```c++ std::string stringifyMyBitEnum(MyBitEnum symbol) { auto val = static_cast(symbol); + assert(15u == (15u | val) && "invalid bits set in bit enum"); // Special case for all bits unset. if (val == 0) return "None"; - llvm::SmallVector strs; - if (1u & val) { strs.push_back("Bit1"); val &= ~1u; } - if (2u & val) { strs.push_back("Bit2"); val &= ~2u; } - if (4u & val) { strs.push_back("Bit3"); val &= ~4u; } - - if (val) return ""; + if (1u == (1u & val)) { strs.push_back("Bit0"); } + if (2u == (2u & val)) { strs.push_back("Bit1"); } + if (4u == (4u & val)) { strs.push_back("Bit2"); } + if (8u == (8u & val)) { strs.push_back("Bit3"); } + return llvm::join(strs, "|"); } @@ -1429,9 +1432,10 @@ uint32_t val = 0; for (auto symbol : symbols) { auto bit = llvm::StringSwitch>(symbol) - .Case("Bit1", 1) - .Case("Bit2", 2) - .Case("Bit3", 4) + .Case("Bit0", 1) + .Case("Bit1", 2) + .Case("Bit2", 4) + .Case("Bit3", 8) .Default(llvm::None); if (bit) { val |= *bit; } else { return llvm::None; } } @@ -1442,7 +1446,7 @@ // Special case for all bits unset. if (value == 0) return MyBitEnum::None; - if (value & ~(1u | 2u | 4u)) return llvm::None; + if (value & ~(1u | 2u | 4u | 8u)) return llvm::None; return static_cast(value); } ``` diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -21,14 +21,14 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" -def FMFnnan : BitEnumAttrCase<"nnan", 0x1>; -def FMFninf : BitEnumAttrCase<"ninf", 0x2>; -def FMFnsz : BitEnumAttrCase<"nsz", 0x4>; -def FMFarcp : BitEnumAttrCase<"arcp", 0x8>; -def FMFcontract : BitEnumAttrCase<"contract", 0x10>; -def FMFafn : BitEnumAttrCase<"afn", 0x20>; -def FMFreassoc : BitEnumAttrCase<"reassoc", 0x40>; -def FMFfast : BitEnumAttrCase<"fast", 0x80>; +def FMFnnan : BitEnumAttrCaseBit<"nnan", 0>; +def FMFninf : BitEnumAttrCaseBit<"ninf", 1>; +def FMFnsz : BitEnumAttrCaseBit<"nsz", 2>; +def FMFarcp : BitEnumAttrCaseBit<"arcp", 3>; +def FMFcontract : BitEnumAttrCaseBit<"contract", 4>; +def FMFafn : BitEnumAttrCaseBit<"afn", 5>; +def FMFreassoc : BitEnumAttrCaseBit<"reassoc", 6>; +def FMFfast : BitEnumAttrCaseBit<"fast", 7>; def FastmathFlags_DoNotUse : BitEnumAttr< "FastmathFlags", diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3082,12 +3082,12 @@ SPV_EM_AnyHitKHR, SPV_EM_ClosestHitKHR, SPV_EM_MissKHR, SPV_EM_CallableKHR ]>; -def SPV_FC_None : BitEnumAttrCase<"None", 0x0000>; -def SPV_FC_Inline : BitEnumAttrCase<"Inline", 0x0001>; -def SPV_FC_DontInline : BitEnumAttrCase<"DontInline", 0x0002>; -def SPV_FC_Pure : BitEnumAttrCase<"Pure", 0x0004>; -def SPV_FC_Const : BitEnumAttrCase<"Const", 0x0008>; -def SPV_FC_OptNoneINTEL : BitEnumAttrCase<"OptNoneINTEL", 0x10000> { +def SPV_FC_None : BitEnumAttrCaseNone<"None">; +def SPV_FC_Inline : BitEnumAttrCaseBit<"Inline", 0>; +def SPV_FC_DontInline : BitEnumAttrCaseBit<"DontInline", 1>; +def SPV_FC_Pure : BitEnumAttrCaseBit<"Pure", 2>; +def SPV_FC_Const : BitEnumAttrCaseBit<"Const", 3>; +def SPV_FC_OptNoneINTEL : BitEnumAttrCaseBit<"OptNoneINTEL", 16> { list availability = [ Capability<[SPV_C_OptNoneINTEL]> ]; @@ -3366,62 +3366,62 @@ SPV_IF_R8ui, SPV_IF_R64ui, SPV_IF_R64i ]>; -def SPV_IO_None : BitEnumAttrCase<"None", 0x0000>; -def SPV_IO_Bias : BitEnumAttrCase<"Bias", 0x0001> { +def SPV_IO_None : BitEnumAttrCaseNone<"None">; +def SPV_IO_Bias : BitEnumAttrCaseBit<"Bias", 0> { list availability = [ Capability<[SPV_C_Shader]> ]; } -def SPV_IO_Lod : BitEnumAttrCase<"Lod", 0x0002>; -def SPV_IO_Grad : BitEnumAttrCase<"Grad", 0x0004>; -def SPV_IO_ConstOffset : BitEnumAttrCase<"ConstOffset", 0x0008>; -def SPV_IO_Offset : BitEnumAttrCase<"Offset", 0x0010> { +def SPV_IO_Lod : BitEnumAttrCaseBit<"Lod", 1>; +def SPV_IO_Grad : BitEnumAttrCaseBit<"Grad", 2>; +def SPV_IO_ConstOffset : BitEnumAttrCaseBit<"ConstOffset", 3>; +def SPV_IO_Offset : BitEnumAttrCaseBit<"Offset", 4> { list availability = [ Capability<[SPV_C_ImageGatherExtended]> ]; } -def SPV_IO_ConstOffsets : BitEnumAttrCase<"ConstOffsets", 0x0020> { +def SPV_IO_ConstOffsets : BitEnumAttrCaseBit<"ConstOffsets", 5> { list availability = [ Capability<[SPV_C_ImageGatherExtended]> ]; } -def SPV_IO_Sample : BitEnumAttrCase<"Sample", 0x0040>; -def SPV_IO_MinLod : BitEnumAttrCase<"MinLod", 0x0080> { +def SPV_IO_Sample : BitEnumAttrCaseBit<"Sample", 6>; +def SPV_IO_MinLod : BitEnumAttrCaseBit<"MinLod", 7> { list availability = [ Capability<[SPV_C_MinLod]> ]; } -def SPV_IO_MakeTexelAvailable : BitEnumAttrCase<"MakeTexelAvailable", 0x0100> { +def SPV_IO_MakeTexelAvailable : BitEnumAttrCaseBit<"MakeTexelAvailable", 8> { list availability = [ MinVersion, Capability<[SPV_C_VulkanMemoryModel]> ]; } -def SPV_IO_MakeTexelVisible : BitEnumAttrCase<"MakeTexelVisible", 0x0200> { +def SPV_IO_MakeTexelVisible : BitEnumAttrCaseBit<"MakeTexelVisible", 9> { list availability = [ MinVersion, Capability<[SPV_C_VulkanMemoryModel]> ]; } -def SPV_IO_NonPrivateTexel : BitEnumAttrCase<"NonPrivateTexel", 0x0400> { +def SPV_IO_NonPrivateTexel : BitEnumAttrCaseBit<"NonPrivateTexel", 10> { list availability = [ MinVersion, Capability<[SPV_C_VulkanMemoryModel]> ]; } -def SPV_IO_VolatileTexel : BitEnumAttrCase<"VolatileTexel", 0x0800> { +def SPV_IO_VolatileTexel : BitEnumAttrCaseBit<"VolatileTexel", 11> { list availability = [ MinVersion, Capability<[SPV_C_VulkanMemoryModel]> ]; } -def SPV_IO_SignExtend : BitEnumAttrCase<"SignExtend", 0x1000> { +def SPV_IO_SignExtend : BitEnumAttrCaseBit<"SignExtend", 12> { list availability = [ MinVersion ]; } -def SPV_IO_Offsets : BitEnumAttrCase<"Offsets", 0x10000>; -def SPV_IO_ZeroExtend : BitEnumAttrCase<"ZeroExtend", 0x2000> { +def SPV_IO_Offsets : BitEnumAttrCaseBit<"Offsets", 16>; +def SPV_IO_ZeroExtend : BitEnumAttrCaseBit<"ZeroExtend", 13> { list availability = [ MinVersion ]; @@ -3457,87 +3457,87 @@ SPV_LT_Export, SPV_LT_Import, SPV_LT_LinkOnceODR ]>; -def SPV_LC_None : BitEnumAttrCase<"None", 0x0000>; -def SPV_LC_Unroll : BitEnumAttrCase<"Unroll", 0x0001>; -def SPV_LC_DontUnroll : BitEnumAttrCase<"DontUnroll", 0x0002>; -def SPV_LC_DependencyInfinite : BitEnumAttrCase<"DependencyInfinite", 0x0004> { +def SPV_LC_None : BitEnumAttrCaseNone<"None">; +def SPV_LC_Unroll : BitEnumAttrCaseBit<"Unroll", 0>; +def SPV_LC_DontUnroll : BitEnumAttrCaseBit<"DontUnroll", 1>; +def SPV_LC_DependencyInfinite : BitEnumAttrCaseBit<"DependencyInfinite", 2> { list availability = [ MinVersion ]; } -def SPV_LC_DependencyLength : BitEnumAttrCase<"DependencyLength", 0x0008> { +def SPV_LC_DependencyLength : BitEnumAttrCaseBit<"DependencyLength", 3> { list availability = [ MinVersion ]; } -def SPV_LC_MinIterations : BitEnumAttrCase<"MinIterations", 0x0010> { +def SPV_LC_MinIterations : BitEnumAttrCaseBit<"MinIterations", 4> { list availability = [ MinVersion ]; } -def SPV_LC_MaxIterations : BitEnumAttrCase<"MaxIterations", 0x0020> { +def SPV_LC_MaxIterations : BitEnumAttrCaseBit<"MaxIterations", 5> { list availability = [ MinVersion ]; } -def SPV_LC_IterationMultiple : BitEnumAttrCase<"IterationMultiple", 0x0040> { +def SPV_LC_IterationMultiple : BitEnumAttrCaseBit<"IterationMultiple", 6> { list availability = [ MinVersion ]; } -def SPV_LC_PeelCount : BitEnumAttrCase<"PeelCount", 0x0080> { +def SPV_LC_PeelCount : BitEnumAttrCaseBit<"PeelCount", 7> { list availability = [ MinVersion ]; } -def SPV_LC_PartialCount : BitEnumAttrCase<"PartialCount", 0x0100> { +def SPV_LC_PartialCount : BitEnumAttrCaseBit<"PartialCount", 8> { list availability = [ MinVersion ]; } -def SPV_LC_InitiationIntervalINTEL : BitEnumAttrCase<"InitiationIntervalINTEL", 0x10000> { +def SPV_LC_InitiationIntervalINTEL : BitEnumAttrCaseBit<"InitiationIntervalINTEL", 16> { list availability = [ Extension<[SPV_INTEL_fpga_loop_controls]>, Capability<[SPV_C_FPGALoopControlsINTEL]> ]; } -def SPV_LC_LoopCoalesceINTEL : BitEnumAttrCase<"LoopCoalesceINTEL", 0x100000> { +def SPV_LC_LoopCoalesceINTEL : BitEnumAttrCaseBit<"LoopCoalesceINTEL", 20> { list availability = [ Extension<[SPV_INTEL_fpga_loop_controls]>, Capability<[SPV_C_FPGALoopControlsINTEL]> ]; } -def SPV_LC_MaxConcurrencyINTEL : BitEnumAttrCase<"MaxConcurrencyINTEL", 0x20000> { +def SPV_LC_MaxConcurrencyINTEL : BitEnumAttrCaseBit<"MaxConcurrencyINTEL", 17> { list availability = [ Extension<[SPV_INTEL_fpga_loop_controls]>, Capability<[SPV_C_FPGALoopControlsINTEL]> ]; } -def SPV_LC_MaxInterleavingINTEL : BitEnumAttrCase<"MaxInterleavingINTEL", 0x200000> { +def SPV_LC_MaxInterleavingINTEL : BitEnumAttrCaseBit<"MaxInterleavingINTEL", 21> { list availability = [ Extension<[SPV_INTEL_fpga_loop_controls]>, Capability<[SPV_C_FPGALoopControlsINTEL]> ]; } -def SPV_LC_DependencyArrayINTEL : BitEnumAttrCase<"DependencyArrayINTEL", 0x40000> { +def SPV_LC_DependencyArrayINTEL : BitEnumAttrCaseBit<"DependencyArrayINTEL", 18> { list availability = [ Extension<[SPV_INTEL_fpga_loop_controls]>, Capability<[SPV_C_FPGALoopControlsINTEL]> ]; } -def SPV_LC_SpeculatedIterationsINTEL : BitEnumAttrCase<"SpeculatedIterationsINTEL", 0x400000> { +def SPV_LC_SpeculatedIterationsINTEL : BitEnumAttrCaseBit<"SpeculatedIterationsINTEL", 22> { list availability = [ Extension<[SPV_INTEL_fpga_loop_controls]>, Capability<[SPV_C_FPGALoopControlsINTEL]> ]; } -def SPV_LC_PipelineEnableINTEL : BitEnumAttrCase<"PipelineEnableINTEL", 0x80000> { +def SPV_LC_PipelineEnableINTEL : BitEnumAttrCaseBit<"PipelineEnableINTEL", 19> { list availability = [ Extension<[SPV_INTEL_fpga_loop_controls]>, Capability<[SPV_C_FPGALoopControlsINTEL]> ]; } -def SPV_LC_NoFusionINTEL : BitEnumAttrCase<"NoFusionINTEL", 0x800000> { +def SPV_LC_NoFusionINTEL : BitEnumAttrCaseBit<"NoFusionINTEL", 23> { list availability = [ Extension<[SPV_INTEL_fpga_loop_controls]>, Capability<[SPV_C_FPGALoopControlsINTEL]> @@ -3555,23 +3555,23 @@ SPV_LC_PipelineEnableINTEL, SPV_LC_NoFusionINTEL ]>; -def SPV_MA_None : BitEnumAttrCase<"None", 0x0000>; -def SPV_MA_Volatile : BitEnumAttrCase<"Volatile", 0x0001>; -def SPV_MA_Aligned : BitEnumAttrCase<"Aligned", 0x0002>; -def SPV_MA_Nontemporal : BitEnumAttrCase<"Nontemporal", 0x0004>; -def SPV_MA_MakePointerAvailable : BitEnumAttrCase<"MakePointerAvailable", 0x0008> { +def SPV_MA_None : BitEnumAttrCaseNone<"None">; +def SPV_MA_Volatile : BitEnumAttrCaseBit<"Volatile", 0>; +def SPV_MA_Aligned : BitEnumAttrCaseBit<"Aligned", 1>; +def SPV_MA_Nontemporal : BitEnumAttrCaseBit<"Nontemporal", 2>; +def SPV_MA_MakePointerAvailable : BitEnumAttrCaseBit<"MakePointerAvailable", 3> { list availability = [ MinVersion, Capability<[SPV_C_VulkanMemoryModel]> ]; } -def SPV_MA_MakePointerVisible : BitEnumAttrCase<"MakePointerVisible", 0x0010> { +def SPV_MA_MakePointerVisible : BitEnumAttrCaseBit<"MakePointerVisible", 4> { list availability = [ MinVersion, Capability<[SPV_C_VulkanMemoryModel]> ]; } -def SPV_MA_NonPrivatePointer : BitEnumAttrCase<"NonPrivatePointer", 0x0020> { +def SPV_MA_NonPrivatePointer : BitEnumAttrCaseBit<"NonPrivatePointer", 5> { list availability = [ MinVersion, Capability<[SPV_C_VulkanMemoryModel]> @@ -3612,44 +3612,44 @@ SPV_MM_Simple, SPV_MM_GLSL450, SPV_MM_OpenCL, SPV_MM_Vulkan ]>; -def SPV_MS_None : BitEnumAttrCase<"None", 0x0000>; -def SPV_MS_Acquire : BitEnumAttrCase<"Acquire", 0x0002>; -def SPV_MS_Release : BitEnumAttrCase<"Release", 0x0004>; -def SPV_MS_AcquireRelease : BitEnumAttrCase<"AcquireRelease", 0x0008>; -def SPV_MS_SequentiallyConsistent : BitEnumAttrCase<"SequentiallyConsistent", 0x0010>; -def SPV_MS_UniformMemory : BitEnumAttrCase<"UniformMemory", 0x0040> { +def SPV_MS_None : BitEnumAttrCaseNone<"None">; +def SPV_MS_Acquire : BitEnumAttrCaseBit<"Acquire", 1>; +def SPV_MS_Release : BitEnumAttrCaseBit<"Release", 2>; +def SPV_MS_AcquireRelease : BitEnumAttrCaseBit<"AcquireRelease", 3>; +def SPV_MS_SequentiallyConsistent : BitEnumAttrCaseBit<"SequentiallyConsistent", 4>; +def SPV_MS_UniformMemory : BitEnumAttrCaseBit<"UniformMemory", 6> { list availability = [ Capability<[SPV_C_Shader]> ]; } -def SPV_MS_SubgroupMemory : BitEnumAttrCase<"SubgroupMemory", 0x0080>; -def SPV_MS_WorkgroupMemory : BitEnumAttrCase<"WorkgroupMemory", 0x0100>; -def SPV_MS_CrossWorkgroupMemory : BitEnumAttrCase<"CrossWorkgroupMemory", 0x0200>; -def SPV_MS_AtomicCounterMemory : BitEnumAttrCase<"AtomicCounterMemory", 0x0400> { +def SPV_MS_SubgroupMemory : BitEnumAttrCaseBit<"SubgroupMemory", 7>; +def SPV_MS_WorkgroupMemory : BitEnumAttrCaseBit<"WorkgroupMemory", 8>; +def SPV_MS_CrossWorkgroupMemory : BitEnumAttrCaseBit<"CrossWorkgroupMemory", 9>; +def SPV_MS_AtomicCounterMemory : BitEnumAttrCaseBit<"AtomicCounterMemory", 10> { list availability = [ Capability<[SPV_C_AtomicStorage]> ]; } -def SPV_MS_ImageMemory : BitEnumAttrCase<"ImageMemory", 0x0800>; -def SPV_MS_OutputMemory : BitEnumAttrCase<"OutputMemory", 0x1000> { +def SPV_MS_ImageMemory : BitEnumAttrCaseBit<"ImageMemory", 11>; +def SPV_MS_OutputMemory : BitEnumAttrCaseBit<"OutputMemory", 12> { list availability = [ MinVersion, Capability<[SPV_C_VulkanMemoryModel]> ]; } -def SPV_MS_MakeAvailable : BitEnumAttrCase<"MakeAvailable", 0x2000> { +def SPV_MS_MakeAvailable : BitEnumAttrCaseBit<"MakeAvailable", 13> { list availability = [ MinVersion, Capability<[SPV_C_VulkanMemoryModel]> ]; } -def SPV_MS_MakeVisible : BitEnumAttrCase<"MakeVisible", 0x4000> { +def SPV_MS_MakeVisible : BitEnumAttrCaseBit<"MakeVisible", 14> { list availability = [ MinVersion, Capability<[SPV_C_VulkanMemoryModel]> ]; } -def SPV_MS_Volatile : BitEnumAttrCase<"Volatile", 0x8000> { +def SPV_MS_Volatile : BitEnumAttrCaseBit<"Volatile", 15> { list availability = [ Extension<[SPV_KHR_vulkan_memory_model]>, Capability<[SPV_C_VulkanMemoryModel]> @@ -3688,9 +3688,9 @@ SPV_S_Invocation, SPV_S_QueueFamily, SPV_S_ShaderCallKHR ]>; -def SPV_SC_None : BitEnumAttrCase<"None", 0x0000>; -def SPV_SC_Flatten : BitEnumAttrCase<"Flatten", 0x0001>; -def SPV_SC_DontFlatten : BitEnumAttrCase<"DontFlatten", 0x0002>; +def SPV_SC_None : BitEnumAttrCaseNone<"None">; +def SPV_SC_Flatten : BitEnumAttrCaseBit<"Flatten", 0>; +def SPV_SC_DontFlatten : BitEnumAttrCaseBit<"DontFlatten", 1>; def SPV_SelectionControlAttr : SPV_BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", [ diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -39,17 +39,17 @@ } // The "kind" of combining function for contractions and reductions. -def COMBINING_KIND_ADD : BitEnumAttrCase<"ADD", 0x1, "add">; -def COMBINING_KIND_MUL : BitEnumAttrCase<"MUL", 0x2, "mul">; -def COMBINING_KIND_MINUI : BitEnumAttrCase<"MINUI", 0x4, "minui">; -def COMBINING_KIND_MINSI : BitEnumAttrCase<"MINSI", 0x8, "minsi">; -def COMBINING_KIND_MINF : BitEnumAttrCase<"MINF", 0x10, "minf">; -def COMBINING_KIND_MAXUI : BitEnumAttrCase<"MAXUI", 0x20, "maxui">; -def COMBINING_KIND_MAXSI : BitEnumAttrCase<"MAXSI", 0x40, "maxsi">; -def COMBINING_KIND_MAXF : BitEnumAttrCase<"MAXF", 0x80, "maxf">; -def COMBINING_KIND_AND : BitEnumAttrCase<"AND", 0x100, "and">; -def COMBINING_KIND_OR : BitEnumAttrCase<"OR", 0x200, "or">; -def COMBINING_KIND_XOR : BitEnumAttrCase<"XOR", 0x400, "xor">; +def COMBINING_KIND_ADD : BitEnumAttrCaseBit<"ADD", 0, "add">; +def COMBINING_KIND_MUL : BitEnumAttrCaseBit<"MUL", 1, "mul">; +def COMBINING_KIND_MINUI : BitEnumAttrCaseBit<"MINUI", 2, "minui">; +def COMBINING_KIND_MINSI : BitEnumAttrCaseBit<"MINSI", 3, "minsi">; +def COMBINING_KIND_MINF : BitEnumAttrCaseBit<"MINF", 4, "minf">; +def COMBINING_KIND_MAXUI : BitEnumAttrCaseBit<"MAXUI", 5, "maxui">; +def COMBINING_KIND_MAXSI : BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">; +def COMBINING_KIND_MAXF : BitEnumAttrCaseBit<"MAXF", 7, "maxf">; +def COMBINING_KIND_AND : BitEnumAttrCaseBit<"AND", 8, "and">; +def COMBINING_KIND_OR : BitEnumAttrCaseBit<"OR", 9, "or">; +def COMBINING_KIND_XOR : BitEnumAttrCaseBit<"XOR", 10, "xor">; def CombiningKind : BitEnumAttr< "CombiningKind", diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1323,12 +1323,30 @@ // A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the // ordinal number of the bit that is set. It is the 32-bit integer with only // one bit set. -class BitEnumAttrCase : - EnumAttrCaseInfo, - SignlessIntegerAttrBase { - let predicate = CPred< - "$_self.cast<::mlir::IntegerAttr>().getValue().getZExtValue() & " - # val # "u">; +class BitEnumAttrCase + : EnumAttrCaseInfo, + SignlessIntegerAttrBase; + +// The special bit enum case for no bits set (i.e. value = 0). +class BitEnumAttrCaseNone + : BitEnumAttrCase; + +// The bit enum case for a single bit, specified by the bit position. +// The pos argument refers to the index of the bit, and is currently +// limited to be in the range [0, 31]. +class BitEnumAttrCaseBit + : BitEnumAttrCase { + assert !and(!ge(pos, 0), !le(pos, 31)), + "bit position must be between 0 and 31"; +} + +// A bit enum case for a group/list of previously declared single bits, +// providing a convenient alias for that group. +class BitEnumAttrCaseGroup cases, + string str = sym> + : BitEnumAttrCase< + sym, !foldl(0, cases, value, bitcase, !or(value, bitcase.value)), + str> { } // Additional information for an enum attribute. @@ -1452,7 +1470,7 @@ // A bit enum stored with 32-bit IntegerAttr. // // Op attributes of this kind are stored as IntegerAttr. Extra verification will -// be generated on the integer to make sure only allowed bit are set. Besides, +// be generated on the integer to make sure only allowed bits are set. Besides, // helper methods are generated to parse a string separated with a specified // delimiter to a symbol and vice versa. class BitEnumAttrBase cases, string summary> : @@ -1470,6 +1488,9 @@ EnumAttrInfo> { let underlyingType = "uint32_t"; + // Determine "valid" bits from enum cases for error checking + int validBits = !foldl(0, cases, value, bitcase, !or(value, bitcase.value)); + // We need to return a string because we may concatenate symbols for multiple // bits together. let symbolToStringFnRetType = "std::string"; diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -193,6 +193,11 @@ os << formatv(" auto val = static_cast<{0}>(symbol);\n", enumAttr.getUnderlyingType()); + // If we have unknown bit set, return an empty string to signal errors. + int64_t validBits = enumDef.getValueAsInt("validBits"); + os << formatv(" assert({0}u == ({0}u | val) && \"invalid bits set in bit " + "enum\");\n", + validBits); if (allBitsUnsetCase) { os << " // Special case for all bits unset.\n"; os << formatv(" if (val == 0) return \"{0}\";\n\n", @@ -201,13 +206,11 @@ os << " ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n"; for (const auto &enumerant : enumerants) { // Skip the special enumerant for None. - if (auto val = enumerant.getValue()) - os << formatv(" if ({0}u & val) {{ strs.push_back(\"{1}\"); " - "val &= ~{0}u; }\n", - val, enumerant.getStr()); + if (int64_t val = enumerant.getValue()) + os << formatv( + " if ({0}u == ({0}u & val)) {{ strs.push_back(\"{1}\"); }\n ", val, + enumerant.getStr()); } - // If we have unknown bit set, return an empty string to signal errors. - os << "\n if (val) return \"\";\n"; os << formatv(" return ::llvm::join(strs, \"{0}\");\n", separator); os << "}\n\n"; diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp --- a/mlir/unittests/TableGen/EnumsGenTest.cpp +++ b/mlir/unittests/TableGen/EnumsGenTest.cpp @@ -68,25 +68,25 @@ TEST(EnumsGenTest, GeneratedBitEnumDefinition) { EXPECT_EQ(0u, static_cast(BitEnumWithNone::None)); - EXPECT_EQ(1u, static_cast(BitEnumWithNone::Bit1)); - EXPECT_EQ(4u, static_cast(BitEnumWithNone::Bit3)); + EXPECT_EQ(1u, static_cast(BitEnumWithNone::Bit0)); + EXPECT_EQ(8u, static_cast(BitEnumWithNone::Bit3)); } TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) { EXPECT_EQ(stringifyBitEnumWithNone(BitEnumWithNone::None), "None"); - EXPECT_EQ(stringifyBitEnumWithNone(BitEnumWithNone::Bit1), "Bit1"); + EXPECT_EQ(stringifyBitEnumWithNone(BitEnumWithNone::Bit0), "Bit0"); EXPECT_EQ(stringifyBitEnumWithNone(BitEnumWithNone::Bit3), "Bit3"); EXPECT_EQ( - stringifyBitEnumWithNone(BitEnumWithNone::Bit1 | BitEnumWithNone::Bit3), - "Bit1|Bit3"); + stringifyBitEnumWithNone(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3), + "Bit0|Bit3"); } TEST(EnumsGenTest, GeneratedStringToSymbolForBitEnum) { EXPECT_EQ(symbolizeBitEnumWithNone("None"), BitEnumWithNone::None); - EXPECT_EQ(symbolizeBitEnumWithNone("Bit1"), BitEnumWithNone::Bit1); + EXPECT_EQ(symbolizeBitEnumWithNone("Bit0"), BitEnumWithNone::Bit0); EXPECT_EQ(symbolizeBitEnumWithNone("Bit3"), BitEnumWithNone::Bit3); - EXPECT_EQ(symbolizeBitEnumWithNone("Bit3|Bit1"), - BitEnumWithNone::Bit3 | BitEnumWithNone::Bit1); + EXPECT_EQ(symbolizeBitEnumWithNone("Bit3|Bit0"), + BitEnumWithNone::Bit3 | BitEnumWithNone::Bit0); EXPECT_EQ(symbolizeBitEnumWithNone("Bit2"), llvm::None); EXPECT_EQ(symbolizeBitEnumWithNone("Bit3|Bit4"), llvm::None); @@ -94,11 +94,31 @@ EXPECT_EQ(symbolizeBitEnumWithoutNone("None"), llvm::None); } +TEST(EnumsGenTest, GeneratedSymbolToStringFnForGroupedBitEnum) { + EXPECT_EQ(stringifyBitEnumWithGroup(BitEnumWithGroup::Bit0), "Bit0"); + EXPECT_EQ(stringifyBitEnumWithGroup(BitEnumWithGroup::Bit3), "Bit3"); + EXPECT_EQ(stringifyBitEnumWithGroup(BitEnumWithGroup::Bits0To3), + "Bit0|Bit1|Bit2|Bit3|Bits0To3"); + EXPECT_EQ(stringifyBitEnumWithGroup(BitEnumWithGroup::Bit4), "Bit4"); + EXPECT_EQ(stringifyBitEnumWithGroup( + BitEnumWithGroup::Bit0 | BitEnumWithGroup::Bit1 | + BitEnumWithGroup::Bit2 | BitEnumWithGroup::Bit4), + "Bit0|Bit1|Bit2|Bit4"); +} + +TEST(EnumsGenTest, GeneratedStringToSymbolForGroupedBitEnum) { + EXPECT_EQ(symbolizeBitEnumWithGroup("Bit0"), BitEnumWithGroup::Bit0); + EXPECT_EQ(symbolizeBitEnumWithGroup("Bit3"), BitEnumWithGroup::Bit3); + EXPECT_EQ(symbolizeBitEnumWithGroup("Bit5"), llvm::None); + EXPECT_EQ(symbolizeBitEnumWithGroup("Bit3|Bit0"), + BitEnumWithGroup::Bit3 | BitEnumWithGroup::Bit0); +} + TEST(EnumsGenTest, GeneratedOperator) { - EXPECT_TRUE(bitEnumContains(BitEnumWithNone::Bit1 | BitEnumWithNone::Bit3, - BitEnumWithNone::Bit1)); - EXPECT_FALSE(bitEnumContains(BitEnumWithNone::Bit1 & BitEnumWithNone::Bit3, - BitEnumWithNone::Bit1)); + EXPECT_TRUE(bitEnumContains(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3, + BitEnumWithNone::Bit0)); + EXPECT_FALSE(bitEnumContains(BitEnumWithNone::Bit0 & BitEnumWithNone::Bit3, + BitEnumWithNone::Bit0)); } TEST(EnumsGenTest, GeneratedSymbolToCustomStringFn) { @@ -152,7 +172,11 @@ mlir::Type intType = mlir::IntegerType::get(&ctx, 32); mlir::Attribute intAttr = mlir::IntegerAttr::get( intType, - static_cast(BitEnumWithNone::Bit1 | BitEnumWithNone::Bit3)); + static_cast(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3)); EXPECT_TRUE(intAttr.isa()); EXPECT_TRUE(intAttr.isa()); + + intAttr = mlir::IntegerAttr::get( + intType, static_cast(BitEnumWithGroup::Bits0To3) | (1u << 6)); + EXPECT_FALSE(intAttr.isa()); } diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td --- a/mlir/unittests/TableGen/enums.td +++ b/mlir/unittests/TableGen/enums.td @@ -23,15 +23,25 @@ def I32Enum: I32EnumAttr<"I32Enum", "A test enum", [Case5, Case10]>; -def Bit0 : BitEnumAttrCase<"None", 0x0000>; -def Bit1 : BitEnumAttrCase<"Bit1", 0x0001>; -def Bit3 : BitEnumAttrCase<"Bit3", 0x0004>; +def NoBits : BitEnumAttrCaseNone<"None">; +def Bit0 : BitEnumAttrCaseBit<"Bit0", 0>; +def Bit1 : BitEnumAttrCaseBit<"Bit1", 1>; +def Bit2 : BitEnumAttrCaseBit<"Bit2", 2>; +def Bit3 : BitEnumAttrCaseBit<"Bit3", 3>; +def Bit4 : BitEnumAttrCaseBit<"Bit4", 4>; +def Bit5 : BitEnumAttrCaseBit<"Bit5", 5>; def BitEnumWithNone : BitEnumAttr<"BitEnumWithNone", "A test enum", - [Bit0, Bit1, Bit3]>; + [NoBits, Bit0, Bit3]>; def BitEnumWithoutNone : BitEnumAttr<"BitEnumWithoutNone", "A test enum", - [Bit1, Bit3]>; + [Bit0, Bit3]>; + +def Bits0To3 : BitEnumAttrCaseGroup<"Bits0To3", + [Bit0, Bit1, Bit2, Bit3]>; + +def BitEnumWithGroup : BitEnumAttr<"BitEnumWithGroup", "A test enum", + [Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]>; def PrettyIntEnumCase1: I32EnumAttrCase<"Case1", 1, "case_one">; def PrettyIntEnumCase2: I32EnumAttrCase<"Case2", 2, "case_two">;