diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3000,7 +3000,7 @@ def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>; def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>; -def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyStruct]>; +def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>; def SPV_Composite : AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>; def SPV_Type : AnyTypeOf<[ diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -71,16 +71,65 @@ }; } -// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType. -class CompositeType : public Type { +// Base SPIR-V type for providing availability queries. +class SPIRVType : public Type { public: using Type::Type; static bool classof(Type type); + /// The extension requirements for each type are following the + /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) + /// convention. + using ExtensionArrayRefVector = SmallVectorImpl>; + + /// Appends to `extensions` the extensions needed for this type to appear in + /// the given `storage` class. This method does not guarantee the uniqueness + /// of extensions; the same extension may be appended multiple times. + void getExtensions(ExtensionArrayRefVector &extensions, + Optional storage = llvm::None); + + /// The capability requirements for each type are following the + /// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D)) + /// convention. + using CapabilityArrayRefVector = SmallVectorImpl>; + + /// Appends to `capabilities` the capabilities needed for this type to appear + /// in the given `storage` class. This method does not guarantee the + /// uniqueness of capabilities; the same capability may be appended multiple + /// times. + void getCapabilities(CapabilityArrayRefVector &capabilities, + Optional storage = llvm::None); +}; + +// SPIR-V scalar type: bool type, integer type, floating point type. +class ScalarType : public SPIRVType { +public: + using SPIRVType::SPIRVType; + + static bool classof(Type type); + + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage = llvm::None); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage = llvm::None); +}; + +// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType. +class CompositeType : public SPIRVType { +public: + using SPIRVType::SPIRVType; + + static bool classof(Type type); + unsigned getNumElements() const; Type getElementType(unsigned) const; + + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage = llvm::None); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage = llvm::None); }; // SPIR-V array type @@ -105,11 +154,16 @@ bool hasLayout() const; uint64_t getArrayStride() const; + + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage = llvm::None); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage = llvm::None); }; // SPIR-V image type class ImageType - : public Type::TypeBase { + : public Type::TypeBase { public: using Base::Base; @@ -141,11 +195,16 @@ ImageSamplerUseInfo getSamplerUseInfo() const; ImageFormat getImageFormat() const; // TODO(ravishankarm): Add support for Access qualifier + + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage = llvm::None); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage = llvm::None); }; // SPIR-V pointer type -class PointerType - : public Type::TypeBase { +class PointerType : public Type::TypeBase { public: using Base::Base; @@ -156,11 +215,16 @@ Type getPointeeType() const; StorageClass getStorageClass() const; + + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage = llvm::None); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage = llvm::None); }; // SPIR-V run-time array type class RuntimeArrayType - : public Type::TypeBase { public: using Base::Base; @@ -170,6 +234,11 @@ static RuntimeArrayType get(Type elementType); Type getElementType() const; + + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage = llvm::None); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage = llvm::None); }; // SPIR-V struct type @@ -203,6 +272,31 @@ Type getElementType(unsigned) const; + /// Range class for element types. + class ElementTypeRange + : public ::mlir::detail::indexed_accessor_range_base< + ElementTypeRange, const Type *, Type, Type, Type> { + private: + using RangeBaseT::RangeBaseT; + + /// See `mlir::detail::indexed_accessor_range_base` for details. + static const Type *offset_base(const Type *object, ptrdiff_t index) { + return object + index; + } + /// See `mlir::detail::indexed_accessor_range_base` for details. + static Type dereference_iterator(const Type *object, ptrdiff_t index) { + return object[index]; + } + + /// Allow StructType class access to cconstructors. + friend class ElementTypeRange; + + /// Allow base class access to `offset_base` and `dereference_iterator`. + friend RangeBaseT; + }; + + ElementTypeRange getElementTypes() const; + bool hasLayout() const; uint64_t getOffset(unsigned) const; @@ -216,6 +310,11 @@ // Offset) associated with the `i`-th member of the StructType. void getMemberDecorations( unsigned i, SmallVectorImpl &memberDecorations) const; + + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage = llvm::None); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage = llvm::None); }; } // end namespace spirv diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -57,6 +57,7 @@ default: return {}; case Version::V_1_3: { + // The following manual ArrayRef constructor call is to satisfy GCC 5. static const Extension exts[] = {V_1_3_IMPLIED_EXTS}; return ArrayRef(exts, llvm::array_lengthof(exts)); } @@ -142,6 +143,17 @@ uint64_t ArrayType::getArrayStride() const { return getImpl()->layoutInfo; } +void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage) { + getElementType().cast().getExtensions(extensions, storage); +} + +void ArrayType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage) { + getElementType().cast().getCapabilities(capabilities, storage); +} + //===----------------------------------------------------------------------===// // CompositeType //===----------------------------------------------------------------------===// @@ -189,6 +201,50 @@ } } +void CompositeType::getExtensions( + SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage) { + switch (getKind()) { + case spirv::TypeKind::Array: + cast().getExtensions(extensions, storage); + break; + case spirv::TypeKind::RuntimeArray: + cast().getExtensions(extensions, storage); + break; + case spirv::TypeKind::Struct: + cast().getExtensions(extensions, storage); + break; + case StandardTypes::Vector: + cast().getElementType().cast().getExtensions( + extensions, storage); + break; + default: + llvm_unreachable("invalid composite type"); + } +} + +void CompositeType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage) { + switch (getKind()) { + case spirv::TypeKind::Array: + cast().getCapabilities(capabilities, storage); + break; + case spirv::TypeKind::RuntimeArray: + cast().getCapabilities(capabilities, storage); + break; + case spirv::TypeKind::Struct: + cast().getCapabilities(capabilities, storage); + break; + case StandardTypes::Vector: + cast().getElementType().cast().getCapabilities( + capabilities, storage); + break; + default: + llvm_unreachable("invalid composite type"); + } +} + //===----------------------------------------------------------------------===// // ImageType //===----------------------------------------------------------------------===// @@ -372,6 +428,20 @@ return getImpl()->getImageFormat(); } +void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &, + Optional) { + // Image types do not require extra extensions thus far. +} + +void ImageType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, Optional) { + if (auto dimCaps = spirv::getCapabilities(getDim())) + capabilities.push_back(*dimCaps); + + if (auto fmtCaps = spirv::getCapabilities(getImageFormat())) + capabilities.push_back(*fmtCaps); +} + //===----------------------------------------------------------------------===// // PointerType //===----------------------------------------------------------------------===// @@ -413,6 +483,35 @@ return getImpl()->getStorageClass(); } +void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage) { + if (storage) + assert(*storage == getStorageClass() && "inconsistent storage class!"); + + // Use this pointer type's storage class because this pointer indicates we are + // using the pointee type in that specific storage class. + getPointeeType().cast().getExtensions(extensions, + getStorageClass()); + + if (auto scExts = spirv::getExtensions(getStorageClass())) + extensions.push_back(*scExts); +} + +void PointerType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage) { + if (storage) + assert(*storage == getStorageClass() && "inconsistent storage class!"); + + // Use this pointer type's storage class because this pointer indicates we are + // using the pointee type in that specific storage class. + getPointeeType().cast().getCapabilities(capabilities, + getStorageClass()); + + if (auto scCaps = spirv::getCapabilities(getStorageClass())) + capabilities.push_back(*scCaps); +} + //===----------------------------------------------------------------------===// // RuntimeArrayType //===----------------------------------------------------------------------===// @@ -440,6 +539,181 @@ Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; } +void RuntimeArrayType::getExtensions( + SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage) { + getElementType().cast().getExtensions(extensions, storage); +} + +void RuntimeArrayType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage) { + { + static const Capability caps[] = {Capability::Shader}; + ArrayRef ref(caps, llvm::array_lengthof(caps)); + capabilities.push_back(ref); + } + getElementType().cast().getCapabilities(capabilities, storage); +} + +//===----------------------------------------------------------------------===// +// ScalarType +//===----------------------------------------------------------------------===// + +bool ScalarType::classof(Type type) { return type.isIntOrFloat(); } + +void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage) { + // 8- or 16-bit integer/floating-point numbers will require extra extensions + // to appear in interface storage classes. See SPV_KHR_16bit_storage and + // SPV_KHR_8bit_storage for more details. + if (!storage) + return; + + switch (*storage) { + case StorageClass::PushConstant: + case StorageClass::StorageBuffer: + case StorageClass::Uniform: + if (getIntOrFloatBitWidth() == 8) { + static const Extension exts[] = {Extension::SPV_KHR_8bit_storage}; + ArrayRef ref(exts, llvm::array_lengthof(exts)); + extensions.push_back(ref); + } + LLVM_FALLTHROUGH; + case StorageClass::Input: + case StorageClass::Output: + if (getIntOrFloatBitWidth() == 16) { + static const Extension exts[] = {Extension::SPV_KHR_16bit_storage}; + ArrayRef ref(exts, llvm::array_lengthof(exts)); + extensions.push_back(ref); + } + break; + default: + break; + } +} + +void ScalarType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage) { + unsigned bitwidth = getIntOrFloatBitWidth(); + + // 8- or 16-bit integer/floating-point numbers will require extra capabilities + // to appear in interface storage classes. See SPV_KHR_16bit_storage and + // SPV_KHR_8bit_storage for more details. + +#define STORAGE_CASE(storage, cap8, cap16) \ + case StorageClass::storage: { \ + if (bitwidth == 8) { \ + static const Capability caps[] = {Capability::cap8}; \ + ArrayRef ref(caps, llvm::array_lengthof(caps)); \ + capabilities.push_back(ref); \ + } else if (bitwidth == 16) { \ + static const Capability caps[] = {Capability::cap16}; \ + ArrayRef ref(caps, llvm::array_lengthof(caps)); \ + capabilities.push_back(ref); \ + } \ + } break + + if (storage) { + switch (*storage) { + STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16); + STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess, + StorageBuffer16BitAccess); + STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess, + StorageUniform16); + case StorageClass::Input: + case StorageClass::Output: + if (bitwidth == 16) { + static const Capability caps[] = {Capability::StorageInputOutput16}; + ArrayRef ref(caps, llvm::array_lengthof(caps)); + capabilities.push_back(ref); + } + break; + default: + break; + } + return; + } +#undef STORAGE_CASE + + // For other non-interface storage classes, require a different set of + // capabilities for special bitwidths. + +#define WIDTH_CASE(type, width) \ + case width: { \ + static const Capability caps[] = {Capability::type##width}; \ + ArrayRef ref(caps, llvm::array_lengthof(caps)); \ + capabilities.push_back(ref); \ + } break + + if (auto intType = dyn_cast()) { + switch (bitwidth) { + case 32: + case 1: + break; + WIDTH_CASE(Int, 8); + WIDTH_CASE(Int, 16); + WIDTH_CASE(Int, 64); + default: + llvm_unreachable("invalid bitwidth to getCapabilities"); + } + } else { + assert(isa()); + switch (bitwidth) { + case 32: + break; + WIDTH_CASE(Float, 16); + WIDTH_CASE(Float, 64); + default: + llvm_unreachable("invalid bitwidth to getCapabilities"); + } + } + +#undef WIDTH_CASE +} + +//===----------------------------------------------------------------------===// +// SPIRVType +//===----------------------------------------------------------------------===// + +bool SPIRVType::classof(Type type) { + return type.isa() || type.isa() || + (type.getKind() >= Type::FIRST_SPIRV_TYPE && + type.getKind() <= TypeKind::LAST_SPIRV_TYPE); +} + +void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage) { + if (auto scalarType = dyn_cast()) { + scalarType.getExtensions(extensions, storage); + } else if (auto compositeType = dyn_cast()) { + compositeType.getExtensions(extensions, storage); + } else if (auto ptrType = dyn_cast()) { + ptrType.getExtensions(extensions, storage); + } else if (auto imageType = dyn_cast()) { + imageType.getExtensions(extensions, storage); + } else { + llvm_unreachable("invalid SPIR-V Type to getExtensions"); + } +} + +void SPIRVType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage) { + if (auto scalarType = dyn_cast()) { + scalarType.getCapabilities(capabilities, storage); + } else if (auto compositeType = dyn_cast()) { + compositeType.getCapabilities(capabilities, storage); + } else if (auto ptrType = dyn_cast()) { + ptrType.getCapabilities(capabilities, storage); + } else if (auto imageType = dyn_cast()) { + imageType.getCapabilities(capabilities, storage); + } else { + llvm_unreachable("invalid SPIR-V Type to getCapabilities"); + } +} + //===----------------------------------------------------------------------===// // StructType //===----------------------------------------------------------------------===// @@ -540,18 +814,18 @@ } Type StructType::getElementType(unsigned index) const { - assert( - getNumElements() > index && - "element index is more than number of members of the SPIR-V StructType"); + assert(getNumElements() > index && "member index out of range"); return getImpl()->memberTypes[index]; } +StructType::ElementTypeRange StructType::getElementTypes() const { + return ElementTypeRange(getImpl()->memberTypes, getNumElements()); +} + bool StructType::hasLayout() const { return getImpl()->layoutInfo; } uint64_t StructType::getOffset(unsigned index) const { - assert( - getNumElements() > index && - "element index is more than number of members of the SPIR-V StructType"); + assert(getNumElements() > index && "member index out of range"); return getImpl()->layoutInfo[index]; } @@ -579,3 +853,16 @@ } } } + +void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage) { + for (Type elementType : getElementTypes()) + elementType.cast().getExtensions(extensions, storage); +} + +void StructType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage) { + for (Type elementType : getElementTypes()) + elementType.cast().getCapabilities(capabilities, storage); +} diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -32,6 +32,73 @@ }; } // namespace +/// Checks that `candidates` extension requirements are possible to be satisfied +/// with the given `allowedExtensions` and updates `deducedExtensions` if so. +/// Emits errors attaching to the given `op` on failures. +/// +/// `candidates` is a vector of vector for extension requirements following +/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) +/// convention. +static LogicalResult checkAndUpdateExtensionRequirements( + Operation *op, const llvm::SmallSet &allowedExtensions, + const spirv::SPIRVType::ExtensionArrayRefVector &candidates, + llvm::SetVector &deducedExtensions) { + for (const auto &ors : candidates) { + auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) { + return allowedExtensions.count(ext); + }); + + if (chosen != ors.end()) { + deducedExtensions.insert(*chosen); + } else { + SmallVector extStrings; + for (spirv::Extension ext : ors) + extStrings.push_back(spirv::stringifyExtension(ext)); + + emitError(op->getLoc(), "'") + << op->getName() << "' requires at least one extension in [" + << llvm::join(extStrings, ", ") + << "] but none allowed in target environment"; + return failure(); + } + } + return success(); +} + +/// Checks that `candidates`capability requirements are possible to be satisfied +/// with the given `allowedCapabilities` and updates `deducedCapabilities` if +/// so. Emits errors attaching to the given `op` on failures. +/// +/// `candidates` is a vector of vector for capability requirements following +/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) +/// convention. +static LogicalResult checkAndUpdateCapabilityRequirements( + Operation *op, + const llvm::SmallSet &allowedCapabilities, + const spirv::SPIRVType::CapabilityArrayRefVector &candidates, + llvm::SetVector &deducedCapabilities) { + for (const auto &ors : candidates) { + auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) { + return allowedCapabilities.count(cap); + }); + + if (chosen != ors.end()) { + deducedCapabilities.insert(*chosen); + } else { + SmallVector capStrings; + for (spirv::Capability cap : ors) + capStrings.push_back(spirv::stringifyCapability(cap)); + + emitError(op->getLoc(), "'") + << op->getName() << "' requires at least one capability in [" + << llvm::join(capStrings, ", ") + << "] but none allowed in target environment"; + return failure(); + } + } + return success(); +} + void UpdateVCEPass::runOnOperation() { spirv::ModuleOp module = getOperation(); @@ -69,6 +136,7 @@ // Walk each SPIR-V op to deduce the minimal version/extension/capability // requirements. WalkResult walkResult = module.walk([&](Operation *op) { + // Op min version requirements if (auto minVersion = dyn_cast(op)) { deducedVersion = std::max(deducedVersion, minVersion.getMinVersion()); if (deducedVersion > allowedVersion) { @@ -81,64 +149,44 @@ } } - // Deduce this op's extension requirement. For each op, the query interfacce - // returns a vector of vector for its extension requirements following - // ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) - // convention. Ops not implementing QueryExtensionInterface do not require - // extensions to be available. - if (auto extensions = dyn_cast(op)) { - for (const auto &ors : extensions.getExtensions()) { - bool satisfied = false; // True when at least one extension can be used - for (spirv::Extension ext : ors) { - if (allowedExtensions.count(ext)) { - deducedExtensions.insert(ext); - satisfied = true; - break; - } - } - - if (!satisfied) { - SmallVector extStrings; - for (spirv::Extension ext : ors) - extStrings.push_back(spirv::stringifyExtension(ext)); - - emitError(op->getLoc(), "'") - << op->getName() << "' requires at least one extension in [" - << llvm::join(extStrings, ", ") - << "] but none allowed in target environment"; - return WalkResult::interrupt(); - } - } - } + // Op extension requirements + if (auto extensions = dyn_cast(op)) + if (failed(checkAndUpdateExtensionRequirements( + op, allowedExtensions, extensions.getExtensions(), + deducedExtensions))) + return WalkResult::interrupt(); - // Deduce this op's capability requirement. For each op, the queryinterface - // returns a vector of vector for its capability requirements following - // ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D)) - // convention. Ops not implementing QueryExtensionInterface do not require - // extensions to be available. - if (auto capabilities = dyn_cast(op)) { - for (const auto &ors : capabilities.getCapabilities()) { - bool satisfied = false; // True when at least one capability can be used - for (spirv::Capability cap : ors) { - if (allowedCapabilities.count(cap)) { - deducedCapabilities.insert(cap); - satisfied = true; - break; - } - } - - if (!satisfied) { - SmallVector capStrings; - for (spirv::Capability cap : ors) - capStrings.push_back(spirv::stringifyCapability(cap)); - - emitError(op->getLoc(), "'") - << op->getName() << "' requires at least one capability in [" - << llvm::join(capStrings, ", ") - << "] but none allowed in target environment"; - return WalkResult::interrupt(); - } - } + // Op capability requirements + if (auto capabilities = dyn_cast(op)) + if (failed(checkAndUpdateCapabilityRequirements( + op, allowedCapabilities, capabilities.getCapabilities(), + deducedCapabilities))) + return WalkResult::interrupt(); + + SmallVector valueTypes; + valueTypes.append(op->operand_type_begin(), op->operand_type_end()); + valueTypes.append(op->result_type_begin(), op->result_type_end()); + + // Special treatment for global variables, whose type requirements are + // conveyed by type attributes. + if (auto globalVar = dyn_cast(op)) + valueTypes.push_back(globalVar.type()); + + // Requirements from values' types + SmallVector, 4> typeExtensions; + SmallVector, 8> typeCapabilities; + for (Type valueType : valueTypes) { + typeExtensions.clear(); + valueType.cast().getExtensions(typeExtensions); + if (failed(checkAndUpdateExtensionRequirements( + op, allowedExtensions, typeExtensions, deducedExtensions))) + return WalkResult::interrupt(); + + typeCapabilities.clear(); + valueType.cast().getCapabilities(typeCapabilities); + if (failed(checkAndUpdateCapabilityRequirements( + op, allowedCapabilities, typeCapabilities, deducedCapabilities))) + return WalkResult::interrupt(); } return WalkResult::advance(); diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -107,6 +107,36 @@ } } +// Test type required capabilities + +// Using 8-bit integers in non-interface storage class requires Int8. +// CHECK: requires #spv.vce +spv.module Logical GLSL450 attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + spv.func @iadd_function(%val : i8) -> i8 "None" { + %0 = spv.IAdd %val, %val : i8 + spv.ReturnValue %0: i8 + } +} + +// Using 16-bit floats in non-interface storage class requires Float16. +// CHECK: requires #spv.vce +spv.module Logical GLSL450 attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + spv.func @fadd_function(%val : f16) -> f16 "None" { + %0 = spv.FAdd %val, %val : f16 + spv.ReturnValue %0: f16 + } +} + //===----------------------------------------------------------------------===// // Extension //===----------------------------------------------------------------------===// @@ -144,3 +174,35 @@ spv.ReturnValue %0: i32 } } + +// Test type required extensions + +// Using 8-bit integers in interface storage class requires additional +// extensions and capabilities. +// CHECK: requires #spv.vce +spv.module Logical GLSL450 attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + spv.func @iadd_storage_buffer(%ptr : !spv.ptr) -> i16 "None" { + %0 = spv.Load "StorageBuffer" %ptr : i16 + %1 = spv.IAdd %0, %0 : i16 + spv.ReturnValue %1: i16 + } +} + +// Complicated nested types +// * Buffer requires ImageBuffer or SampledBuffer. +// * Rg32f requires StorageImageExtendedFormats. +// CHECK: requires #spv.vce +spv.module Logical GLSL450 attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + spv.globalVariable @data : !spv.ptr, Uniform> + spv.globalVariable @img : !spv.ptr, UniformConstant> +}