diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -247,6 +247,24 @@ "getCapabilities">]; } +//===----------------------------------------------------------------------===// +// SPIR-V target GPU vendor and device definitions +//===----------------------------------------------------------------------===// + +// An accelerator other than GPU or CPU +def SPV_DT_Other : I32EnumAttrCase<"Other", 0>; +def SPV_DT_IntegratedGPU : I32EnumAttrCase<"IntegratedGPU", 1>; +def SPV_DT_DiscreteGPU : I32EnumAttrCase<"DiscreteGPU", 2>; +def SPV_DT_CPU : I32EnumAttrCase<"CPU", 3>; +// Information missing. +def SPV_DT_Unknown : I32EnumAttrCase<"Unknown", 0x7FFFFFFF>; + +def SPV_DeviceTypeAttr : SPV_I32EnumAttr< + "DeviceType", "valid SPIR-V device types", [ + SPV_DT_Other, SPV_DT_IntegratedGPU, SPV_DT_DiscreteGPU, + SPV_DT_CPU, SPV_DT_Unknown + ]>; + //===----------------------------------------------------------------------===// // SPIR-V extension definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h @@ -29,6 +29,8 @@ public: explicit TargetEnv(TargetEnvAttr targetAttr); + DeviceType getDeviceType(); + Version getVersion(); /// Returns true if the given capability is allowed. diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td @@ -45,10 +45,31 @@ // are the from Vulkan limit requirements: // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/html/vkspec.html#limits-minmax def SPV_ResourceLimitsAttr : StructAttr<"ResourceLimitsAttr", SPIRV_Dialect, [ + // Unique identifier for the vendor and target GPU. + // 0x7FFFFFFF means unknown. + StructFieldAttr<"vendor_id", DefaultValuedAttr>, + StructFieldAttr<"device_id", DefaultValuedAttr>, + // Target device type. + StructFieldAttr<"device_type", + DefaultValuedAttr>, + + // The maximum total storage size, in bytes, available for variables + // declared with the Workgroup storage class. + StructFieldAttr<"max_compute_shared_memory_size", + DefaultValuedAttr>, + + // The maximum total number of compute shader invocations in a single local + // workgroup. StructFieldAttr<"max_compute_workgroup_invocations", DefaultValuedAttr>, + // The maximum size of a local compute workgroup, per dimension. StructFieldAttr<"max_compute_workgroup_size", - DefaultValuedAttr> + DefaultValuedAttr>, + + // The default number of invocations in each subgroup. + // 0x7FFFFFFF means unknown. + StructFieldAttr<"subgroup_size", DefaultValuedAttr> ]>; #endif // SPIRV_TARGET_AND_ABI diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp @@ -38,6 +38,14 @@ } } +spirv::DeviceType spirv::TargetEnv::getDeviceType() { + auto deviceType = spirv::symbolizeDeviceType( + targetAttr.getResourceLimits().device_type().getInt()); + if (!deviceType) + return DeviceType::Unknown; + return *deviceType; +} + spirv::Version spirv::TargetEnv::getVersion() { return targetAttr.getVersion(); } @@ -134,13 +142,16 @@ spirv::ResourceLimitsAttr spirv::getDefaultResourceLimits(MLIRContext *context) { - auto i32Type = IntegerType::get(32, context); - auto v3i32Type = VectorType::get(3, i32Type); - - // These numbers are from "Table 46. Required Limits" of the Vulkan spec. + // All the fields have default values. Here we just provide a nicer way to + // construct a default resource limit attribute. return spirv::ResourceLimitsAttr ::get( - IntegerAttr::get(i32Type, 128), - DenseIntElementsAttr::get(v3i32Type, {128, 128, 64}), context); + /*vendor_id=*/nullptr, + /*device_id*/ nullptr, + /*device_type=*/nullptr, + /*max_compute_shared_memory_size=*/nullptr, + /*max_compute_workgroup_invocations=*/nullptr, + /*max_compute_workgroup_size=*/nullptr, + /*subgroup_size=*/nullptr, context); } StringRef spirv::getTargetEnvAttrName() { return "spv.target_env"; }