diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h @@ -138,9 +138,11 @@ using Base::Base; /// Gets a TargetEnvAttr instance. - static TargetEnvAttr get(VerCapExtAttr triple, Vendor vendorID, - DeviceType deviceType, uint32_t deviceId, - ResourceLimitsAttr limits); + static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits, + ClientAPI clientAPI = ClientAPI::Unknown, + Vendor vendorID = Vendor::Unknown, + DeviceType deviceType = DeviceType::Unknown, + uint32_t deviceId = kUnknownDeviceID); /// Returns the attribute kind's name (without the 'spirv.' prefix). static StringRef getKindName(); @@ -161,6 +163,9 @@ /// Returns the target capabilities as an integer array attribute. ArrayAttr getCapabilitiesAttr(); + /// Returns the client API. + ClientAPI getClientAPI() const; + /// Returns the vendor ID. Vendor getVendorID() const; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -267,7 +267,7 @@ // An accelerator other than GPU or CPU def SPIRV_DT_Other : I32EnumAttrCase<"Other", 3>; // Information missing. -def SPIRV_DT_Unknown : I32EnumAttrCase<"Unknown", 4>; +def SPIRV_DT_Unknown : I32EnumAttrCase<"Unknown", 0xffffffff>; def SPIRV_DeviceTypeAttr : SPIRV_I32EnumAttr< "DeviceType", "valid SPIR-V device types", "device_type", [ @@ -283,7 +283,7 @@ def SPIRV_V_NVIDIA : I32EnumAttrCase<"NVIDIA", 5>; def SPIRV_V_Qualcomm : I32EnumAttrCase<"Qualcomm", 6>; def SPIRV_V_SwiftShader : I32EnumAttrCase<"SwiftShader", 7>; -def SPIRV_V_Unknown : I32EnumAttrCase<"Unknown", 0xff>; +def SPIRV_V_Unknown : I32EnumAttrCase<"Unknown", 0xffffffff>; def SPIRV_VendorAttr : SPIRV_I32EnumAttr< "Vendor", "recognized SPIR-V vendor strings", "vendor", [ @@ -292,6 +292,18 @@ SPIRV_V_Unknown ]>; +def SPIRV_CA_Metal : I32EnumAttrCase<"Metal", 0>; +def SPIRV_CA_OpenCL : I32EnumAttrCase<"OpenCL", 1>; +def SPIRV_CA_Vulkan : I32EnumAttrCase<"Vulkan", 2>; +def SPIRV_CA_WebGPU : I32EnumAttrCase<"WebGPU", 3>; +def SPIRV_CA_Unknown : I32EnumAttrCase<"Unknown", 0xffffffff>; + +def SPIRV_ClientAPIAttr : SPIRV_I32EnumAttr< + "ClientAPI", "recognized SPIR-V client APIs", "client_api", [ + SPIRV_CA_Metal, SPIRV_CA_OpenCL, SPIRV_CA_Vulkan, SPIRV_CA_WebGPU, + SPIRV_CA_Unknown + ]>; + //===----------------------------------------------------------------------===// // SPIR-V extension definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp @@ -82,17 +82,18 @@ }; struct TargetEnvAttributeStorage : public AttributeStorage { - using KeyTy = std::tuple; + using KeyTy = + std::tuple; - TargetEnvAttributeStorage(Attribute triple, Vendor vendorID, - DeviceType deviceType, uint32_t deviceID, - Attribute limits) - : triple(triple), limits(limits), vendorID(vendorID), - deviceType(deviceType), deviceID(deviceID) {} + TargetEnvAttributeStorage(Attribute triple, ClientAPI clientAPI, + Vendor vendorID, DeviceType deviceType, + uint32_t deviceID, Attribute limits) + : triple(triple), limits(limits), clientAPI(clientAPI), + vendorID(vendorID), deviceType(deviceType), deviceID(deviceID) {} bool operator==(const KeyTy &key) const { - return key == - std::make_tuple(triple, vendorID, deviceType, deviceID, limits); + return key == std::make_tuple(triple, clientAPI, vendorID, deviceType, + deviceID, limits); } static TargetEnvAttributeStorage * @@ -100,11 +101,12 @@ return new (allocator.allocate()) TargetEnvAttributeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key), std::get<3>(key), - std::get<4>(key)); + std::get<4>(key), std::get<5>(key)); } Attribute triple; Attribute limits; + ClientAPI clientAPI; Vendor vendorID; DeviceType deviceType; uint32_t deviceID; @@ -282,14 +284,13 @@ // TargetEnvAttr //===----------------------------------------------------------------------===// -spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple, - Vendor vendorID, - DeviceType deviceType, - uint32_t deviceID, - ResourceLimitsAttr limits) { +spirv::TargetEnvAttr spirv::TargetEnvAttr::get( + spirv::VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI, + Vendor vendorID, DeviceType deviceType, uint32_t deviceID) { assert(triple && limits && "expected valid triple and limits"); MLIRContext *context = triple.getContext(); - return Base::get(context, triple, vendorID, deviceType, deviceID, limits); + return Base::get(context, triple, clientAPI, vendorID, deviceType, deviceID, + limits); } StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; } @@ -318,6 +319,10 @@ return getTripleAttr().getCapabilitiesAttr(); } +spirv::ClientAPI spirv::TargetEnvAttr::getClientAPI() const { + return getImpl()->clientAPI; +} + spirv::Vendor spirv::TargetEnvAttr::getVendorID() const { return getImpl()->vendorID; } @@ -523,6 +528,22 @@ if (parser.parseAttribute(tripleAttr) || parser.parseComma()) return {}; + auto clientAPI = spirv::ClientAPI::Unknown; + if (succeeded(parser.parseOptionalKeyword("api"))) { + if (parser.parseEqual()) + return {}; + auto loc = parser.getCurrentLocation(); + StringRef apiStr; + if (parser.parseKeyword(&apiStr)) + return {}; + if (auto apiSymbol = spirv::symbolizeClientAPI(apiStr)) + clientAPI = *apiSymbol; + else + parser.emitError(loc, "unknown client API: ") << apiStr; + if (parser.parseComma()) + return {}; + } + // Parse [vendor[:device-type[:device-id]]] Vendor vendorID = Vendor::Unknown; DeviceType deviceType = DeviceType::Unknown; @@ -531,22 +552,20 @@ auto loc = parser.getCurrentLocation(); StringRef vendorStr; if (succeeded(parser.parseOptionalKeyword(&vendorStr))) { - if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr)) { + if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr)) vendorID = *vendorSymbol; - } else { + else parser.emitError(loc, "unknown vendor: ") << vendorStr; - } if (succeeded(parser.parseOptionalColon())) { loc = parser.getCurrentLocation(); StringRef deviceTypeStr; if (parser.parseKeyword(&deviceTypeStr)) return {}; - if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr)) { + if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr)) deviceType = *deviceTypeSymbol; - } else { + else parser.emitError(loc, "unknown device type: ") << deviceTypeStr; - } if (succeeded(parser.parseOptionalColon())) { loc = parser.getCurrentLocation(); @@ -563,8 +582,8 @@ if (parser.parseAttribute(limitsAttr) || parser.parseGreater()) return {}; - return spirv::TargetEnvAttr::get(tripleAttr, vendorID, deviceType, deviceID, - limitsAttr); + return spirv::TargetEnvAttr::get(tripleAttr, limitsAttr, clientAPI, vendorID, + deviceType, deviceID); } Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser, @@ -616,6 +635,9 @@ static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) { printer << spirv::TargetEnvAttr::getKindName() << "<#spirv."; print(targetEnv.getTripleAttr(), printer); + auto clientAPI = targetEnv.getClientAPI(); + if (clientAPI != spirv::ClientAPI::Unknown) + printer << ", api=" << clientAPI; spirv::Vendor vendorID = targetEnv.getVendorID(); spirv::DeviceType deviceType = targetEnv.getDeviceType(); uint32_t deviceID = targetEnv.getDeviceID(); diff --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/FunctionInterfaces.h" @@ -170,10 +171,10 @@ auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0, {spirv::Capability::Shader}, ArrayRef(), context); - return spirv::TargetEnvAttr::get(triple, spirv::Vendor::Unknown, - spirv::DeviceType::Unknown, - spirv::TargetEnvAttr::kUnknownDeviceID, - spirv::getDefaultResourceLimits(context)); + return spirv::TargetEnvAttr::get( + triple, spirv::getDefaultResourceLimits(context), + spirv::ClientAPI::Unknown, spirv::Vendor::Unknown, + spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID); } spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) { diff --git a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir --- a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir +++ b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir @@ -118,6 +118,24 @@ // ----- +func.func @target_env_client_api() attributes { + // CHECK: spirv.target_env = #spirv.target_env< + // CHECK-SAME: #spirv.vce, + // CHECK-SAME: api=Metal, + // CHECK-SAME: #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, api=Metal, #spirv.resource_limits<>> +} { return } + +// ----- + +func.func @target_env_client_api() attributes { + // CHECK: spirv.target_env = #spirv.target_env + // CHECK-NOT: api= + spirv.target_env = #spirv.target_env<#spirv.vce, api=Unknown, #spirv.resource_limits<>> +} { return } + +// ----- + func.func @target_env_vendor_id() attributes { // CHECK: spirv.target_env = #spirv.target_env< // CHECK-SAME: #spirv.vce, @@ -148,6 +166,17 @@ // ----- +func.func @target_env_client_api_vendor_id_device_type_device_id() attributes { + // CHECK: spirv.target_env = #spirv.target_env< + // CHECK-SAME: #spirv.vce, + // CHECK-SAME: api=Vulkan, + // CHECK-SAME: Qualcomm:IntegratedGPU:100925441, + // CHECK-SAME: #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, api=Vulkan, Qualcomm:IntegratedGPU:0x6040001, #spirv.resource_limits<>> +} { return } + +// ----- + func.func @target_env_extra_fields() attributes { // expected-error @+3 {{expected '>'}} spirv.target_env = #spirv.target_env<