diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -987,6 +987,11 @@ SPV_C_ShaderViewportIndexLayerEXT, SPV_C_ShaderViewportMaskNV, SPV_C_ShaderStereoViewNV ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeCapability(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } @@ -1014,6 +1019,11 @@ SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64, SPV_AM_PhysicalStorageBuffer64 ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeAddressingModel(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } @@ -1556,6 +1566,11 @@ SPV_BI_HitTNV, SPV_BI_HitKindNV, SPV_BI_IncomingRayFlagsNV, SPV_BI_WarpsPerSMNV, SPV_BI_SMCountNV, SPV_BI_WarpIDNV, SPV_BI_SMIDNV ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeBuiltIn(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } @@ -1888,6 +1903,11 @@ SPV_D_AliasedPointer, SPV_D_CounterBuffer, SPV_D_UserSemantic, SPV_D_UserTypeGOOGLE ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeDecoration(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } @@ -1928,6 +1948,11 @@ SPV_D_1D, SPV_D_2D, SPV_D_3D, SPV_D_Cube, SPV_D_Rect, SPV_D_Buffer, SPV_D_SubpassData ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeDim(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } @@ -2260,6 +2285,11 @@ SPV_EM_SampleInterlockOrderedEXT, SPV_EM_SampleInterlockUnorderedEXT, SPV_EM_ShadingRateInterlockOrderedEXT, SPV_EM_ShadingRateInterlockUnorderedEXT ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeExecutionMode(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } @@ -2346,6 +2376,11 @@ SPV_EM_TaskNV, SPV_EM_MeshNV, SPV_EM_RayGenerationNV, SPV_EM_IntersectionNV, SPV_EM_AnyHitNV, SPV_EM_ClosestHitNV, SPV_EM_MissNV, SPV_EM_CallableNV ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeExecutionModel(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } @@ -2359,6 +2394,11 @@ BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [ SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeFunctionControl(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } @@ -2571,6 +2611,11 @@ SPV_IF_Rgb10a2ui, SPV_IF_Rg32ui, SPV_IF_Rg16ui, SPV_IF_Rg8ui, SPV_IF_R16ui, SPV_IF_R8ui ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeImageFormat(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } @@ -2589,6 +2634,11 @@ I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", [ SPV_LT_Export, SPV_LT_Import ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeLinkageType(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } @@ -2637,6 +2687,11 @@ SPV_LC_DependencyLength, SPV_LC_MinIterations, SPV_LC_MaxIterations, SPV_LC_IterationMultiple, SPV_LC_PeelCount, SPV_LC_PartialCount ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeLoopControl(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } @@ -2669,6 +2724,11 @@ SPV_MA_MakePointerAvailable, SPV_MA_MakePointerVisible, SPV_MA_NonPrivatePointer ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeMemoryAccess(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } @@ -2698,6 +2758,11 @@ I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", [ SPV_MM_Simple, SPV_MM_GLSL450, SPV_MM_OpenCL, SPV_MM_Vulkan ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeMemoryModel(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } @@ -2754,6 +2819,11 @@ SPV_MS_AtomicCounterMemory, SPV_MS_ImageMemory, SPV_MS_OutputMemory, SPV_MS_MakeAvailable, SPV_MS_MakeVisible, SPV_MS_Volatile ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeMemorySemantics(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } @@ -2774,6 +2844,11 @@ SPV_S_CrossDevice, SPV_S_Device, SPV_S_Workgroup, SPV_S_Subgroup, SPV_S_Invocation, SPV_S_QueueFamily ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeScope(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } @@ -2785,6 +2860,11 @@ BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", [ SPV_SC_None, SPV_SC_Flatten, SPV_SC_DontFlatten ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeSelectionControl(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } @@ -2884,6 +2964,11 @@ SPV_SC_RayPayloadNV, SPV_SC_HitAttributeNV, SPV_SC_IncomingRayPayloadNV, SPV_SC_ShaderRecordBufferNV, SPV_SC_PhysicalStorageBuffer ]> { + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolizeStorageClass(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; let cppNamespace = "::mlir::spirv"; } diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -379,10 +379,17 @@ case_names = ',\n'.join(case_names) # Generate the enum attribute definition - enum_attr = 'def SPV_{name}Attr :\n '\ - '{category}EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n'\ - ' ]> {{\n'\ - ' let cppNamespace = "::mlir::spirv";\n}}'.format( + enum_attr = '''def SPV_{name}Attr : + {category}EnumAttr<"{name}", "valid SPIR-V {name}", [ +{cases} + ]> {{ + let predicate = And<[ + IntegerAttrBase.predicate, + CPred<"::mlir::spirv::symbolize{name}(" + "$_self.cast().getValue().getZExtValue()).hasValue()"> + ]>; + let cppNamespace = "::mlir::spirv"; +}}'''.format( name=kind_name, category=kind_category, cases=case_names) return kind_name, case_defs + '\n\n' + enum_attr