diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h @@ -52,6 +52,9 @@ using Base::Base; /// Gets a TargetEnvAttr instance. + static TargetEnvAttr get(Version version, ArrayRef extensions, + ArrayRef capabilities, + DictionaryAttr limits); static TargetEnvAttr get(IntegerAttr version, ArrayAttr extensions, ArrayAttr capabilities, DictionaryAttr limits); @@ -86,7 +89,7 @@ ArrayAttr getCapabilitiesAttr(); /// Returns the target resource limits. - DictionaryAttr getResourceLimits(); + ResourceLimitsAttr getResourceLimits(); static bool kindof(unsigned kind) { return kind == AttrKind::TargetEnv; } diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp @@ -48,6 +48,27 @@ } // namespace spirv } // namespace mlir +spirv::TargetEnvAttr spirv::TargetEnvAttr::get( + spirv::Version version, ArrayRef extensions, + ArrayRef capabilities, DictionaryAttr limits) { + Builder b(limits.getContext()); + + auto versionAttr = b.getI32IntegerAttr(static_cast(version)); + + SmallVector extAttrs; + extAttrs.reserve(extensions.size()); + for (spirv::Extension ext : extensions) + extAttrs.push_back(b.getStringAttr(spirv::stringifyExtension(ext))); + + SmallVector capAttrs; + capAttrs.reserve(capabilities.size()); + for (spirv::Capability cap : capabilities) + capAttrs.push_back(b.getI32IntegerAttr(static_cast(cap))); + + return get(versionAttr, b.getArrayAttr(extAttrs), b.getArrayAttr(capAttrs), + limits); +} + spirv::TargetEnvAttr spirv::TargetEnvAttr::get(IntegerAttr version, ArrayAttr extensions, ArrayAttr capabilities, @@ -98,8 +119,8 @@ return getImpl()->capabilities.cast(); } -DictionaryAttr spirv::TargetEnvAttr::getResourceLimits() { - return getImpl()->limits.cast(); +spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() { + return getImpl()->limits.cast(); } LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants(