diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h @@ -139,10 +139,10 @@ static StringRef getKindName(); /// Returns the (version, capabilities, extensions) triple attribute. - VerCapExtAttr getTripleAttr(); + VerCapExtAttr getTripleAttr() const; /// Returns the target version. - Version getVersion(); + Version getVersion() const; /// Returns the target extensions. VerCapExtAttr::ext_range getExtensions(); @@ -155,16 +155,16 @@ ArrayAttr getCapabilitiesAttr(); /// Returns the vendor ID. - Vendor getVendorID(); + Vendor getVendorID() const; /// Returns the device type. - DeviceType getDeviceType(); + DeviceType getDeviceType() const; /// Returns the device ID. - uint32_t getDeviceID(); + uint32_t getDeviceID() const; /// Returns the target resource limits. - ResourceLimitsAttr getResourceLimits(); + ResourceLimitsAttr getResourceLimits() const; static LogicalResult verifyConstructionInvariants(Location loc, VerCapExtAttr triple, diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h @@ -29,7 +29,7 @@ public: explicit TargetEnv(TargetEnvAttr targetAttr); - Version getVersion(); + Version getVersion() const; /// Returns true if the given capability is allowed. bool allows(Capability) const; @@ -43,9 +43,23 @@ /// Returns llvm::None otherwise. Optional allows(ArrayRef) const; + /// Returns the vendor ID. + Vendor getVendorID() const; + + /// Returns the device type. + DeviceType getDeviceType() const; + + /// Returns the device ID. + uint32_t getDeviceID() const; + /// Returns the MLIRContext. MLIRContext *getContext() const; + /// Returns the target resource limits. + ResourceLimitsAttr getResourceLimits() const; + + TargetEnvAttr getAttr() const { return targetAttr; } + /// Allows implicity converting to the underlying spirv::TargetEnvAttr. operator TargetEnvAttr() const { return targetAttr; } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp @@ -288,11 +288,11 @@ StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; } -spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() { +spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() const { return getImpl()->triple.cast(); } -spirv::Version spirv::TargetEnvAttr::getVersion() { +spirv::Version spirv::TargetEnvAttr::getVersion() const { return getTripleAttr().getVersion(); } @@ -312,17 +312,19 @@ return getTripleAttr().getCapabilitiesAttr(); } -spirv::Vendor spirv::TargetEnvAttr::getVendorID() { +spirv::Vendor spirv::TargetEnvAttr::getVendorID() const { return getImpl()->vendorID; } -spirv::DeviceType spirv::TargetEnvAttr::getDeviceType() { +spirv::DeviceType spirv::TargetEnvAttr::getDeviceType() const { return getImpl()->deviceType; } -uint32_t spirv::TargetEnvAttr::getDeviceID() { return getImpl()->deviceID; } +uint32_t spirv::TargetEnvAttr::getDeviceID() const { + return getImpl()->deviceID; +} -spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() { +spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const { return getImpl()->limits.cast(); } diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp @@ -38,7 +38,7 @@ } } -spirv::Version spirv::TargetEnv::getVersion() { +spirv::Version spirv::TargetEnv::getVersion() const { return targetAttr.getVersion(); } @@ -48,7 +48,7 @@ Optional spirv::TargetEnv::allows(ArrayRef caps) const { - auto chosen = llvm::find_if(caps, [this](spirv::Capability cap) { + const auto *chosen = llvm::find_if(caps, [this](spirv::Capability cap) { return givenCapabilities.count(cap); }); if (chosen != caps.end()) @@ -62,7 +62,7 @@ Optional spirv::TargetEnv::allows(ArrayRef exts) const { - auto chosen = llvm::find_if(exts, [this](spirv::Extension ext) { + const auto *chosen = llvm::find_if(exts, [this](spirv::Extension ext) { return givenExtensions.count(ext); }); if (chosen != exts.end()) @@ -70,6 +70,22 @@ return llvm::None; } +spirv::Vendor spirv::TargetEnv::getVendorID() const { + return targetAttr.getVendorID(); +} + +spirv::DeviceType spirv::TargetEnv::getDeviceType() const { + return targetAttr.getDeviceType(); +} + +uint32_t spirv::TargetEnv::getDeviceID() const { + return targetAttr.getDeviceID(); +} + +spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const { + return targetAttr.getResourceLimits(); +} + MLIRContext *spirv::TargetEnv::getContext() const { return targetAttr.getContext(); }