diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md --- a/mlir/docs/Dialects/SPIR-V.md +++ b/mlir/docs/Dialects/SPIR-V.md @@ -805,8 +805,14 @@ spirv-capability-list `,` spirv-extensions-list `>` +spirv-vendor-id ::= `AMD` | `NVIDIA` | ... +spirv-device-type ::= `DiscreteGPU` | `IntegratedGPU` | `CPU` | ... +spirv-device-id ::= integer-literal +spirv-device-info ::= spirv-vendor-id (`:` spirv-device-type (`:` spirv-device-id)?)? + spirv-target-env-attribute ::= `#` `spv.target_env` `<` spirv-vce-attribute, + (spirv-device-info `,`)? spirv-resource-limits `>` ``` @@ -827,6 +833,7 @@ module attributes { spv.target_env = #spv.target_env< #spv.vce, + ARM:IntegratedGPU, { max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32> diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h @@ -23,7 +23,9 @@ namespace mlir { namespace spirv { enum class Capability : uint32_t; +enum class DeviceType; enum class Extension; +enum class Vendor; enum class Version : uint32_t; namespace detail { @@ -123,10 +125,15 @@ : public Attribute::AttrBase { public: + /// ID for unknown devices. + static constexpr uint32_t kUnknownDeviceID = 0x7FFFFFFF; + using Base::Base; /// Gets a TargetEnvAttr instance. - static TargetEnvAttr get(VerCapExtAttr triple, DictionaryAttr limits); + static TargetEnvAttr get(VerCapExtAttr triple, Vendor vendorID, + DeviceType deviceType, uint32_t deviceId, + DictionaryAttr limits); /// Returns the attribute kind's name (without the 'spv.' prefix). static StringRef getKindName(); @@ -147,12 +154,22 @@ /// Returns the target capabilities as an integer array attribute. ArrayAttr getCapabilitiesAttr(); + /// Returns the vendor ID. + Vendor getVendorID(); + + /// Returns the device type. + DeviceType getDeviceType(); + + /// Returns the device ID. + uint32_t getDeviceID(); + /// Returns the target resource limits. ResourceLimitsAttr getResourceLimits(); - static LogicalResult verifyConstructionInvariants(Location loc, - VerCapExtAttr triple, - DictionaryAttr limits); + static LogicalResult + verifyConstructionInvariants(Location loc, VerCapExtAttr triple, + Vendor vendorID, DeviceType deviceType, + uint32_t deviceID, DictionaryAttr limits); }; } // namespace spirv } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -254,20 +254,36 @@ // SPIR-V target GPU vendor and device definitions //===----------------------------------------------------------------------===// +def SPV_DT_CPU : StrEnumAttrCase<"CPU">; +def SPV_DT_DiscreteGPU : StrEnumAttrCase<"DiscreteGPU">; +def SPV_DT_IntegratedGPU : StrEnumAttrCase<"IntegratedGPU">; // An accelerator other than GPU or CPU -def SPV_DT_Other : I32EnumAttrCase<"Other", 0>; -def SPV_DT_IntegratedGPU : I32EnumAttrCase<"IntegratedGPU", 1>; -def SPV_DT_DiscreteGPU : I32EnumAttrCase<"DiscreteGPU", 2>; -def SPV_DT_CPU : I32EnumAttrCase<"CPU", 3>; +def SPV_DT_Other : StrEnumAttrCase<"Other">; // Information missing. -def SPV_DT_Unknown : I32EnumAttrCase<"Unknown", 0x7FFFFFFF>; +def SPV_DT_Unknown : StrEnumAttrCase<"Unknown">; -def SPV_DeviceTypeAttr : SPV_I32EnumAttr< +def SPV_DeviceTypeAttr : SPV_StrEnumAttr< "DeviceType", "valid SPIR-V device types", [ SPV_DT_Other, SPV_DT_IntegratedGPU, SPV_DT_DiscreteGPU, SPV_DT_CPU, SPV_DT_Unknown ]>; +def SPV_V_AMD : StrEnumAttrCase<"AMD">; +def SPV_V_ARM : StrEnumAttrCase<"ARM">; +def SPV_V_Imagination : StrEnumAttrCase<"Imagination">; +def SPV_V_Intel : StrEnumAttrCase<"Intel">; +def SPV_V_NVIDIA : StrEnumAttrCase<"NVIDIA">; +def SPV_V_Qualcomm : StrEnumAttrCase<"Qualcomm">; +def SPV_V_SwiftShader : StrEnumAttrCase<"SwiftShader">; +def SPV_V_Unknown : StrEnumAttrCase<"Unknown">; + +def SPV_VendorAttr : SPV_StrEnumAttr< + "Vendor", "recognized SPIR-V vendor strings", [ + SPV_V_AMD, SPV_V_ARM, SPV_V_Imagination, SPV_V_Intel, + SPV_V_NVIDIA, SPV_V_Qualcomm, SPV_V_SwiftShader, + SPV_V_Unknown + ]>; + //===----------------------------------------------------------------------===// // SPIR-V extension definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h @@ -29,8 +29,6 @@ public: explicit TargetEnv(TargetEnvAttr targetAttr); - DeviceType getDeviceType(); - Version getVersion(); /// Returns true if the given capability is allowed. diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td @@ -45,15 +45,6 @@ // are the from Vulkan limit requirements: // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/html/vkspec.html#limits-minmax def SPV_ResourceLimitsAttr : StructAttr<"ResourceLimitsAttr", SPIRV_Dialect, [ - // Unique identifier for the vendor and target GPU. - // 0x7FFFFFFF means unknown. - StructFieldAttr<"vendor_id", DefaultValuedAttr>, - StructFieldAttr<"device_id", DefaultValuedAttr>, - // Target device type. - StructFieldAttr<"device_type", - DefaultValuedAttr>, - // The maximum total storage size, in bytes, available for variables // declared with the Workgroup storage class. StructFieldAttr<"max_compute_shared_memory_size", diff --git a/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp @@ -77,23 +77,32 @@ }; struct TargetEnvAttributeStorage : public AttributeStorage { - using KeyTy = std::pair; + using KeyTy = std::tuple; - TargetEnvAttributeStorage(Attribute triple, Attribute limits) - : triple(triple), limits(limits) {} + TargetEnvAttributeStorage(Attribute triple, Vendor vendorID, + DeviceType deviceType, uint32_t deviceID, + Attribute limits) + : triple(triple), limits(limits), vendorID(vendorID), + deviceType(deviceType), deviceID(deviceID) {} bool operator==(const KeyTy &key) const { - return key.first == triple && key.second == limits; + return key == + std::make_tuple(triple, vendorID, deviceType, deviceID, limits); } static TargetEnvAttributeStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key) { return new (allocator.allocate()) - TargetEnvAttributeStorage(key.first, key.second); + TargetEnvAttributeStorage(std::get<0>(key), std::get<1>(key), + std::get<2>(key), std::get<3>(key), + std::get<4>(key)); } Attribute triple; Attribute limits; + Vendor vendorID; + DeviceType deviceType; + uint32_t deviceID; }; } // namespace detail } // namespace spirv @@ -268,10 +277,13 @@ //===----------------------------------------------------------------------===// spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple, + Vendor vendorID, + DeviceType deviceType, + uint32_t deviceID, DictionaryAttr limits) { assert(triple && limits && "expected valid triple and limits"); MLIRContext *context = triple.getContext(); - return Base::get(context, triple, limits); + return Base::get(context, triple, vendorID, deviceType, deviceID, limits); } StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; } @@ -300,12 +312,24 @@ return getTripleAttr().getCapabilitiesAttr(); } +spirv::Vendor spirv::TargetEnvAttr::getVendorID() { + return getImpl()->vendorID; +} + +spirv::DeviceType spirv::TargetEnvAttr::getDeviceType() { + return getImpl()->deviceType; +} + +uint32_t spirv::TargetEnvAttr::getDeviceID() { return getImpl()->deviceID; } + spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() { return getImpl()->limits.cast(); } LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants( - Location loc, spirv::VerCapExtAttr triple, DictionaryAttr limits) { + Location loc, spirv::VerCapExtAttr /*triple*/, spirv::Vendor /*vendorID*/, + spirv::DeviceType /*deviceType*/, uint32_t /*deviceID*/, + DictionaryAttr limits) { if (!limits.isa()) return emitError(loc, "expected spirv::ResourceLimitsAttr for limits"); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -918,6 +918,42 @@ if (parser.parseAttribute(tripleAttr) || parser.parseComma()) return {}; + // Parse [vendor[:device-type[:device-id]]] + Vendor vendorID = Vendor::Unknown; + DeviceType deviceType = DeviceType::Unknown; + uint32_t deviceID = spirv::TargetEnvAttr::kUnknownDeviceID; + { + auto loc = parser.getCurrentLocation(); + StringRef vendorStr; + if (succeeded(parser.parseOptionalKeyword(&vendorStr))) { + if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr)) { + vendorID = *vendorSymbol; + } else { + parser.emitError(loc, "unknown vendor: ") << vendorStr; + } + + if (succeeded(parser.parseOptionalColon())) { + loc = parser.getCurrentLocation(); + StringRef deviceTypeStr; + if (parser.parseKeyword(&deviceTypeStr)) + return {}; + if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr)) { + deviceType = *deviceTypeSymbol; + } else { + parser.emitError(loc, "unknown device type: ") << deviceTypeStr; + } + + if (succeeded(parser.parseOptionalColon())) { + loc = parser.getCurrentLocation(); + if (parser.parseInteger(deviceID)) + return {}; + } + } + if (parser.parseComma()) + return {}; + } + } + DictionaryAttr limitsAttr; { auto loc = parser.getCurrentLocation(); @@ -937,7 +973,8 @@ if (parser.parseGreater()) return {}; - return spirv::TargetEnvAttr::get(tripleAttr, limitsAttr); + return spirv::TargetEnvAttr::get(tripleAttr, vendorID, deviceType, deviceID, + limitsAttr); } Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser, @@ -986,6 +1023,17 @@ static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) { printer << spirv::TargetEnvAttr::getKindName() << "<#spv."; print(targetEnv.getTripleAttr(), printer); + spirv::Vendor vendorID = targetEnv.getVendorID(); + spirv::DeviceType deviceType = targetEnv.getDeviceType(); + uint32_t deviceID = targetEnv.getDeviceID(); + if (vendorID != spirv::Vendor::Unknown) { + printer << ", " << spirv::stringifyVendor(vendorID); + if (deviceType != spirv::DeviceType::Unknown) { + printer << ":" << spirv::stringifyDeviceType(deviceType); + if (deviceID != spirv::TargetEnvAttr::kUnknownDeviceID) + printer << ":" << deviceID; + } + } printer << ", " << targetEnv.getResourceLimits() << ">"; } diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp @@ -38,14 +38,6 @@ } } -spirv::DeviceType spirv::TargetEnv::getDeviceType() { - auto deviceType = spirv::symbolizeDeviceType( - targetAttr.getResourceLimits().device_type().getInt()); - if (!deviceType) - return DeviceType::Unknown; - return *deviceType; -} - spirv::Version spirv::TargetEnv::getVersion() { return targetAttr.getVersion(); } @@ -145,9 +137,6 @@ // All the fields have default values. Here we just provide a nicer way to // construct a default resource limit attribute. return spirv::ResourceLimitsAttr ::get( - /*vendor_id=*/nullptr, - /*device_id*/ nullptr, - /*device_type=*/nullptr, /*max_compute_shared_memory_size=*/nullptr, /*max_compute_workgroup_invocations=*/nullptr, /*max_compute_workgroup_size=*/nullptr, @@ -160,7 +149,9 @@ auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0, {spirv::Capability::Shader}, ArrayRef(), context); - return spirv::TargetEnvAttr::get(triple, + return spirv::TargetEnvAttr::get(triple, spirv::Vendor::Unknown, + spirv::DeviceType::Unknown, + spirv::TargetEnvAttr::kUnknownDeviceID, spirv::getDefaultResourceLimits(context)); } diff --git a/mlir/test/Dialect/SPIRV/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/target-and-abi.mlir --- a/mlir/test/Dialect/SPIRV/target-and-abi.mlir +++ b/mlir/test/Dialect/SPIRV/target-and-abi.mlir @@ -127,6 +127,36 @@ // ----- +func @target_env_vendor_id() attributes { + // CHECK: spv.target_env = #spv.target_env< + // CHECK-SAME: #spv.vce, + // CHECK-SAME: NVIDIA, + // CHECK-SAME: {}> + spv.target_env = #spv.target_env<#spv.vce, NVIDIA, {}> +} { return } + +// ----- + +func @target_env_vendor_id_device_type() attributes { + // CHECK: spv.target_env = #spv.target_env< + // CHECK-SAME: #spv.vce, + // CHECK-SAME: AMD:DiscreteGPU, + // CHECK-SAME: {}> + spv.target_env = #spv.target_env<#spv.vce, AMD:DiscreteGPU, {}> +} { return } + +// ----- + +func @target_env_vendor_id_device_type_device_id() attributes { + // CHECK: spv.target_env = #spv.target_env< + // CHECK-SAME: #spv.vce, + // CHECK-SAME: Qualcomm:IntegratedGPU:100925441, + // CHECK-SAME: {}> + spv.target_env = #spv.target_env<#spv.vce, Qualcomm:IntegratedGPU:0x6040001, {}> +} { return } + +// ----- + func @target_env_extra_fields() attributes { // expected-error @+6 {{expected '>'}} spv.target_env = #spv.target_env<