diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td @@ -56,7 +56,7 @@ string instance = ?; } -class MinVersionBase +class MinVersionBase : Availability { let interfaceName = name; @@ -69,13 +69,13 @@ "std::max(*$overall, $instance)); " "} else { $overall = $instance; }}"; let initializer = "::llvm::None"; - let instanceType = scheme.cppNamespace # "::" # scheme.className; + let instanceType = scheme.cppNamespace # "::" # scheme.enum.className; - let instance = scheme.cppNamespace # "::" # scheme.className # "::" # + let instance = scheme.cppNamespace # "::" # scheme.enum.className # "::" # min.symbol; } -class MaxVersionBase +class MaxVersionBase : Availability { let interfaceName = name; @@ -88,9 +88,9 @@ "std::min(*$overall, $instance)); " "} else { $overall = $instance; }}"; let initializer = "::llvm::None"; - let instanceType = scheme.cppNamespace # "::" # scheme.className; + let instanceType = scheme.cppNamespace # "::" # scheme.enum.className; - let instance = scheme.cppNamespace # "::" # scheme.className # "::" # + let instance = scheme.cppNamespace # "::" # scheme.enum.className # "::" # max.symbol; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td @@ -77,8 +77,6 @@ let results = (outs); - let autogenSerialization = 0; - let assemblyFormat = [{ $execution_scope `,` $memory_scope `,` $memory_semantics attr-dict }]; @@ -129,8 +127,6 @@ let results = (outs); - let autogenSerialization = 0; - let assemblyFormat = "$memory_scope `,` $memory_semantics attr-dict"; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -82,43 +82,31 @@ // Utility definitions //===----------------------------------------------------------------------===// -// A predicate that checks whether `$_self` is a known enum case for the -// enum class with `name`. -class SPV_IsKnownEnumCaseFor : - CPred<"::mlir::spirv::symbolize" # name # "(" - "$_self.cast().getValue().getZExtValue()).has_value()">; - // Wrapper over base BitEnumAttr to set common fields. -class SPV_BitEnumAttr cases> : - I32BitEnumAttr { - let predicate = And<[ - I32Attr.predicate, - SPV_IsKnownEnumCaseFor, - ]>; +class SPV_BitEnum cases> + : I32BitEnumAttr { + let genSpecializedAttr = 0; let cppNamespace = "::mlir::spirv"; } - -// Wrapper over base I32EnumAttr to set common fields. -class SPV_I32EnumAttr cases> : - I32EnumAttr { - let predicate = And<[ - I32Attr.predicate, - SPV_IsKnownEnumCaseFor, - ]>; - let cppNamespace = "::mlir::spirv"; +class SPV_BitEnumAttr cases> : + EnumAttr, mnemonic> { + let assemblyFormat = "`<` $value `>`"; } // Wrapper over base I32EnumAttr to set common fields. -class SPV_Enum cases> +class SPV_I32Enum cases> : I32EnumAttr { let genSpecializedAttr = 0; let cppNamespace = "::mlir::spirv"; } -class SPV_EnumAttr cases> : - EnumAttr, mnemonic>; + EnumAttr, mnemonic> { + let assemblyFormat = "`<` $value `>`"; +} //===----------------------------------------------------------------------===// // SPIR-V availability definitions @@ -132,7 +120,8 @@ def SPV_V_1_5 : I32EnumAttrCase<"V_1_5", 5, "v1.5">; def SPV_V_1_6 : I32EnumAttrCase<"V_1_6", 6, "v1.6">; -def SPV_VersionAttr : SPV_I32EnumAttr<"Version", "valid SPIR-V version", [ +def SPV_VersionAttr : SPV_I32EnumAttr< + "Version", "valid SPIR-V version", "version", [ SPV_V_1_0, SPV_V_1_1, SPV_V_1_2, SPV_V_1_3, SPV_V_1_4, SPV_V_1_5, SPV_V_1_6]>; @@ -284,7 +273,7 @@ // Information missing. def SPV_DT_Unknown : I32EnumAttrCase<"Unknown", 4>; -def SPV_DeviceTypeAttr : SPV_EnumAttr< +def SPV_DeviceTypeAttr : SPV_I32EnumAttr< "DeviceType", "valid SPIR-V device types", "device_type", [ SPV_DT_Other, SPV_DT_IntegratedGPU, SPV_DT_DiscreteGPU, SPV_DT_CPU, SPV_DT_Unknown @@ -300,7 +289,7 @@ def SPV_V_SwiftShader : I32EnumAttrCase<"SwiftShader", 7>; def SPV_V_Unknown : I32EnumAttrCase<"Unknown", 0xff>; -def SPV_VendorAttr : SPV_EnumAttr< +def SPV_VendorAttr : SPV_I32EnumAttr< "Vendor", "recognized SPIR-V vendor strings", "vendor", [ SPV_V_AMD, SPV_V_Apple, SPV_V_ARM, SPV_V_Imagination, SPV_V_Intel, SPV_V_NVIDIA, SPV_V_Qualcomm, SPV_V_SwiftShader, @@ -418,7 +407,7 @@ def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>; def SPV_ExtensionAttr : - SPV_EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [ + SPV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [ SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group, SPV_KHR_float_controls, SPV_KHR_physical_storage_buffer, SPV_KHR_multiview, SPV_KHR_no_integer_wrap_decoration, SPV_KHR_post_depth_coverage, @@ -1402,7 +1391,7 @@ } def SPV_CapabilityAttr : - SPV_I32EnumAttr<"Capability", "valid SPIR-V Capability", [ + SPV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [ SPV_C_Matrix, SPV_C_Addresses, SPV_C_Linkage, SPV_C_Kernel, SPV_C_Float16, SPV_C_Float64, SPV_C_Int64, SPV_C_Groups, SPV_C_Int16, SPV_C_Int8, SPV_C_Sampled1D, SPV_C_SampledBuffer, SPV_C_GroupNonUniform, SPV_C_ShaderLayer, @@ -1514,7 +1503,7 @@ } def SPV_AddressingModelAttr : - SPV_I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", [ + SPV_I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", "addressing_model", [ SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64, SPV_AM_PhysicalStorageBuffer64 ]>; @@ -2049,7 +2038,7 @@ } def SPV_BuiltInAttr : - SPV_I32EnumAttr<"BuiltIn", "valid SPIR-V BuiltIn", [ + SPV_I32EnumAttr<"BuiltIn", "valid SPIR-V BuiltIn", "built_in", [ SPV_BI_Position, SPV_BI_PointSize, SPV_BI_ClipDistance, SPV_BI_CullDistance, SPV_BI_VertexId, SPV_BI_InstanceId, SPV_BI_PrimitiveId, SPV_BI_InvocationId, SPV_BI_Layer, SPV_BI_ViewportIndex, SPV_BI_TessLevelOuter, @@ -2610,7 +2599,7 @@ } def SPV_DecorationAttr : - SPV_I32EnumAttr<"Decoration", "valid SPIR-V Decoration", [ + SPV_I32EnumAttr<"Decoration", "valid SPIR-V Decoration", "decoration", [ SPV_D_RelaxedPrecision, SPV_D_SpecId, SPV_D_Block, SPV_D_BufferBlock, SPV_D_RowMajor, SPV_D_ColMajor, SPV_D_ArrayStride, SPV_D_MatrixStride, SPV_D_GLSLShared, SPV_D_GLSLPacked, SPV_D_CPacked, SPV_D_BuiltIn, @@ -2679,7 +2668,7 @@ } def SPV_DimAttr : - SPV_I32EnumAttr<"Dim", "valid SPIR-V Dim", [ + SPV_I32EnumAttr<"Dim", "valid SPIR-V Dim", "dim", [ SPV_D_1D, SPV_D_2D, SPV_D_3D, SPV_D_Cube, SPV_D_Rect, SPV_D_Buffer, SPV_D_SubpassData ]>; @@ -3093,7 +3082,7 @@ } def SPV_ExecutionModeAttr : - SPV_I32EnumAttr<"ExecutionMode", "valid SPIR-V ExecutionMode", [ + SPV_I32EnumAttr<"ExecutionMode", "valid SPIR-V ExecutionMode", "execution_mode", [ SPV_EM_Invocations, SPV_EM_SpacingEqual, SPV_EM_SpacingFractionalEven, SPV_EM_SpacingFractionalOdd, SPV_EM_VertexOrderCw, SPV_EM_VertexOrderCcw, SPV_EM_PixelCenterInteger, SPV_EM_OriginUpperLeft, SPV_EM_OriginLowerLeft, @@ -3203,7 +3192,7 @@ } def SPV_ExecutionModelAttr : - SPV_I32EnumAttr<"ExecutionModel", "valid SPIR-V ExecutionModel", [ + SPV_I32EnumAttr<"ExecutionModel", "valid SPIR-V ExecutionModel", "execution_model", [ SPV_EM_Vertex, SPV_EM_TessellationControl, SPV_EM_TessellationEvaluation, SPV_EM_Geometry, SPV_EM_Fragment, SPV_EM_GLCompute, SPV_EM_Kernel, SPV_EM_TaskNV, SPV_EM_MeshNV, SPV_EM_RayGenerationKHR, SPV_EM_IntersectionKHR, @@ -3222,7 +3211,7 @@ } def SPV_FunctionControlAttr : - SPV_BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [ + SPV_BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", "function_control", [ SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const, SPV_FC_OptNoneINTEL ]>; @@ -3268,7 +3257,7 @@ } def SPV_GroupOperationAttr : - SPV_I32EnumAttr<"GroupOperation", "valid SPIR-V GroupOperation", [ + SPV_I32EnumAttr<"GroupOperation", "valid SPIR-V GroupOperation", "group_operation", [ SPV_GO_Reduce, SPV_GO_InclusiveScan, SPV_GO_ExclusiveScan, SPV_GO_ClusteredReduce, SPV_GO_PartitionedReduceNV, SPV_GO_PartitionedInclusiveScanNV, SPV_GO_PartitionedExclusiveScanNV @@ -3482,7 +3471,7 @@ } def SPV_ImageFormatAttr : - SPV_I32EnumAttr<"ImageFormat", "valid SPIR-V ImageFormat", [ + SPV_I32EnumAttr<"ImageFormat", "valid SPIR-V ImageFormat", "image_format", [ SPV_IF_Unknown, SPV_IF_Rgba32f, SPV_IF_Rgba16f, SPV_IF_R32f, SPV_IF_Rgba8, SPV_IF_Rgba8Snorm, SPV_IF_Rg32f, SPV_IF_Rg16f, SPV_IF_R11fG11fB10f, SPV_IF_R16f, SPV_IF_Rgba16, SPV_IF_Rgb10A2, SPV_IF_Rg16, SPV_IF_Rg8, @@ -3561,7 +3550,7 @@ } def SPV_ImageOperandsAttr : - SPV_BitEnumAttr<"ImageOperands", "valid SPIR-V ImageOperands", [ + SPV_BitEnumAttr<"ImageOperands", "valid SPIR-V ImageOperands", "image_operands", [ SPV_IO_None, SPV_IO_Bias, SPV_IO_Lod, SPV_IO_Grad, SPV_IO_ConstOffset, SPV_IO_Offset, SPV_IO_ConstOffsets, SPV_IO_Sample, SPV_IO_MinLod, SPV_IO_MakeTexelAvailable, SPV_IO_MakeTexelVisible, SPV_IO_NonPrivateTexel, @@ -3587,7 +3576,7 @@ } def SPV_LinkageTypeAttr : - SPV_I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", [ + SPV_I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", "linkage_type", [ SPV_LT_Export, SPV_LT_Import, SPV_LT_LinkOnceODR ]>; @@ -3679,7 +3668,7 @@ } def SPV_LoopControlAttr : - SPV_BitEnumAttr<"LoopControl", "valid SPIR-V LoopControl", [ + SPV_BitEnumAttr<"LoopControl", "valid SPIR-V LoopControl", "loop_control", [ SPV_LC_None, SPV_LC_Unroll, SPV_LC_DontUnroll, SPV_LC_DependencyInfinite, SPV_LC_DependencyLength, SPV_LC_MinIterations, SPV_LC_MaxIterations, SPV_LC_IterationMultiple, SPV_LC_PeelCount, SPV_LC_PartialCount, @@ -3725,7 +3714,7 @@ } def SPV_MemoryAccessAttr : - SPV_BitEnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", [ + SPV_BitEnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", "memory_access", [ SPV_MA_None, SPV_MA_Volatile, SPV_MA_Aligned, SPV_MA_Nontemporal, SPV_MA_MakePointerAvailable, SPV_MA_MakePointerVisible, SPV_MA_NonPrivatePointer, SPV_MA_AliasScopeINTELMask, SPV_MA_NoAliasINTELMask @@ -3754,7 +3743,7 @@ } def SPV_MemoryModelAttr : - SPV_I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", [ + SPV_I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", "memory_model", [ SPV_MM_Simple, SPV_MM_GLSL450, SPV_MM_OpenCL, SPV_MM_Vulkan ]>; @@ -3803,7 +3792,7 @@ } def SPV_MemorySemanticsAttr : - SPV_BitEnumAttr<"MemorySemantics", "valid SPIR-V MemorySemantics", [ + SPV_BitEnumAttr<"MemorySemantics", "valid SPIR-V MemorySemantics", "memory_semantics", [ SPV_MS_None, SPV_MS_Acquire, SPV_MS_Release, SPV_MS_AcquireRelease, SPV_MS_SequentiallyConsistent, SPV_MS_UniformMemory, SPV_MS_SubgroupMemory, SPV_MS_WorkgroupMemory, SPV_MS_CrossWorkgroupMemory, @@ -3829,7 +3818,7 @@ } def SPV_ScopeAttr : - SPV_I32EnumAttr<"Scope", "valid SPIR-V Scope", [ + SPV_I32EnumAttr<"Scope", "valid SPIR-V Scope", "scope", [ SPV_S_CrossDevice, SPV_S_Device, SPV_S_Workgroup, SPV_S_Subgroup, SPV_S_Invocation, SPV_S_QueueFamily, SPV_S_ShaderCallKHR ]>; @@ -3839,7 +3828,7 @@ def SPV_SC_DontFlatten : I32BitEnumAttrCaseBit<"DontFlatten", 1>; def SPV_SelectionControlAttr : - SPV_BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", [ + SPV_BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", "selection_control", [ SPV_SC_None, SPV_SC_Flatten, SPV_SC_DontFlatten ]>; @@ -3947,7 +3936,7 @@ } def SPV_StorageClassAttr : - SPV_I32EnumAttr<"StorageClass", "valid SPIR-V StorageClass", [ + SPV_I32EnumAttr<"StorageClass", "valid SPIR-V StorageClass", "storage_class", [ SPV_SC_UniformConstant, SPV_SC_Input, SPV_SC_Uniform, SPV_SC_Output, SPV_SC_Workgroup, SPV_SC_CrossWorkgroup, SPV_SC_Private, SPV_SC_Function, SPV_SC_Generic, SPV_SC_PushConstant, SPV_SC_AtomicCounter, SPV_SC_Image, @@ -3965,34 +3954,32 @@ def SPV_IDI_IsDepth : I32EnumAttrCase<"IsDepth", 1>; def SPV_IDI_DepthUnknown : I32EnumAttrCase<"DepthUnknown", 2>; -def SPV_DepthAttr : - SPV_I32EnumAttr<"ImageDepthInfo", "valid SPIR-V Image Depth specification", - [SPV_IDI_NoDepth, SPV_IDI_IsDepth, SPV_IDI_DepthUnknown]>; +def SPV_DepthAttr : SPV_I32EnumAttr< + "ImageDepthInfo", "valid SPIR-V Image Depth specification", + "image_depth_info", [SPV_IDI_NoDepth, SPV_IDI_IsDepth, SPV_IDI_DepthUnknown]>; def SPV_IAI_NonArrayed : I32EnumAttrCase<"NonArrayed", 0>; def SPV_IAI_Arrayed : I32EnumAttrCase<"Arrayed", 1>; -def SPV_ArrayedAttr : - SPV_I32EnumAttr< - "ImageArrayedInfo", "valid SPIR-V Image Arrayed specification", - [SPV_IAI_NonArrayed, SPV_IAI_Arrayed]>; +def SPV_ArrayedAttr : SPV_I32EnumAttr< + "ImageArrayedInfo", "valid SPIR-V Image Arrayed specification", + "image_arrayed_info", [SPV_IAI_NonArrayed, SPV_IAI_Arrayed]>; def SPV_ISI_SingleSampled : I32EnumAttrCase<"SingleSampled", 0>; def SPV_ISI_MultiSampled : I32EnumAttrCase<"MultiSampled", 1>; -def SPV_SamplingAttr: - SPV_I32EnumAttr< - "ImageSamplingInfo", "valid SPIR-V Image Sampling specification", - [SPV_ISI_SingleSampled, SPV_ISI_MultiSampled]>; +def SPV_SamplingAttr: SPV_I32EnumAttr< + "ImageSamplingInfo", "valid SPIR-V Image Sampling specification", + "image_sampling_info", [SPV_ISI_SingleSampled, SPV_ISI_MultiSampled]>; def SPV_ISUI_SamplerUnknown : I32EnumAttrCase<"SamplerUnknown", 0>; def SPV_ISUI_NeedSampler : I32EnumAttrCase<"NeedSampler", 1>; def SPV_ISUI_NoSampler : I32EnumAttrCase<"NoSampler", 2>; -def SPV_SamplerUseAttr: - SPV_I32EnumAttr< - "ImageSamplerUseInfo", "valid SPIR-V Sampler Use specification", - [SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]>; +def SPV_SamplerUseAttr: SPV_I32EnumAttr< + "ImageSamplerUseInfo", "valid SPIR-V Sampler Use specification", + "image_sampler_use_info", + [SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]>; //===----------------------------------------------------------------------===// // SPIR-V attribute definitions @@ -4326,7 +4313,7 @@ def SPV_OC_OpAtomicFAddEXT : I32EnumAttrCase<"OpAtomicFAddEXT", 6035>; def SPV_OpcodeAttr : - SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ + SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [ SPV_OC_OpNop, SPV_OC_OpUndef, SPV_OC_OpSourceContinued, SPV_OC_OpSource, SPV_OC_OpSourceExtension, SPV_OC_OpName, SPV_OC_OpMemberName, SPV_OC_OpString, SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst, diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -113,6 +113,9 @@ // Returns the dialect for the attribute if defined. Dialect getDialect() const; + + // Returns the TableGen definition this Attribute was constructed from. + const llvm::Record &getDef() const; }; // Wrapper class providing helper methods for accessing MLIR constant attribute diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp @@ -15,8 +15,8 @@ #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/StringExtras.h" diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -15,6 +15,7 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" #include "mlir/IR/BuiltinOps.h" @@ -643,15 +644,15 @@ // this entry point's execution mode. We set it to be: // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode} ModuleOp module = op->getParentOfType(); - IntegerAttr executionModeAttr = op.execution_modeAttr(); + spirv::ExecutionModeAttr executionModeAttr = op.execution_modeAttr(); std::string moduleName; if (module.getName().has_value()) moduleName = "_" + module.getName().value().str(); else moduleName = ""; - std::string executionModeInfoName = - llvm::formatv("__spv_{0}_{1}_execution_mode_info_{2}", moduleName, - op.fn().str(), executionModeAttr.getValue()); + std::string executionModeInfoName = llvm::formatv( + "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.fn().str(), + static_cast(executionModeAttr.getValue())); MLIRContext *context = rewriter.getContext(); OpBuilder::InsertionGuard guard(rewriter); @@ -684,8 +685,10 @@ // Initialize the struct and set the execution mode value. rewriter.setInsertionPoint(block, block->begin()); Value structValue = rewriter.create(loc, structType); - Value executionMode = - rewriter.create(loc, llvmI32Type, executionModeAttr); + Value executionMode = rewriter.create( + loc, llvmI32Type, + rewriter.getI32IntegerAttr( + static_cast(executionModeAttr.getValue()))); structValue = rewriter.create( loc, structType, structValue, executionMode, ArrayAttr::get(context, @@ -1391,8 +1394,8 @@ auto llvmI32Type = IntegerType::get(context, 32); Value targetOp = rewriter.create(loc, dstType); for (unsigned i = 0; i < componentsArray.size(); i++) { - if (componentsArray[i].isa()) - op.emitError("unable to support non-constant component"); + if (!componentsArray[i].isa()) + return op.emitError("unable to support non-constant component"); int indexVal = componentsArray[i].cast().getInt(); if (indexVal == -1) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/SPIRV/IR/ParserUtils.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" @@ -175,19 +176,16 @@ NamedAttrList attr; auto loc = parser.getCurrentLocation(); if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), - attrName, attr)) { + attrName, attr)) return failure(); - } - if (!attrVal.isa()) { + if (!attrVal.isa()) return parser.emitError(loc, "expected ") << attrName << " attribute specified as string"; - } auto attrOptional = spirv::symbolizeEnum(attrVal.cast().getValue()); - if (!attrOptional) { + if (!attrOptional) return parser.emitError(loc, "invalid ") << attrName << " attribute specification: " << attrVal; - } value = *attrOptional; return success(); } @@ -195,50 +193,52 @@ /// Parses the next string attribute in `parser` as an enumerant of the given /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer /// attribute with the enum class's name as attribute name. -template +template static ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state, StringRef attrName = spirv::attributeName()) { - if (parseEnumStrAttr(value, parser)) { + if (parseEnumStrAttr(value, parser)) return failure(); - } - state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr( - llvm::bit_cast(value))); + state.addAttribute(attrName, + parser.getBuilder().getAttr(value)); return success(); } /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass` /// and inserts the enumerant into `state` as an 32-bit integer attribute with /// the enum class's name as attribute name. -template +template static ParseResult parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser, OperationState &state, StringRef attrName = spirv::attributeName()) { - if (parseEnumKeywordAttr(value, parser)) { + if (parseEnumKeywordAttr(value, parser)) return failure(); - } - state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr( - llvm::bit_cast(value))); + state.addAttribute(attrName, + parser.getBuilder().getAttr(value)); return success(); } /// Parses Function, Selection and Loop control attributes. If no control is /// specified, "None" is used as a default. -template +template static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName = spirv::attributeName()) { if (succeeded(parser.parseOptionalKeyword(kControl))) { EnumClass control; - if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) || + if (parser.parseLParen() || + parseEnumKeywordAttr(control, parser, state) || parser.parseRParen()) return failure(); return success(); } // Set control to "None" otherwise. Builder builder = parser.getBuilder(); - state.addAttribute(attrName, builder.getI32IntegerAttr(0)); + state.addAttribute(attrName, + builder.getAttr(static_cast(0))); return success(); } @@ -257,10 +257,9 @@ } spirv::MemoryAccess memoryAccessAttr; - if (parseEnumStrAttr(memoryAccessAttr, parser, state, - kMemoryAccessAttrName)) { + if (parseEnumStrAttr(memoryAccessAttr, parser, state, + kMemoryAccessAttrName)) return failure(); - } if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) { // Parse integer attribute for alignment. @@ -288,10 +287,9 @@ } spirv::MemoryAccess memoryAccessAttr; - if (parseEnumStrAttr(memoryAccessAttr, parser, state, - kSourceMemoryAccessAttrName)) { + if (parseEnumStrAttr(memoryAccessAttr, parser, state, + kSourceMemoryAccessAttrName)) return failure(); - } if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) { // Parse integer attribute for alignment. @@ -480,15 +478,15 @@ return success(); } - auto memAccessVal = memAccessAttr.template cast(); - auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt()); + auto memAccess = memAccessAttr.template cast(); if (!memAccess) { return memoryOp.emitOpError("invalid memory access specifier: ") - << memAccessVal; + << memAccessAttr; } - if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { + if (spirv::bitEnumContains(memAccess.getValue(), + spirv::MemoryAccess::Aligned)) { if (!op->getAttr(kAlignmentAttrName)) { return memoryOp.emitOpError("missing alignment value"); } @@ -524,15 +522,15 @@ return success(); } - auto memAccessVal = memAccessAttr.template cast(); - auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt()); + auto memAccess = memAccessAttr.template cast(); if (!memAccess) { return memoryOp.emitOpError("invalid memory access specifier: ") - << memAccessVal; + << memAccess; } - if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { + if (spirv::bitEnumContains(memAccess.getValue(), + spirv::MemoryAccess::Aligned)) { if (!op->getAttr(kSourceAlignmentAttrName)) { return memoryOp.emitOpError("missing alignment value"); } @@ -771,8 +769,10 @@ OpAsmParser::UnresolvedOperand ptrInfo, valueInfo; Type type; SMLoc loc; - if (parseEnumStrAttr(scope, parser, state, kMemoryScopeAttrName) || - parseEnumStrAttr(memoryScope, parser, state, kSemanticsAttrName) || + if (parseEnumStrAttr(scope, parser, state, + kMemoryScopeAttrName) || + parseEnumStrAttr(memoryScope, parser, state, + kSemanticsAttrName) || parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) || parser.getCurrentLocation(&loc) || parser.parseColonType(type)) return failure(); @@ -794,14 +794,11 @@ // Prints an atomic update op. static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) { printer << " \""; - auto scopeAttr = op->getAttrOfType(kMemoryScopeAttrName); - printer << spirv::stringifyScope( - static_cast(scopeAttr.getInt())) - << "\" \""; - auto memorySemanticsAttr = op->getAttrOfType(kSemanticsAttrName); - printer << spirv::stringifyMemorySemantics( - static_cast( - memorySemanticsAttr.getInt())) + auto scopeAttr = op->getAttrOfType(kMemoryScopeAttrName); + printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \""; + auto memorySemanticsAttr = + op->getAttrOfType(kSemanticsAttrName); + printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue()) << "\" " << op->getOperands() << " : " << op->getOperand(0).getType(); } @@ -835,8 +832,9 @@ "pointer operand's pointee type ") << elementType << ", but found " << valueType; } - auto memorySemantics = static_cast( - op->getAttrOfType(kSemanticsAttrName).getInt()); + auto memorySemantics = + op->getAttrOfType(kSemanticsAttrName) + .getValue(); if (failed(verifyMemorySemantics(op, memorySemantics))) { return failure(); } @@ -848,10 +846,10 @@ spirv::Scope executionScope; spirv::GroupOperation groupOperation; OpAsmParser::UnresolvedOperand valueInfo; - if (parseEnumStrAttr(executionScope, parser, state, - kExecutionScopeAttrName) || - parseEnumStrAttr(groupOperation, parser, state, - kGroupOperationAttrName) || + if (parseEnumStrAttr(executionScope, parser, state, + kExecutionScopeAttrName) || + parseEnumStrAttr(groupOperation, parser, state, + kGroupOperationAttrName) || parser.parseOperand(valueInfo)) return failure(); @@ -881,15 +879,17 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp, OpAsmPrinter &printer) { - printer << " \"" - << stringifyScope(static_cast( - groupOp->getAttrOfType(kExecutionScopeAttrName) - .getInt())) - << "\" \"" - << stringifyGroupOperation(static_cast( - groupOp->getAttrOfType(kGroupOperationAttrName) - .getInt())) - << "\" " << groupOp->getOperand(0); + printer + << " \"" + << stringifyScope( + groupOp->getAttrOfType(kExecutionScopeAttrName) + .getValue()) + << "\" \"" + << stringifyGroupOperation(groupOp + ->getAttrOfType( + kGroupOperationAttrName) + .getValue()) + << "\" " << groupOp->getOperand(0); if (groupOp->getNumOperands() > 1) printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')'; @@ -897,14 +897,16 @@ } static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) { - spirv::Scope scope = static_cast( - groupOp->getAttrOfType(kExecutionScopeAttrName).getInt()); + spirv::Scope scope = + groupOp->getAttrOfType(kExecutionScopeAttrName) + .getValue(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) return groupOp->emitOpError( "execution scope must be 'Workgroup' or 'Subgroup'"); - spirv::GroupOperation operation = static_cast( - groupOp->getAttrOfType(kGroupOperationAttrName).getInt()); + spirv::GroupOperation operation = + groupOp->getAttrOfType(kGroupOperationAttrName) + .getValue(); if (operation == spirv::GroupOperation::ClusteredReduce && groupOp->getNumOperands() == 1) return groupOp->emitOpError("cluster size operand must be provided for " @@ -1146,11 +1148,12 @@ spirv::MemorySemantics equalSemantics, unequalSemantics; SmallVector operandInfo; Type type; - if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) || - parseEnumStrAttr(equalSemantics, parser, state, - kEqualSemanticsAttrName) || - parseEnumStrAttr(unequalSemantics, parser, state, - kUnequalSemanticsAttrName) || + if (parseEnumStrAttr(memoryScope, parser, state, + kMemoryScopeAttrName) || + parseEnumStrAttr( + equalSemantics, parser, state, kEqualSemanticsAttrName) || + parseEnumStrAttr( + unequalSemantics, parser, state, kUnequalSemanticsAttrName) || parser.parseOperandList(operandInfo, 3)) return failure(); @@ -1268,8 +1271,10 @@ spirv::MemorySemantics semantics; SmallVector operandInfo; Type type; - if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) || - parseEnumStrAttr(semantics, parser, state, kSemanticsAttrName) || + if (parseEnumStrAttr(memoryScope, parser, state, + kMemoryScopeAttrName) || + parseEnumStrAttr(semantics, parser, state, + kSemanticsAttrName) || parser.parseOperandList(operandInfo, 2)) return failure(); @@ -2076,7 +2081,7 @@ SmallVector interfaceVars; FlatSymbolRefAttr fn; - if (parseEnumStrAttr(execModel, parser, state) || + if (parseEnumStrAttr(execModel, parser, state) || parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) { return failure(); } @@ -2133,7 +2138,7 @@ spirv::ExecutionMode execMode; Attribute fn; if (parser.parseAttribute(fn, kFnNameAttrName, state.attributes) || - parseEnumStrAttr(execMode, parser, state)) { + parseEnumStrAttr(execMode, parser, state)) { return failure(); } @@ -2221,7 +2226,7 @@ // Parse the optional function control keyword. spirv::FunctionControl fnControl; - if (parseEnumStrAttr(fnControl, parser, state)) + if (parseEnumStrAttr(fnControl, parser, state)) return failure(); // If additional attributes are present, parse them. @@ -2308,7 +2313,7 @@ builder.getStringAttr(name)); state.addAttribute(getTypeAttrName(), TypeAttr::get(type)); state.addAttribute(spirv::attributeName(), - builder.getI32IntegerAttr(static_cast(control))); + builder.getAttr(control)); state.attributes.append(attrs.begin(), attrs.end()); state.addRegion(); } @@ -2997,14 +3002,14 @@ //===----------------------------------------------------------------------===// void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) { - state.addAttribute("loop_control", - builder.getI32IntegerAttr( - static_cast(spirv::LoopControl::None))); + state.addAttribute("loop_control", builder.getAttr( + spirv::LoopControl::None)); state.addRegion(); } ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &state) { - if (parseControlAttribute(parser, state)) + if (parseControlAttribute(parser, + state)) return failure(); return parser.parseRegion(*state.addRegion(), /*arguments=*/{}, /*argTypes=*/{}); @@ -3196,9 +3201,9 @@ Optional name) { state.addAttribute( "addressing_model", - builder.getI32IntegerAttr(static_cast(addressingModel))); - state.addAttribute("memory_model", builder.getI32IntegerAttr( - static_cast(memoryModel))); + builder.getAttr(addressingModel)); + state.addAttribute("memory_model", + builder.getAttr(memoryModel)); OpBuilder::InsertionGuard guard(builder); builder.createBlock(state.addRegion()); if (vceTriple) @@ -3219,8 +3224,10 @@ // Parse attributes spirv::AddressingModel addrModel; spirv::MemoryModel memoryModel; - if (::parseEnumKeywordAttr(addrModel, parser, state) || - ::parseEnumKeywordAttr(memoryModel, parser, state)) + if (::parseEnumKeywordAttr(addrModel, parser, + state) || + ::parseEnumKeywordAttr(memoryModel, parser, + state)) return failure(); if (succeeded(parser.parseOptionalKeyword("requires"))) { @@ -3401,7 +3408,8 @@ ParseResult spirv::SelectionOp::parse(OpAsmParser &parser, OperationState &state) { - if (parseControlAttribute(parser, state)) + if (parseControlAttribute(parser, state)) return failure(); return parser.parseRegion(*state.addRegion(), /*arguments=*/{}, /*argTypes=*/{}); @@ -3666,8 +3674,8 @@ return failure(); } - auto attr = parser.getBuilder().getI32IntegerAttr( - llvm::bit_cast(ptrType.getStorageClass())); + auto attr = parser.getBuilder().getAttr( + ptrType.getStorageClass()); state.addAttribute(spirv::attributeName(), attr); return success(); diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -132,6 +132,8 @@ return Dialect(nullptr); } +const llvm::Record &Attribute::getDef() const { return *def; } + ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) { assert(def->isSubClassOf("ConstantAttr") && "must be subclass of TableGen 'ConstantAttr' class"); diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -12,6 +12,7 @@ #include "Deserializer.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" @@ -406,35 +407,6 @@ return success(); } -template <> -LogicalResult -Deserializer::processOp(ArrayRef operands) { - if (operands.size() != 3) { - return emitError( - unknownLoc, - "OpControlBarrier must have execution scope , memory scope " - "and memory semantics "); - } - - SmallVector argAttrs; - for (auto operand : operands) { - auto argAttr = getConstantInt(operand); - if (!argAttr) { - return emitError(unknownLoc, - "expected 32-bit integer constant from ") - << operand << " for OpControlBarrier"; - } - argAttrs.push_back(argAttr); - } - - opBuilder.create( - unknownLoc, argAttrs[0].cast(), - argAttrs[1].cast(), - argAttrs[2].cast()); - - return success(); -} - template <> LogicalResult Deserializer::processOp(ArrayRef operands) { @@ -477,31 +449,6 @@ return success(); } -template <> -LogicalResult -Deserializer::processOp(ArrayRef operands) { - if (operands.size() != 2) { - return emitError(unknownLoc, "OpMemoryBarrier must have memory scope " - "and memory semantics "); - } - - SmallVector argAttrs; - for (auto operand : operands) { - auto argAttr = getConstantInt(operand); - if (!argAttr) { - return emitError(unknownLoc, - "expected 32-bit integer constant from ") - << operand << " for OpMemoryBarrier"; - } - argAttrs.push_back(argAttr); - } - - opBuilder.create( - unknownLoc, argAttrs[0].cast(), - argAttrs[1].cast()); - return success(); -} - template <> LogicalResult Deserializer::processOp(ArrayRef words) { @@ -538,8 +485,9 @@ if (wordIndex < words.size()) { auto attrValue = words[wordIndex++]; - attributes.push_back(opBuilder.getNamedAttr( - "memory_access", opBuilder.getI32IntegerAttr(attrValue))); + auto attr = opBuilder.getAttr( + static_cast(attrValue)); + attributes.push_back(opBuilder.getNamedAttr("memory_access", attr)); isAlignedAttr = (attrValue == 2); } @@ -549,9 +497,10 @@ } if (wordIndex < words.size()) { - attributes.push_back(opBuilder.getNamedAttr( - "source_memory_access", - opBuilder.getI32IntegerAttr(words[wordIndex++]))); + auto attrValue = words[wordIndex++]; + auto attr = opBuilder.getAttr( + static_cast(attrValue)); + attributes.push_back(opBuilder.getNamedAttr("source_memory_access", attr)); } if (wordIndex < words.size()) { diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -216,10 +216,11 @@ (*module)->setAttr( "addressing_model", - opBuilder.getI32IntegerAttr(llvm::bit_cast(operands.front()))); - (*module)->setAttr( - "memory_model", - opBuilder.getI32IntegerAttr(llvm::bit_cast(operands.back()))); + opBuilder.getAttr( + static_cast(operands.front()))); + (*module)->setAttr("memory_model", + opBuilder.getAttr( + static_cast(operands.back()))); return success(); } diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -13,6 +13,7 @@ #include "Serializer.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/IR/RegionGraphTraits.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" @@ -277,8 +278,8 @@ operands.push_back(resultID); auto attr = op->getAttr(spirv::attributeName()); if (attr) { - operands.push_back(static_cast( - attr.cast().getValue().getZExtValue())); + operands.push_back( + static_cast(attr.cast().getValue())); } elidedAttrs.push_back(spirv::attributeName()); for (auto arg : op.getODSOperands(0)) { @@ -565,27 +566,6 @@ return success(); } -template <> -LogicalResult -Serializer::processOp(spirv::ControlBarrierOp op) { - StringRef argNames[] = {"execution_scope", "memory_scope", - "memory_semantics"}; - SmallVector operands; - - for (auto argName : argNames) { - auto argIntAttr = op->getAttrOfType(argName); - auto operand = prepareConstantInt(op.getLoc(), argIntAttr); - if (!operand) { - return failure(); - } - operands.push_back(operand); - } - - encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier, - operands); - return success(); -} - template <> LogicalResult Serializer::processOp(spirv::ExecutionModeOp op) { @@ -615,25 +595,6 @@ return success(); } -template <> -LogicalResult -Serializer::processOp(spirv::MemoryBarrierOp op) { - StringRef argNames[] = {"memory_scope", "memory_semantics"}; - SmallVector operands; - - for (auto argName : argNames) { - auto argIntAttr = op->getAttrOfType(argName); - auto operand = prepareConstantInt(op.getLoc(), argIntAttr); - if (!operand) { - return failure(); - } - operands.push_back(operand); - } - - encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier, operands); - return success(); -} - template <> LogicalResult Serializer::processOp(spirv::FunctionCallOp op) { @@ -674,8 +635,8 @@ } if (auto attr = op->getAttr("memory_access")) { - operands.push_back(static_cast( - attr.cast().getValue().getZExtValue())); + operands.push_back( + static_cast(attr.cast().getValue())); } elidedAttrs.push_back("memory_access"); @@ -688,8 +649,8 @@ elidedAttrs.push_back("alignment"); if (auto attr = op->getAttr("source_memory_access")) { - operands.push_back(static_cast( - attr.cast().getValue().getZExtValue())); + operands.push_back( + static_cast(attr.cast().getValue())); } elidedAttrs.push_back("source_memory_access"); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" @@ -23,6 +24,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" #include "llvm/Support/Debug.h" +#include #define DEBUG_TYPE "spirv-serialization" @@ -192,8 +194,11 @@ } void Serializer::processMemoryModel() { - uint32_t mm = module->getAttrOfType("memory_model").getInt(); - uint32_t am = module->getAttrOfType("addressing_model").getInt(); + uint32_t mm = static_cast( + module->getAttrOfType("memory_model").getValue()); + uint32_t am = static_cast( + module->getAttrOfType("addressing_model") + .getValue()); encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm}); } diff --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir --- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir @@ -112,7 +112,7 @@ // CHECK-LABEL: spv.func @barrier gpu.func @barrier(%arg0 : f32, %arg1 : memref<12xf32>) kernel attributes {spv.entry_point_abi = #spv.entry_point_abi: vector<3xi32>>} { - // CHECK: spv.ControlBarrier Workgroup, Workgroup, "AcquireRelease|WorkgroupMemory" + // CHECK: spv.ControlBarrier , , gpu.barrier gpu.return } diff --git a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir --- a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir +++ b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir @@ -32,7 +32,7 @@ // CHECK: %[[ADD:.+]] = spv.GroupNonUniformIAdd "Subgroup" "Reduce" %[[VAL]] : i32 // CHECK: %[[OUTPTR:.+]] = spv.AccessChain %[[OUTPUT]][%[[ZERO]], %[[ZERO]]] -// CHECK: %[[ELECT:.+]] = spv.GroupNonUniformElect Subgroup : i1 +// CHECK: %[[ELECT:.+]] = spv.GroupNonUniformElect : i1 // CHECK: spv.mlir.selection { // CHECK: spv.BranchConditional %[[ELECT]], ^bb1, ^bb2 diff --git a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir @@ -1,32 +1,30 @@ // RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class='client-api=vulkan' -verify-diagnostics %s -o - | FileCheck %s --check-prefix=VULKAN // Vulkan Mappings: -// 0 -> StorageBuffer (12) -// 1 -> Generic (8) -// 3 -> Workgroup (4) -// 4 -> Uniform (2) -// TODO: create a StorageClass wrapper class so we can print the symbolc -// storage class (instead of the backing IntegerAttr) and be able to -// round trip the IR. +// 0 -> StorageBuffer +// 1 -> Generic +// 2 -> [null] +// 3 -> Workgroup +// 4 -> Uniform // VULKAN-LABEL: func @operand_result func.func @operand_result() { - // VULKAN: memref + // VULKAN: memref> %0 = "dialect.memref_producer"() : () -> (memref) - // VULKAN: memref<4xi32, 8 : i32> + // VULKAN: memref<4xi32, #spv.storage_class> %1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>) - // VULKAN: memref + // VULKAN: memref> %2 = "dialect.memref_producer"() : () -> (memref) - // VULKAN: memref<*xf16, 2 : i32> + // VULKAN: memref<*xf16, #spv.storage_class> %3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>) "dialect.memref_consumer"(%0) : (memref) -> () - // VULKAN: memref<4xi32, 8 : i32> + // VULKAN: memref<4xi32, #spv.storage_class> "dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> () - // VULKAN: memref + // VULKAN: memref> "dialect.memref_consumer"(%2) : (memref) -> () - // VULKAN: memref<*xf16, 2 : i32> + // VULKAN: memref<*xf16, #spv.storage_class> "dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> () return @@ -36,7 +34,7 @@ // VULKAN-LABEL: func @type_attribute func.func @type_attribute() { - // VULKAN: attr = memref + // VULKAN: attr = memref> "dialect.memref_producer"() { attr = memref } : () -> () return } @@ -45,9 +43,9 @@ // VULKAN-LABEL: func @function_io func.func @function_io - // VULKAN-SAME: (%{{.+}}: memref, %{{.+}}: memref<4xi32, 4 : i32>) + // VULKAN-SAME: (%{{.+}}: memref>, %{{.+}}: memref<4xi32, #spv.storage_class>) (%arg0: memref, %arg1: memref<4xi32, 3>) - // VULKAN-SAME: -> (memref, memref<4xi32, 4 : i32>) + // VULKAN-SAME: -> (memref>, memref<4xi32, #spv.storage_class>) -> (memref, memref<4xi32, 3>) { return %arg0, %arg1: memref, memref<4xi32, 3> } @@ -57,8 +55,8 @@ // VULKAN: func @region func.func @region(%cond: i1, %arg0: memref) { scf.if %cond { - // VULKAN: "dialect.memref_consumer"(%{{.+}}) {attr = memref} - // VULKAN-SAME: (memref) -> memref + // VULKAN: "dialect.memref_consumer"(%{{.+}}) {attr = memref>} + // VULKAN-SAME: (memref>) -> memref> %0 = "dialect.memref_consumer"(%arg0) { attr = memref } : (memref) -> (memref) } return diff --git a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir @@ -14,7 +14,7 @@ func.func @atomic_and(%ptr : !spv.ptr, %value : i32) -> i32 { // expected-error @+1 {{pointer operand must point to an integer value, found 'f32'}} - %0 = "spv.AtomicAnd"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4 : i32} : (!spv.ptr, i32) -> (i32) + %0 = "spv.AtomicAnd"(%ptr, %value) {memory_scope = #spv.scope, semantics = #spv.memory_semantics} : (!spv.ptr, i32) -> (i32) return %0 : i32 } @@ -23,7 +23,7 @@ func.func @atomic_and(%ptr : !spv.ptr, %value : i64) -> i64 { // expected-error @+1 {{expected value to have the same type as the pointer operand's pointee type 'i32', but found 'i64'}} - %0 = "spv.AtomicAnd"(%ptr, %value) {memory_scope = 2: i32, semantics = 0x8 : i32} : (!spv.ptr, i64) -> (i64) + %0 = "spv.AtomicAnd"(%ptr, %value) {memory_scope = #spv.scope, semantics = #spv.memory_semantics} : (!spv.ptr, i64) -> (i64) return %0 : i64 } @@ -51,7 +51,7 @@ func.func @atomic_compare_exchange(%ptr: !spv.ptr, %value: i64, %comparator: i32) -> i32 { // expected-error @+1 {{value operand must have the same type as the op result, but found 'i64' vs 'i32'}} - %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr, i64, i32) -> (i32) + %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = #spv.scope, equal_semantics = #spv.memory_semantics, unequal_semantics = #spv.memory_semantics} : (!spv.ptr, i64, i32) -> (i32) return %0: i32 } @@ -59,7 +59,7 @@ func.func @atomic_compare_exchange(%ptr: !spv.ptr, %value: i32, %comparator: i16) -> i32 { // expected-error @+1 {{comparator operand must have the same type as the op result, but found 'i16' vs 'i32'}} - %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr, i32, i16) -> (i32) + %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = #spv.scope, equal_semantics = #spv.memory_semantics, unequal_semantics = #spv.memory_semantics} : (!spv.ptr, i32, i16) -> (i32) return %0: i32 } @@ -67,7 +67,7 @@ func.func @atomic_compare_exchange(%ptr: !spv.ptr, %value: i32, %comparator: i32) -> i32 { // expected-error @+1 {{pointer operand's pointee type must have the same as the op result type, but found 'i64' vs 'i32'}} - %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr, i32, i32) -> (i32) + %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = #spv.scope, equal_semantics = #spv.memory_semantics, unequal_semantics = #spv.memory_semantics} : (!spv.ptr, i32, i32) -> (i32) return %0: i32 } @@ -87,7 +87,7 @@ func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr, %value: i64, %comparator: i32) -> i32 { // expected-error @+1 {{value operand must have the same type as the op result, but found 'i64' vs 'i32'}} - %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr, i64, i32) -> (i32) + %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = #spv.scope, equal_semantics = #spv.memory_semantics, unequal_semantics = #spv.memory_semantics} : (!spv.ptr, i64, i32) -> (i32) return %0: i32 } @@ -95,7 +95,7 @@ func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr, %value: i32, %comparator: i16) -> i32 { // expected-error @+1 {{comparator operand must have the same type as the op result, but found 'i16' vs 'i32'}} - %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr, i32, i16) -> (i32) + %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = #spv.scope, equal_semantics = #spv.memory_semantics, unequal_semantics = #spv.memory_semantics} : (!spv.ptr, i32, i16) -> (i32) return %0: i32 } @@ -103,7 +103,7 @@ func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr, %value: i32, %comparator: i32) -> i32 { // expected-error @+1 {{pointer operand's pointee type must have the same as the op result type, but found 'i64' vs 'i32'}} - %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr, i32, i32) -> (i32) + %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = #spv.scope, equal_semantics = #spv.memory_semantics, unequal_semantics = #spv.memory_semantics} : (!spv.ptr, i32, i32) -> (i32) return %0: i32 } @@ -123,7 +123,7 @@ func.func @atomic_exchange(%ptr: !spv.ptr, %value: i64) -> i32 { // expected-error @+1 {{value operand must have the same type as the op result, but found 'i64' vs 'i32'}} - %0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4: i32} : (!spv.ptr, i64) -> (i32) + %0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = #spv.scope, semantics = #spv.memory_semantics} : (!spv.ptr, i64) -> (i32) return %0: i32 } @@ -131,7 +131,7 @@ func.func @atomic_exchange(%ptr: !spv.ptr, %value: i32) -> i32 { // expected-error @+1 {{pointer operand's pointee type must have the same as the op result type, but found 'i64' vs 'i32'}} - %0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4: i32} : (!spv.ptr, i32) -> (i32) + %0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = #spv.scope, semantics = #spv.memory_semantics} : (!spv.ptr, i32) -> (i32) return %0: i32 } @@ -253,7 +253,7 @@ func.func @atomic_fadd(%ptr : !spv.ptr, %value : f32) -> f32 { // expected-error @+1 {{pointer operand must point to an float value, found 'i32'}} - %0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4 : i32} : (!spv.ptr, f32) -> (f32) + %0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = #spv.scope, semantics = #spv.memory_semantics} : (!spv.ptr, f32) -> (f32) return %0 : f32 } @@ -261,7 +261,7 @@ func.func @atomic_fadd(%ptr : !spv.ptr, %value : f64) -> f64 { // expected-error @+1 {{expected value to have the same type as the pointer operand's pointee type 'f32', but found 'f64'}} - %0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = 2: i32, semantics = 0x8 : i32} : (!spv.ptr, f64) -> (f64) + %0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = #spv.scope, semantics = #spv.memory_semantics} : (!spv.ptr, f64) -> (f64) return %0 : f64 } diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir --- a/mlir/test/Dialect/SPIRV/IR/availability.mlir +++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir @@ -26,7 +26,7 @@ // CHECK: max version: v1.6 // CHECK: extensions: [ ] // CHECK: capabilities: [ [GroupNonUniformBallot] ] - %0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xi32> + %0 = spv.GroupNonUniformBallot %predicate : vector<4xi32> return %0: vector<4xi32> } diff --git a/mlir/test/Dialect/SPIRV/IR/barrier-ops.mlir b/mlir/test/Dialect/SPIRV/IR/barrier-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/barrier-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/barrier-ops.mlir @@ -5,16 +5,17 @@ //===----------------------------------------------------------------------===// func.func @control_barrier_0() -> () { - // CHECK: spv.ControlBarrier Workgroup, Device, "Acquire|UniformMemory" - spv.ControlBarrier Workgroup, Device, "Acquire|UniformMemory" + // CHECK: spv.ControlBarrier , , + spv.ControlBarrier , , return } // ----- func.func @control_barrier_1() -> () { - // expected-error @+1 {{expected string or keyword containing one of the following enum values}} - spv.ControlBarrier Something, Device, "Acquire|UniformMemory" + // expected-error @+2 {{to be one of}} + // expected-error @+1 {{failed to parse SPV_ScopeAttr}} + spv.ControlBarrier , , return } @@ -26,16 +27,16 @@ //===----------------------------------------------------------------------===// func.func @memory_barrier_0() -> () { - // CHECK: spv.MemoryBarrier Device, "Acquire|UniformMemory" - spv.MemoryBarrier Device, "Acquire|UniformMemory" + // CHECK: spv.MemoryBarrier , + spv.MemoryBarrier , return } // ----- func.func @memory_barrier_1() -> () { - // CHECK: spv.MemoryBarrier Workgroup, Acquire - spv.MemoryBarrier Workgroup, Acquire + // CHECK: spv.MemoryBarrier , + spv.MemoryBarrier , return } @@ -43,7 +44,7 @@ func.func @memory_barrier_2() -> () { // expected-error @+1 {{expected at most one of these four memory constraints to be set: `Acquire`, `Release`,`AcquireRelease` or `SequentiallyConsistent`}} - spv.MemoryBarrier Device, "Acquire|Release" + spv.MemoryBarrier , return } diff --git a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir @@ -17,24 +17,24 @@ //===----------------------------------------------------------------------===// func.func @group_broadcast_scalar(%value: f32, %localid: i32 ) -> f32 { - // CHECK: spv.GroupBroadcast Workgroup %{{.*}}, %{{.*}} : f32, i32 - %0 = spv.GroupBroadcast Workgroup %value, %localid : f32, i32 + // CHECK: spv.GroupBroadcast %{{.*}}, %{{.*}} : f32, i32 + %0 = spv.GroupBroadcast %value, %localid : f32, i32 return %0: f32 } // ----- func.func @group_broadcast_scalar_vector(%value: f32, %localid: vector<3xi32> ) -> f32 { - // CHECK: spv.GroupBroadcast Workgroup %{{.*}}, %{{.*}} : f32, vector<3xi32> - %0 = spv.GroupBroadcast Workgroup %value, %localid : f32, vector<3xi32> + // CHECK: spv.GroupBroadcast %{{.*}}, %{{.*}} : f32, vector<3xi32> + %0 = spv.GroupBroadcast %value, %localid : f32, vector<3xi32> return %0: f32 } // ----- func.func @group_broadcast_vector(%value: vector<4xf32>, %localid: vector<3xi32> ) -> vector<4xf32> { - // CHECK: spv.GroupBroadcast Subgroup %{{.*}}, %{{.*}} : vector<4xf32>, vector<3xi32> - %0 = spv.GroupBroadcast Subgroup %value, %localid : vector<4xf32>, vector<3xi32> + // CHECK: spv.GroupBroadcast %{{.*}}, %{{.*}} : vector<4xf32>, vector<3xi32> + %0 = spv.GroupBroadcast %value, %localid : vector<4xf32>, vector<3xi32> return %0: vector<4xf32> } @@ -42,7 +42,7 @@ func.func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> ) -> f32 { // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}} - %0 = spv.GroupBroadcast Device %value, %localid : f32, vector<3xi32> + %0 = spv.GroupBroadcast %value, %localid : f32, vector<3xi32> return %0: f32 } @@ -50,7 +50,7 @@ func.func @group_broadcast_negative_locid_dtype(%value: f32, %localid: vector<3xf32> ) -> f32 { // expected-error @+1 {{operand #1 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}} - %0 = spv.GroupBroadcast Subgroup %value, %localid : f32, vector<3xf32> + %0 = spv.GroupBroadcast %value, %localid : f32, vector<3xf32> return %0: f32 } @@ -58,7 +58,7 @@ func.func @group_broadcast_negative_locid_vec4(%value: f32, %localid: vector<4xi32> ) -> f32 { // expected-error @+1 {{localid is a vector and can be with only 2 or 3 components, actual number is 4}} - %0 = spv.GroupBroadcast Subgroup %value, %localid : f32, vector<4xi32> + %0 = spv.GroupBroadcast %value, %localid : f32, vector<4xi32> return %0: f32 } diff --git a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir @@ -198,7 +198,7 @@ %0 = spv.Variable : !spv.ptr // CHECK: spv.Load // CHECK-SAME: ["None"] - %1 = "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr) -> (f32) + %1 = "spv.Load"(%0) {memory_access = #spv.memory_access} : (!spv.ptr) -> (f32) return } @@ -207,7 +207,7 @@ %0 = spv.Variable : !spv.ptr // CHECK: spv.Load // CHECK-SAME: ["Volatile"] - %1 = "spv.Load"(%0) {memory_access = 1 : i32} : (!spv.ptr) -> (f32) + %1 = "spv.Load"(%0) {memory_access = #spv.memory_access} : (!spv.ptr) -> (f32) return } @@ -216,7 +216,7 @@ %0 = spv.Variable : !spv.ptr // CHECK: spv.Load // CHECK-SAME: ["Aligned", 4] - %1 = "spv.Load"(%0) {memory_access = 2 : i32, alignment = 4 : i32} : (!spv.ptr) -> (f32) + %1 = "spv.Load"(%0) {memory_access = #spv.memory_access, alignment = 4 : i32} : (!spv.ptr) -> (f32) return } @@ -225,7 +225,7 @@ %0 = spv.Variable : !spv.ptr // CHECK: spv.Load // CHECK-SAME: ["Volatile|Aligned", 4] - %1 = "spv.Load"(%0) {memory_access = 3 : i32, alignment = 4 : i32} : (!spv.ptr) -> (f32) + %1 = "spv.Load"(%0) {memory_access = #spv.memory_access, alignment = 4 : i32} : (!spv.ptr) -> (f32) return } @@ -588,7 +588,7 @@ %0 = spv.Variable : !spv.ptr %1 = spv.Variable : !spv.ptr // expected-error @+1 {{missing alignment value}} - "spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32} : (!spv.ptr, !spv.ptr) -> () + "spv.CopyMemory"(%0, %1) {memory_access=#spv.memory_access} : (!spv.ptr, !spv.ptr) -> () spv.Return } @@ -598,7 +598,7 @@ %0 = spv.Variable : !spv.ptr %1 = spv.Variable : !spv.ptr // expected-error @+1 {{invalid alignment specification with non-aligned memory access specification}} - "spv.CopyMemory"(%0, %1) {source_memory_access=0x0001 : i32, memory_access=0x0002 : i32, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () + "spv.CopyMemory"(%0, %1) {source_memory_access=#spv.memory_access, memory_access=#spv.memory_access, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () spv.Return } @@ -608,7 +608,7 @@ %0 = spv.Variable : !spv.ptr %1 = spv.Variable : !spv.ptr // expected-error @+1 {{missing alignment value}} - "spv.CopyMemory"(%0, %1) {source_memory_access=0x0002 : i32, memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () + "spv.CopyMemory"(%0, %1) {source_memory_access=#spv.memory_access, memory_access=#spv.memory_access, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () spv.Return } @@ -619,16 +619,16 @@ %1 = spv.Variable : !spv.ptr // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32 - "spv.CopyMemory"(%0, %1) {memory_access=0x0001 : i32} : (!spv.ptr, !spv.ptr) -> () + "spv.CopyMemory"(%0, %1) {memory_access=#spv.memory_access} : (!spv.ptr, !spv.ptr) -> () // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32 - "spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () + "spv.CopyMemory"(%0, %1) {memory_access=#spv.memory_access, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Volatile"] : f32 - "spv.CopyMemory"(%0, %1) {source_memory_access=0x0001 : i32, memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () + "spv.CopyMemory"(%0, %1) {source_memory_access=#spv.memory_access, memory_access=#spv.memory_access, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Aligned", 8] : f32 - "spv.CopyMemory"(%0, %1) {source_memory_access=0x0002 : i32, memory_access=0x0002 : i32, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () + "spv.CopyMemory"(%0, %1) {source_memory_access=#spv.memory_access, memory_access=#spv.memory_access, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () spv.Return } diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir @@ -5,8 +5,8 @@ //===----------------------------------------------------------------------===// func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> { - // CHECK: %{{.*}} = spv.GroupNonUniformBallot Workgroup %{{.*}}: vector<4xi32> - %0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xi32> + // CHECK: %{{.*}} = spv.GroupNonUniformBallot %{{.*}}: vector<4xi32> + %0 = spv.GroupNonUniformBallot %predicate : vector<4xi32> return %0: vector<4xi32> } @@ -14,7 +14,7 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> { // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}} - %0 = spv.GroupNonUniformBallot Device %predicate : vector<4xi32> + %0 = spv.GroupNonUniformBallot %predicate : vector<4xi32> return %0: vector<4xi32> } @@ -22,7 +22,7 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> { // expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit signless/unsigned integer values of length 4, but got 'vector<4xsi32>'}} - %0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xsi32> + %0 = spv.GroupNonUniformBallot %predicate : vector<4xsi32> return %0: vector<4xsi32> } @@ -34,8 +34,8 @@ func.func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 { %one = spv.Constant 1 : i32 - // CHECK: spv.GroupNonUniformBroadcast Workgroup %{{.*}}, %{{.*}} : f32, i32 - %0 = spv.GroupNonUniformBroadcast Workgroup %value, %one : f32, i32 + // CHECK: spv.GroupNonUniformBroadcast %{{.*}}, %{{.*}} : f32, i32 + %0 = spv.GroupNonUniformBroadcast %value, %one : f32, i32 return %0: f32 } @@ -43,8 +43,8 @@ func.func @group_non_uniform_broadcast_vector(%value: vector<4xf32>) -> vector<4xf32> { %one = spv.Constant 1 : i32 - // CHECK: spv.GroupNonUniformBroadcast Subgroup %{{.*}}, %{{.*}} : vector<4xf32>, i32 - %0 = spv.GroupNonUniformBroadcast Subgroup %value, %one : vector<4xf32>, i32 + // CHECK: spv.GroupNonUniformBroadcast %{{.*}}, %{{.*}} : vector<4xf32>, i32 + %0 = spv.GroupNonUniformBroadcast %value, %one : vector<4xf32>, i32 return %0: vector<4xf32> } @@ -53,7 +53,7 @@ func.func @group_non_uniform_broadcast_negative_scope(%value: f32, %localid: i32 ) -> f32 { %one = spv.Constant 1 : i32 // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}} - %0 = spv.GroupNonUniformBroadcast Device %value, %one : f32, i32 + %0 = spv.GroupNonUniformBroadcast %value, %one : f32, i32 return %0: f32 } @@ -61,7 +61,7 @@ func.func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid: i32) -> f32 { // expected-error @+1 {{id must be the result of a constant op}} - %0 = spv.GroupNonUniformBroadcast Subgroup %value, %localid : f32, i32 + %0 = spv.GroupNonUniformBroadcast %value, %localid : f32, i32 return %0: f32 } @@ -73,8 +73,8 @@ // CHECK-LABEL: @group_non_uniform_elect func.func @group_non_uniform_elect() -> i1 { - // CHECK: %{{.+}} = spv.GroupNonUniformElect Workgroup : i1 - %0 = spv.GroupNonUniformElect Workgroup : i1 + // CHECK: %{{.+}} = spv.GroupNonUniformElect : i1 + %0 = spv.GroupNonUniformElect : i1 return %0: i1 } @@ -82,7 +82,7 @@ func.func @group_non_uniform_elect() -> i1 { // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}} - %0 = spv.GroupNonUniformElect CrossDevice : i1 + %0 = spv.GroupNonUniformElect : i1 return %0: i1 } diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir @@ -819,7 +819,7 @@ %0 = spv.Variable : !spv.ptr // expected-error @+1 {{invalid enclosed op}} - %1 = spv.SpecConstantOperation wraps "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr) -> i32 + %1 = spv.SpecConstantOperation wraps "spv.Load"(%0) {memory_access = #spv.memory_access} : (!spv.ptr) -> i32 spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir --- a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir +++ b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir @@ -165,11 +165,11 @@ // CHECK-SAME: #spv.coop_matrix_props< // CHECK-SAME: m_size = 8, n_size = 8, k_size = 32, // CHECK-SAME: a_type = i8, b_type = i8, c_type = i32, - // CHECK-SAME: result_type = i32, scope = 3 : i32> + // CHECK-SAME: result_type = i32, scope = > // CHECK-SAME: #spv.coop_matrix_props< // CHECK-SAME: m_size = 8, n_size = 8, k_size = 16, // CHECK-SAME: a_type = f16, b_type = f16, c_type = f16, - // CHECK-SAME: result_type = f16, scope = 3 : i32> + // CHECK-SAME: result_type = f16, scope = > spv.target_env = #spv.target_env< #spv.vce, @@ -182,7 +182,7 @@ b_type = i8, c_type = i32, result_type = i32, - scope = 3 : i32 + scope = #spv.scope >, #spv.coop_matrix_props< m_size = 8, n_size = 8, @@ -191,7 +191,7 @@ b_type = f16, c_type = f16, result_type = f16, - scope = 3 : i32 + scope = #spv.scope >] >> } { return } diff --git a/mlir/test/Dialect/SPIRV/IR/target-env.mlir b/mlir/test/Dialect/SPIRV/IR/target-env.mlir --- a/mlir/test/Dialect/SPIRV/IR/target-env.mlir +++ b/mlir/test/Dialect/SPIRV/IR/target-env.mlir @@ -59,7 +59,7 @@ func.func @group_non_uniform_ballot_suitable_version(%predicate: i1) -> vector<4xi32> attributes { spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>> } { - // CHECK: spv.GroupNonUniformBallot Workgroup + // CHECK: spv.GroupNonUniformBallot %0 = "test.convert_to_group_non_uniform_ballot_op"(%predicate): (i1) -> (vector<4xi32>) return %0: vector<4xi32> } diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -27,7 +27,7 @@ #spv.vce, #spv.resource_limits<>> } { spv.func @group_non_uniform_ballot(%predicate : i1) -> vector<4xi32> "None" { - %0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xi32> + %0 = spv.GroupNonUniformBallot %predicate : vector<4xi32> spv.ReturnValue %0: vector<4xi32> } } diff --git a/mlir/test/Target/SPIRV/barrier-ops.mlir b/mlir/test/Target/SPIRV/barrier-ops.mlir --- a/mlir/test/Target/SPIRV/barrier-ops.mlir +++ b/mlir/test/Target/SPIRV/barrier-ops.mlir @@ -2,23 +2,23 @@ spv.module Logical GLSL450 requires #spv.vce { spv.func @memory_barrier_0() -> () "None" { - // CHECK: spv.MemoryBarrier Device, "Release|UniformMemory" - spv.MemoryBarrier Device, "Release|UniformMemory" + // CHECK: spv.MemoryBarrier , + spv.MemoryBarrier , spv.Return } spv.func @memory_barrier_1() -> () "None" { - // CHECK: spv.MemoryBarrier Subgroup, "AcquireRelease|SubgroupMemory" - spv.MemoryBarrier Subgroup, "AcquireRelease|SubgroupMemory" + // CHECK: spv.MemoryBarrier , + spv.MemoryBarrier , spv.Return } spv.func @control_barrier_0() -> () "None" { - // CHECK: spv.ControlBarrier Device, Workgroup, "Release|UniformMemory" - spv.ControlBarrier Device, Workgroup, "Release|UniformMemory" + // CHECK: spv.ControlBarrier , , + spv.ControlBarrier , , spv.Return } spv.func @control_barrier_1() -> () "None" { - // CHECK: spv.ControlBarrier Workgroup, Invocation, "AcquireRelease|UniformMemory" - spv.ControlBarrier Workgroup, Invocation, "AcquireRelease|UniformMemory" + // CHECK: spv.ControlBarrier , , + spv.ControlBarrier , , spv.Return } } diff --git a/mlir/test/Target/SPIRV/group-ops.mlir b/mlir/test/Target/SPIRV/group-ops.mlir --- a/mlir/test/Target/SPIRV/group-ops.mlir +++ b/mlir/test/Target/SPIRV/group-ops.mlir @@ -9,14 +9,14 @@ } // CHECK-LABEL: @group_broadcast_1 spv.func @group_broadcast_1(%value: f32, %localid: i32 ) -> f32 "None" { - // CHECK: spv.GroupBroadcast Workgroup %{{.*}}, %{{.*}} : f32, i32 - %0 = spv.GroupBroadcast Workgroup %value, %localid : f32, i32 + // CHECK: spv.GroupBroadcast %{{.*}}, %{{.*}} : f32, i32 + %0 = spv.GroupBroadcast %value, %localid : f32, i32 spv.ReturnValue %0: f32 } // CHECK-LABEL: @group_broadcast_2 spv.func @group_broadcast_2(%value: f32, %localid: vector<3xi32> ) -> f32 "None" { - // CHECK: spv.GroupBroadcast Workgroup %{{.*}}, %{{.*}} : f32, vector<3xi32> - %0 = spv.GroupBroadcast Workgroup %value, %localid : f32, vector<3xi32> + // CHECK: spv.GroupBroadcast %{{.*}}, %{{.*}} : f32, vector<3xi32> + %0 = spv.GroupBroadcast %value, %localid : f32, vector<3xi32> spv.ReturnValue %0: f32 } // CHECK-LABEL: @subgroup_block_read_intel diff --git a/mlir/test/Target/SPIRV/non-uniform-ops.mlir b/mlir/test/Target/SPIRV/non-uniform-ops.mlir --- a/mlir/test/Target/SPIRV/non-uniform-ops.mlir +++ b/mlir/test/Target/SPIRV/non-uniform-ops.mlir @@ -3,23 +3,23 @@ spv.module Logical GLSL450 requires #spv.vce { // CHECK-LABEL: @group_non_uniform_ballot spv.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> "None" { - // CHECK: %{{.*}} = spv.GroupNonUniformBallot Workgroup %{{.*}}: vector<4xi32> - %0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xi32> + // CHECK: %{{.*}} = spv.GroupNonUniformBallot %{{.*}}: vector<4xi32> + %0 = spv.GroupNonUniformBallot %predicate : vector<4xi32> spv.ReturnValue %0: vector<4xi32> } // CHECK-LABEL: @group_non_uniform_broadcast spv.func @group_non_uniform_broadcast(%value: f32) -> f32 "None" { %one = spv.Constant 1 : i32 - // CHECK: spv.GroupNonUniformBroadcast Subgroup %{{.*}}, %{{.*}} : f32, i32 - %0 = spv.GroupNonUniformBroadcast Subgroup %value, %one : f32, i32 + // CHECK: spv.GroupNonUniformBroadcast %{{.*}}, %{{.*}} : f32, i32 + %0 = spv.GroupNonUniformBroadcast %value, %one : f32, i32 spv.ReturnValue %0: f32 } // CHECK-LABEL: @group_non_uniform_elect spv.func @group_non_uniform_elect() -> i1 "None" { - // CHECK: %{{.+}} = spv.GroupNonUniformElect Workgroup : i1 - %0 = spv.GroupNonUniformElect Workgroup : i1 + // CHECK: %{{.+}} = spv.GroupNonUniformElect : i1 + %0 = spv.GroupNonUniformElect : i1 spv.ReturnValue %0: i1 } diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -519,10 +519,24 @@ << formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName); if (attr.getAttrDefName() == "SPV_ScopeAttr" || attr.getAttrDefName() == "SPV_MemorySemanticsAttr") { + // These two enums are encoded as to constant values in SPIR-V blob, + // but we directly use the constant value as attribute in SPIR-V dialect. So + // need to handle them separately from normal enum attributes. + EnumAttr baseEnum(attr.getDef().getValueAsDef("enum")); os << tabs << formatv(" {0}.push_back(prepareConstantInt({1}.getLoc(), " - "attr.cast()));\n", - operandList, opVar); + "Builder({1}).getI32IntegerAttr(static_cast(" + "attr.cast<{2}::{3}Attr>().getValue()))));\n", + operandList, opVar, baseEnum.getCppNamespace(), + baseEnum.getEnumClassName()); + } else if (attr.isSubClassOf("SPV_BitEnumAttr") || + attr.isSubClassOf("SPV_I32EnumAttr")) { + EnumAttr baseEnum(attr.getDef().getValueAsDef("enum")); + os << tabs + << formatv(" {0}.push_back(static_cast(" + "attr.cast<{1}::{2}Attr>().getValue()));\n", + operandList, baseEnum.getCppNamespace(), + baseEnum.getEnumClassName()); } else if (attr.getAttrDefName() == "I32ArrayAttr") { // Serialize all the elements of the array os << tabs << " for (auto attrElem : attr.cast()) {\n"; @@ -531,7 +545,7 @@ "attrElem.cast().getValue().getZExtValue()));\n", operandList); os << tabs << " }\n"; - } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") { + } else if (attr.getAttrDefName() == "I32Attr") { os << tabs << formatv(" {0}.push_back(static_cast(" "attr.cast().getValue().getZExtValue()));\n", @@ -797,10 +811,25 @@ raw_ostream &os) { if (attr.getAttrDefName() == "SPV_ScopeAttr" || attr.getAttrDefName() == "SPV_MemorySemanticsAttr") { + // These two enums are encoded as to constant values in SPIR-V blob, + // but we directly use the constant value as attribute in SPIR-V dialect. So + // need to handle them separately from normal enum attributes. + EnumAttr baseEnum(attr.getDef().getValueAsDef("enum")); os << tabs << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " - "getConstantInt({2}[{3}++])));\n", - attrList, attrName, words, wordIndex); + "opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>(" + "getConstantInt({4}[{5}++]).getValue().getZExtValue()))));\n", + attrList, attrName, baseEnum.getCppNamespace(), + baseEnum.getEnumClassName(), words, wordIndex); + } else if (attr.isSubClassOf("SPV_BitEnumAttr") || + attr.isSubClassOf("SPV_I32EnumAttr")) { + EnumAttr baseEnum(attr.getDef().getValueAsDef("enum")); + os << tabs + << formatv(" {0}.push_back(opBuilder.getNamedAttr(\"{1}\", " + "opBuilder.getAttr<{2}::{3}Attr>(" + "static_cast<{2}::{3}>({4}[{5}++]))));\n", + attrList, attrName, baseEnum.getCppNamespace(), + baseEnum.getEnumClassName(), words, wordIndex); } else if (attr.getAttrDefName() == "I32ArrayAttr") { os << tabs << "SmallVector attrListElems;\n"; os << tabs << formatv("while ({0} < {1}.size()) {{\n", wordIndex, words); @@ -815,7 +844,7 @@ << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " "opBuilder.getArrayAttr(attrListElems)));\n", attrList, attrName); - } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") { + } else if (attr.getAttrDefName() == "I32Attr") { os << tabs << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " "opBuilder.getI32IntegerAttr({2}[{3}++])));\n", @@ -1257,11 +1286,12 @@ for (const Availability &avail : opAvailabilities) availClasses.try_emplace(avail.getClass(), avail); for (const NamedAttribute &namedAttr : srcOp.getAttributes()) { - const auto *enumAttr = llvm::dyn_cast(&namedAttr.attr); - if (!enumAttr) + if (!namedAttr.attr.isSubClassOf("SPV_BitEnumAttr") && + !namedAttr.attr.isSubClassOf("SPV_I32EnumAttr")) continue; + EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum")); - for (const EnumAttrCase &enumerant : enumAttr->getAllCases()) + for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) for (const Availability &caseAvail : getAvailabilities(enumerant.getDef())) availClasses.try_emplace(caseAvail.getClass(), caseAvail); @@ -1298,16 +1328,17 @@ // Update with enum attributes' specific availability spec. for (const NamedAttribute &namedAttr : srcOp.getAttributes()) { - const auto *enumAttr = llvm::dyn_cast(&namedAttr.attr); - if (!enumAttr) + if (!namedAttr.attr.isSubClassOf("SPV_BitEnumAttr") && + !namedAttr.attr.isSubClassOf("SPV_I32EnumAttr")) continue; + EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum")); // (enumerant, availability specification) pairs for this availability // class. SmallVector, 1> caseSpecs; // Collect all cases' availability specs. - for (const EnumAttrCase &enumerant : enumAttr->getAllCases()) + for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) for (const Availability &caseAvail : getAvailabilities(enumerant.getDef())) if (availClassName == caseAvail.getClass()) @@ -1318,19 +1349,19 @@ if (caseSpecs.empty()) continue; - if (enumAttr->isBitEnum()) { + if (enumAttr.isBitEnum()) { // For BitEnumAttr, we need to iterate over each bit to query its // availability spec. os << formatv(" for (unsigned i = 0; " "i < std::numeric_limits<{0}>::digits; ++i) {{\n", - enumAttr->getUnderlyingType()); + enumAttr.getUnderlyingType()); os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & " "static_cast<{0}::{1}>(1 << i);\n", - enumAttr->getCppNamespace(), enumAttr->getEnumClassName(), + enumAttr.getCppNamespace(), enumAttr.getEnumClassName(), namedAttr.name); os << formatv( " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n", - enumAttr->getUnderlyingType()); + enumAttr.getUnderlyingType()); } else { // For IntEnumAttr, we just need to query the value as a whole. os << " {\n"; @@ -1338,7 +1369,7 @@ namedAttr.name); } os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n", - enumAttr->getCppNamespace(), avail.getQueryFnName()); + enumAttr.getCppNamespace(), avail.getQueryFnName()); os << " if (tblgen_instance) " // TODO` here once ODS supports // dialect-specific contents so that we can use not implementing the @@ -1385,7 +1416,8 @@ raw_ostream &os) { llvm::emitSourceFileHeader("SPIR-V Capability Implication", os); - EnumAttr enumAttr(recordKeeper.getDef("SPV_CapabilityAttr")); + EnumAttr enumAttr( + recordKeeper.getDef("SPV_CapabilityAttr")->getValueAsDef("enum")); os << "ArrayRef " "spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n" diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp --- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp @@ -14,6 +14,7 @@ #include "mlir/Target/SPIRV/Serialization.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/IR/Builders.h" @@ -46,11 +47,10 @@ OperationState state(UnknownLoc::get(&context), spirv::ModuleOp::getOperationName()); state.addAttribute("addressing_model", - builder.getI32IntegerAttr(static_cast( - spirv::AddressingModel::Logical))); - state.addAttribute("memory_model", - builder.getI32IntegerAttr( - static_cast(spirv::MemoryModel::GLSL450))); + builder.getAttr( + spirv::AddressingModel::Logical)); + state.addAttribute("memory_model", builder.getAttr( + spirv::MemoryModel::GLSL450)); state.addAttribute("vce_triple", spirv::VerCapExtAttr::get( spirv::Version::V_1_0, ArrayRef(), diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -437,10 +437,13 @@ # Generate the enum attribute definition kind_category = 'Bit' if is_bit_enum else 'I32' enum_attr = '''def SPV_{name}Attr : - SPV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", [ + SPV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", "{snake_name}", [ {cases} ]>;'''.format( - name=kind_name, category=kind_category, cases=case_names) + name=kind_name, + snake_name=snake_casify(kind_name), + category=kind_category, + cases=case_names) return kind_name, case_defs + '\n\n' + enum_attr @@ -473,7 +476,8 @@ ] opcode_list = ',\n'.join(opcode_list) enum_attr = 'def SPV_OpcodeAttr :\n'\ - ' SPV_I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\ + ' SPV_I32EnumAttr<"{name}", "valid SPIR-V instructions", '\ + '"opcode", [\n'\ '{lst}\n'\ ' ]>;'.format(name='Opcode', lst=opcode_list) return opcode_str + '\n\n' + enum_attr @@ -630,9 +634,7 @@ def snake_casify(name): """Turns the given name to follow snake_case convention.""" - name = re.sub('\W+', '', name).split() - name = [s.lower() for s in name] - return '_'.join(name) + return re.sub(r'(?