diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -23,9 +23,20 @@ /// Type conversion from standard types to SPIR-V types for shader interface. /// -/// For composite types, this converter additionally performs type wrapping to +/// Non-32-bit scalar types requires special hardware support that may not exist +/// on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar types +/// require special capabilities or extensions. Right now if a scalar type of a +/// certain bitwidth is not supported in the target environment, we use 32-bit +/// ones unconditionally. This requires the runtime to also feed in data with +/// a matched bitwidth and layout for interface types. The runtime can do that +/// by inspecting the SPIR-V module. +/// +/// For memref types, this converter additionally performs type wrapping to /// satisfy shader interface requirements: shader interface types must be /// pointers to structs. +/// +/// TODO(antiagainst): We might want to introduce a way to control how +/// unsupported bitwidth are handled and explicitly fail if wanted. class SPIRVTypeConverter : public TypeConverter { public: explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr); diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h @@ -44,7 +44,7 @@ Optional allows(ArrayRef) const; /// Returns the MLIRContext. - MLIRContext *getContext(); + MLIRContext *getContext() const; /// Allows implicity converting to the underlying spirv::TargetEnvAttr. operator TargetEnvAttr() const { return targetAttr; } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -24,6 +24,64 @@ using namespace mlir; +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +/// Checks that `candidates` extension requirements are possible to be satisfied +/// with the given `targetEnv`. +/// +/// `candidates` is a vector of vector for extension requirements following +/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) +/// convention. +template +static LogicalResult checkExtensionRequirements( + LabelT label, const spirv::TargetEnv &targetEnv, + const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { + for (const auto &ors : candidates) { + if (targetEnv.allows(ors)) + continue; + + SmallVector extStrings; + for (spirv::Extension ext : ors) + extStrings.push_back(spirv::stringifyExtension(ext)); + + LLVM_DEBUG(llvm::dbgs() + << label << " illegal: requires at least one extension in [" + << llvm::join(extStrings, ", ") + << "] but none allowed in target environment\n"); + return failure(); + } + return success(); +} + +/// Checks that `candidates`capability requirements are possible to be satisfied +/// with the given `isAllowedFn`. +/// +/// `candidates` is a vector of vector for capability requirements following +/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) +/// convention. +template +static LogicalResult checkCapabilityRequirements( + LabelT label, const spirv::TargetEnv &targetEnv, + const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { + for (const auto &ors : candidates) { + if (targetEnv.allows(ors)) + continue; + + SmallVector capStrings; + for (spirv::Capability cap : ors) + capStrings.push_back(spirv::stringifyCapability(cap)); + + LLVM_DEBUG(llvm::dbgs() + << label << " illegal: requires at least one capability in [" + << llvm::join(capStrings, ", ") + << "] but none allowed in target environment\n"); + return failure(); + } + return success(); +} + //===----------------------------------------------------------------------===// // Type Conversion //===----------------------------------------------------------------------===// @@ -157,62 +215,206 @@ return llvm::None; } +/// Converts a scalar `type` to a suitable type under the given `targetEnv`. +static Optional +convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type, + Optional storageClass = {}) { + // Get extension and capability requirements for the given type. + SmallVector, 1> extensions; + SmallVector, 2> capabilities; + type.getExtensions(extensions, storageClass); + type.getCapabilities(capabilities, storageClass); + + // If all requirements are met, then we can accept this type as-is. + if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && + succeeded(checkExtensionRequirements(type, targetEnv, extensions))) + return type; + + // Otherwise we need to adjust the type, which really means adjusting the + // bitwidth given this is a scalar type. + // TODO(antiagainst): We are unconditionally converting the bitwidth here, + // this might be okay for non-interface types (i.e., types used in + // Priviate/Function storage classes), but not for interface types (i.e., + // types used in StorageBuffer/Uniform/PushConstant/etc. storage classes). + // This is because the later actually affects the ABI contract with the + // runtime. So we may want to expose a control on SPIRVTypeConverter to fail + // conversion if we cannot change there. + + if (auto floatType = type.dyn_cast()) { + LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); + return Builder(targetEnv.getContext()).getF32Type(); + } + + auto intType = type.cast(); + LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); + return IntegerType::get(/*width=*/32, intType.getSignedness(), + targetEnv.getContext()); +} + +/// Converts a vector `type` to a suitable type under the given `targetEnv`. +static Optional +convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type, + Optional storageClass = {}) { + if (!spirv::CompositeType::isValid(type)) { + // TODO(antiagainst): One-element vector types can be translated into scalar + // types. Vector types with more than four elements can be translated into + // array types. + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: 1- and > 4-element unimplemented\n"); + return llvm::None; + } + + // Get extension and capability requirements for the given type. + SmallVector, 1> extensions; + SmallVector, 2> capabilities; + type.cast().getExtensions(extensions, storageClass); + type.cast().getCapabilities(capabilities, storageClass); + + // If all requirements are met, then we can accept this type as-is. + if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && + succeeded(checkExtensionRequirements(type, targetEnv, extensions))) + return type; + + auto elementType = convertScalarType( + targetEnv, type.getElementType().cast(), storageClass); + if (elementType) + return VectorType::get(type.getShape(), *elementType); + return llvm::None; +} + +static Optional convertTensorType(const spirv::TargetEnv &targetEnv, + TensorType type) { + // TODO(ravishankarm) : Handle dynamic shapes. + if (!type.hasStaticShape()) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: dynamic shape unimplemented\n"); + return llvm::None; + } + + auto scalarType = type.getElementType().dyn_cast(); + if (!scalarType) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot convert non-scalar element type\n"); + return llvm::None; + } + + Optional scalarSize = getTypeNumBytes(scalarType); + Optional tensorSize = getTypeNumBytes(type); + if (!scalarSize || !tensorSize) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot deduce element count\n"); + return llvm::None; + } + + auto arrayElemCount = *tensorSize / *scalarSize; + auto arrayElemType = convertScalarType(targetEnv, scalarType); + if (!arrayElemType) + return llvm::None; + Optional arrayElemSize = getTypeNumBytes(*arrayElemType); + if (!arrayElemSize) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot deduce converted element size\n"); + return llvm::None; + } + + return spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize); +} + +static Optional convertMemrefType(const spirv::TargetEnv &targetEnv, + MemRefType type) { + // TODO(ravishankarm) : Handle dynamic shapes. + if (!type.hasStaticShape()) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: dynamic shape unimplemented\n"); + return llvm::None; + } + + auto scalarType = type.getElementType().dyn_cast(); + if (!scalarType) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot convert non-scalar element type\n"); + return llvm::None; + } + + Optional scalarSize = getTypeNumBytes(scalarType); + Optional memrefSize = getTypeNumBytes(type); + if (!scalarSize || !memrefSize) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot deduce element count\n"); + return llvm::None; + } + + auto arrayElemCount = *memrefSize / *scalarSize; + + auto storageClass = + SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace()); + if (!storageClass) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot convert memory space\n"); + return llvm::None; + } + + auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass); + if (!arrayElemType) + return llvm::None; + Optional arrayElemSize = getTypeNumBytes(*arrayElemType); + if (!arrayElemSize) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot deduce converted element size\n"); + return llvm::None; + } + + auto arrayType = + spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize); + + // Wrap in a struct to satisfy Vulkan interface requirements. + auto structType = spirv::StructType::get(arrayType, 0); + return spirv::PointerType::get(structType, *storageClass); +} + SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr) : targetEnv(targetAttr) { - addConversion([](Type type) -> Optional { - // If the type is already valid in SPIR-V, directly return. - return type.isa() ? type : Optional(); - }); + // Add conversions. The order matters here: later ones will be tried earlier. + + // All other cases failed. Then we cannot convert this type. + addConversion([](Type type) { return llvm::None; }); + + // Allow all SPIR-V dialect specific types. This assumes all standard types + // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType) + // were tried before. + // + // TODO(antiagainst): this assumes that the SPIR-V types are valid to use in + // the given target environment, which should be the case if the whole + // pipeline is driven by the same target environment. Still, we probably still + // want to validate and convert to be safe. + addConversion([](spirv::SPIRVType type) { return type; }); + addConversion([](IndexType indexType) { return SPIRVTypeConverter::getIndexType(indexType.getContext()); }); - addConversion([this](MemRefType memRefType) -> Type { - auto elementType = convertType(memRefType.getElementType()); - if (!elementType) - return Type(); - - auto elementSize = getTypeNumBytes(elementType); - if (!elementSize) - return Type(); - - // TODO(ravishankarm) : Handle dynamic shapes. - if (memRefType.hasStaticShape()) { - auto arraySize = getTypeNumBytes(memRefType); - if (!arraySize) - return Type(); - - auto arrayType = spirv::ArrayType::get( - elementType, arraySize.getValue() / elementSize.getValue(), - elementSize.getValue()); - - // Wrap in a struct to satisfy Vulkan interface requirements. - auto structType = spirv::StructType::get(arrayType, 0); - if (auto sc = getStorageClassForMemorySpace(memRefType.getMemorySpace())) - return spirv::PointerType::get(structType, *sc); - return Type(); - } - return Type(); + + addConversion([this](IntegerType intType) -> Optional { + if (auto scalarType = intType.dyn_cast()) + return convertScalarType(targetEnv, scalarType); + return llvm::None; + }); + + addConversion([this](FloatType floatType) -> Optional { + if (auto scalarType = floatType.dyn_cast()) + return convertScalarType(targetEnv, scalarType); + return llvm::None; + }); + + addConversion([this](VectorType vectorType) { + return convertVectorType(targetEnv, vectorType); }); - addConversion([this](TensorType tensorType) -> Type { - // TODO(ravishankarm) : Handle dynamic shapes. - if (!tensorType.hasStaticShape()) - return Type(); - - auto elementType = convertType(tensorType.getElementType()); - if (!elementType) - return Type(); - - auto elementSize = getTypeNumBytes(elementType); - if (!elementSize) - return Type(); - - auto tensorSize = getTypeNumBytes(tensorType); - if (!tensorSize) - return Type(); - - return spirv::ArrayType::get(elementType, - tensorSize.getValue() / elementSize.getValue(), - elementSize.getValue()); + + addConversion([this](TensorType tensorType) { + return convertTensorType(targetEnv, tensorType); + }); + + addConversion([this](MemRefType memRefType) { + return convertMemrefType(targetEnv, memRefType); }); } @@ -427,58 +629,6 @@ spirv::TargetEnvAttr targetAttr) : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {} -/// Checks that `candidates` extension requirements are possible to be satisfied -/// with the given `targetEnv`. -/// -/// `candidates` is a vector of vector for extension requirements following -/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) -/// convention. -static LogicalResult checkExtensionRequirements( - Operation *op, const spirv::TargetEnv &targetEnv, - const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { - for (const auto &ors : candidates) { - if (targetEnv.allows(ors)) - continue; - - SmallVector extStrings; - for (spirv::Extension ext : ors) - extStrings.push_back(spirv::stringifyExtension(ext)); - - LLVM_DEBUG(llvm::dbgs() << op->getName() - << " illegal: requires at least one extension in [" - << llvm::join(extStrings, ", ") - << "] but none allowed in target environment\n"); - return failure(); - } - return success(); -} - -/// Checks that `candidates`capability requirements are possible to be satisfied -/// with the given `isAllowedFn`. -/// -/// `candidates` is a vector of vector for capability requirements following -/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) -/// convention. -static LogicalResult checkCapabilityRequirements( - Operation *op, const spirv::TargetEnv &targetEnv, - const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { - for (const auto &ors : candidates) { - if (targetEnv.allows(ors)) - continue; - - SmallVector capStrings; - for (spirv::Capability cap : ors) - capStrings.push_back(spirv::stringifyCapability(cap)); - - LLVM_DEBUG(llvm::dbgs() << op->getName() - << " illegal: requires at least one capability in [" - << llvm::join(capStrings, ", ") - << "] but none allowed in target environment\n"); - return failure(); - } - return success(); -} - bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) { // Make sure this op is available at the given version. Ops not implementing // QueryMinVersionInterface/QueryMaxVersionInterface are available to all @@ -504,7 +654,7 @@ // implementing QueryExtensionInterface do not require extensions to be // available. if (auto extensions = dyn_cast(op)) - if (failed(checkExtensionRequirements(op, this->targetEnv, + if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, extensions.getExtensions()))) return false; @@ -512,7 +662,7 @@ // implementing QueryCapabilityInterface do not require capabilities to be // available. if (auto capabilities = dyn_cast(op)) - if (failed(checkCapabilityRequirements(op, this->targetEnv, + if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, capabilities.getCapabilities()))) return false; @@ -532,13 +682,14 @@ for (Type valueType : valueTypes) { typeExtensions.clear(); valueType.cast().getExtensions(typeExtensions); - if (failed(checkExtensionRequirements(op, this->targetEnv, typeExtensions))) + if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, + typeExtensions))) return false; typeCapabilities.clear(); valueType.cast().getCapabilities(typeCapabilities); - if (failed( - checkCapabilityRequirements(op, this->targetEnv, typeCapabilities))) + if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, + typeCapabilities))) return false; } diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp @@ -71,7 +71,9 @@ return llvm::None; } -MLIRContext *spirv::TargetEnv::getContext() { return targetAttr.getContext(); } +MLIRContext *spirv::TargetEnv::getContext() const { + return targetAttr.getContext(); +} //===----------------------------------------------------------------------===// // Utility functions diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir @@ -0,0 +1,552 @@ +// RUN: mlir-opt -split-input-file -convert-std-to-spirv %s -o - | FileCheck %s + +//===----------------------------------------------------------------------===// +// Integer types +//===----------------------------------------------------------------------===// + +// Check that non-32-bit integer types are converted to 32-bit types if the +// corresponding capabilities are not available. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: spv.func @integer8 +// CHECK-SAME: i32 +// CHECK-SAME: si32 +// CHECK-SAME: ui32 +func @integer8(%arg0: i8, %arg1: si8, %arg2: ui8) { return } + +// CHECK-LABEL: spv.func @integer16 +// CHECK-SAME: i32 +// CHECK-SAME: si32 +// CHECK-SAME: ui32 +func @integer16(%arg0: i16, %arg1: si16, %arg2: ui16) { return } + +// CHECK-LABEL: spv.func @integer64 +// CHECK-SAME: i32 +// CHECK-SAME: si32 +// CHECK-SAME: ui32 +func @integer64(%arg0: i64, %arg1: si64, %arg2: ui64) { return } + +} // end module + +// ----- + +// Check that non-32-bit integer types are kept untouched if the corresponding +// capabilities are available. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: spv.func @integer8 +// CHECK-SAME: i8 +// CHECK-SAME: si8 +// CHECK-SAME: ui8 +func @integer8(%arg0: i8, %arg1: si8, %arg2: ui8) { return } + +// CHECK-LABEL: spv.func @integer16 +// CHECK-SAME: i16 +// CHECK-SAME: si16 +// CHECK-SAME: ui16 +func @integer16(%arg0: i16, %arg1: si16, %arg2: ui16) { return } + +// CHECK-LABEL: spv.func @integer64 +// CHECK-SAME: i64 +// CHECK-SAME: si64 +// CHECK-SAME: ui64 +func @integer64(%arg0: i64, %arg1: si64, %arg2: ui64) { return } + +} // end module + +// ----- + +// Check that weird bitwidths are not supported. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-NOT: spv.func @integer4 +func @integer4(%arg0: i4) { return } + +// CHECK-NOT: spv.func @integer128 +func @integer128(%arg0: i128) { return } + +// CHECK-NOT: spv.func @integer42 +func @integer42(%arg0: i42) { return } + +} // end module +// ----- + +//===----------------------------------------------------------------------===// +// Index type +//===----------------------------------------------------------------------===// + +// The index type is always converted into i32. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: spv.func @index_type +// CHECK-SAME: %{{.*}}: i32 +func @index_type(%arg0: index) { return } + +} // end module + +// ----- + +//===----------------------------------------------------------------------===// +// Float types +//===----------------------------------------------------------------------===// + +// Check that non-32-bit float types are converted to 32-bit types if the +// corresponding capabilities are not available. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: spv.func @float16 +// CHECK-SAME: f32 +func @float16(%arg0: f16) { return } + +// CHECK-LABEL: spv.func @float64 +// CHECK-SAME: f32 +func @float64(%arg0: f64) { return } + +} // end module + +// ----- + +// Check that non-32-bit float types are kept untouched if the corresponding +// capabilities are available. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: spv.func @float16 +// CHECK-SAME: f16 +func @float16(%arg0: f16) { return } + +// CHECK-LABEL: spv.func @float64 +// CHECK-SAME: f64 +func @float64(%arg0: f64) { return } + +} // end module + +// ----- + +// Check that bf16 is not supported. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-NOT: spv.func @bf16_type +func @bf16_type(%arg0: bf16) { return } + +} // end module + +// ----- + +//===----------------------------------------------------------------------===// +// Vector types +//===----------------------------------------------------------------------===// + +// Check that capabilities for scalar types affects vector types too: no special +// capabilities available means using turning element types to 32-bit. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: spv.func @int_vector +// CHECK-SAME: vector<2xi32> +// CHECK-SAME: vector<3xsi32> +// CHECK-SAME: vector<4xui32> +func @int_vector( + %arg0: vector<2xi8>, + %arg1: vector<3xsi16>, + %arg2: vector<4xui64> +) { return } + +// CHECK-LABEL: spv.func @float_vector +// CHECK-SAME: vector<2xf32> +// CHECK-SAME: vector<3xf32> +func @float_vector( + %arg0: vector<2xf16>, + %arg1: vector<3xf64> +) { return } + +} // end module + +// ----- + +// Check that capabilities for scalar types affects vector types too: having +// special capabilities means keep vector types untouched. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: spv.func @int_vector +// CHECK-SAME: vector<2xi8> +// CHECK-SAME: vector<3xsi16> +// CHECK-SAME: vector<4xui64> +func @int_vector( + %arg0: vector<2xi8>, + %arg1: vector<3xsi16>, + %arg2: vector<4xui64> +) { return } + +// CHECK-LABEL: spv.func @float_vector +// CHECK-SAME: vector<2xf16> +// CHECK-SAME: vector<3xf64> +func @float_vector( + %arg0: vector<2xf16>, + %arg1: vector<3xf64> +) { return } + +} // end module + +// ----- + +// Check that 1- or > 4-element vectors are not supported. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-NOT: spv.func @one_element_vector +func @one_element_vector(%arg0: vector<1xi32>) { return } + +// CHECK-NOT: spv.func @large_vector +func @large_vector(%arg0: vector<1024xi32>) { return } + +} // end module + +// ----- + +//===----------------------------------------------------------------------===// +// MemRef types +//===----------------------------------------------------------------------===// + +// Check that using non-32-bit scalar types in interface storage classes +// requires special capability and extension: convert them to 32-bit if not +// satisfied. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: spv.func @memref_8bit_StorageBuffer +// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return } + +// CHECK-LABEL: spv.func @memref_8bit_Uniform +// CHECK-SAME: !spv.ptr [0]>, Uniform> +func @memref_8bit_Uniform(%arg0: memref<16xsi8, 1>) { return } + +// CHECK-LABEL: spv.func @memref_8bit_PushConstant +// CHECK-SAME: !spv.ptr [0]>, PushConstant> +func @memref_8bit_PushConstant(%arg0: memref<16xui8, 3>) { return } + +// CHECK-LABEL: spv.func @memref_16bit_StorageBuffer +// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +func @memref_16bit_StorageBuffer(%arg0: memref<16xi16, 0>) { return } + +// CHECK-LABEL: spv.func @memref_16bit_Uniform +// CHECK-SAME: !spv.ptr [0]>, Uniform> +func @memref_16bit_Uniform(%arg0: memref<16xsi16, 1>) { return } + +// CHECK-LABEL: spv.func @memref_16bit_PushConstant +// CHECK-SAME: !spv.ptr [0]>, PushConstant> +func @memref_16bit_PushConstant(%arg0: memref<16xui16, 3>) { return } + +// CHECK-LABEL: spv.func @memref_16bit_Input +// CHECK-SAME: !spv.ptr [0]>, Input> +func @memref_16bit_Input(%arg3: memref<16xf16, 7>) { return } + +// CHECK-LABEL: spv.func @memref_16bit_Output +// CHECK-SAME: !spv.ptr [0]>, Output> +func @memref_16bit_Output(%arg4: memref<16xf16, 8>) { return } + +} // end module + +// ----- + +// Check that using non-32-bit scalar types in interface storage classes +// requires special capability and extension: keep as-is when the capability +// and extension is available. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: spv.func @memref_8bit_PushConstant +// CHECK-SAME: !spv.ptr [0]>, PushConstant> +func @memref_8bit_PushConstant(%arg0: memref<16xi8, 3>) { return } + +// CHECK-LABEL: spv.func @memref_16bit_PushConstant +// CHECK-SAME: !spv.ptr [0]>, PushConstant> +// CHECK-SAME: !spv.ptr [0]>, PushConstant> +func @memref_16bit_PushConstant( + %arg0: memref<16xi16, 3>, + %arg1: memref<16xf16, 3> +) { return } + +} // end module + +// ----- + +// Check that using non-32-bit scalar types in interface storage classes +// requires special capability and extension: keep as-is when the capability +// and extension is available. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: spv.func @memref_8bit_StorageBuffer +// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return } + +// CHECK-LABEL: spv.func @memref_16bit_StorageBuffer +// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +func @memref_16bit_StorageBuffer( + %arg0: memref<16xi16, 0>, + %arg1: memref<16xf16, 0> +) { return } + +} // end module + +// ----- + +// Check that using non-32-bit scalar types in interface storage classes +// requires special capability and extension: keep as-is when the capability +// and extension is available. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: spv.func @memref_8bit_Uniform +// CHECK-SAME: !spv.ptr [0]>, Uniform> +func @memref_8bit_Uniform(%arg0: memref<16xi8, 1>) { return } + +// CHECK-LABEL: spv.func @memref_16bit_Uniform +// CHECK-SAME: !spv.ptr [0]>, Uniform> +// CHECK-SAME: !spv.ptr [0]>, Uniform> +func @memref_16bit_Uniform( + %arg0: memref<16xi16, 1>, + %arg1: memref<16xf16, 1> +) { return } + +} // end module + +// ----- + +// Check that using non-32-bit scalar types in interface storage classes +// requires special capability and extension: keep as-is when the capability +// and extension is available. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: spv.func @memref_16bit_Input +// CHECK-SAME: !spv.ptr [0]>, Input> +func @memref_16bit_Input(%arg3: memref<16xf16, 7>) { return } + +// CHECK-LABEL: spv.func @memref_16bit_Output +// CHECK-SAME: !spv.ptr [0]>, Output> +func @memref_16bit_Output(%arg4: memref<16xi16, 8>) { return } + +} // end module + +// ----- + +// Check that memref offset and strides affect the array size. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: spv.func @memref_offset_strides +func @memref_offset_strides( +// CHECK-SAME: !spv.array<64 x f32 [4]> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<72 x f32 [4]> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<256 x f32 [4]> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<64 x f32 [4]> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<88 x f32 [4]> [0]>, StorageBuffer> + %arg0: memref<16x4xf32, offset: 0, strides: [4, 1]>, // tightly packed; row major + %arg1: memref<16x4xf32, offset: 8, strides: [4, 1]>, // offset 8 + %arg2: memref<16x4xf32, offset: 0, strides: [16, 1]>, // pad 12 after each row + %arg3: memref<16x4xf32, offset: 0, strides: [1, 16]>, // tightly packed; col major + %arg4: memref<16x4xf32, offset: 0, strides: [1, 22]>, // pad 4 after each col + +// CHECK-SAME: !spv.array<64 x f16 [2]> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<72 x f16 [2]> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<256 x f16 [2]> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<64 x f16 [2]> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<88 x f16 [2]> [0]>, StorageBuffer> + %arg5: memref<16x4xf16, offset: 0, strides: [4, 1]>, + %arg6: memref<16x4xf16, offset: 8, strides: [4, 1]>, + %arg7: memref<16x4xf16, offset: 0, strides: [16, 1]>, + %arg8: memref<16x4xf16, offset: 0, strides: [1, 16]>, + %arg9: memref<16x4xf16, offset: 0, strides: [1, 22]> +) { return } + +} // end module + +// ----- + +// Check that dynamic shapes are not supported. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: func @unranked_memref +// CHECK-SAME: memref<*xi32> +func @unranked_memref(%arg0: memref<*xi32>) { return } + +// CHECK-LABEL: func @dynamic_dim_memref +// CHECK-SAME: memref<8x?xi32> +func @dynamic_dim_memref(%arg0: memref<8x?xi32>) { return } + +} // end module + +// ----- + +//===----------------------------------------------------------------------===// +// Tensor types +//===----------------------------------------------------------------------===// + +// Check that tensor element types are kept untouched with proper capabilites. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: spv.func @int_tensor_types +// CHECK-SAME: !spv.array<32 x i64 [8]> +// CHECK-SAME: !spv.array<32 x i32 [4]> +// CHECK-SAME: !spv.array<32 x i16 [2]> +// CHECK-SAME: !spv.array<32 x i8 [1]> +func @int_tensor_types( + %arg0: tensor<8x4xi64>, + %arg1: tensor<8x4xi32>, + %arg2: tensor<8x4xi16>, + %arg3: tensor<8x4xi8> +) { return } + +// CHECK-LABEL: spv.func @float_tensor_types +// CHECK-SAME: !spv.array<32 x f64 [8]> +// CHECK-SAME: !spv.array<32 x f32 [4]> +// CHECK-SAME: !spv.array<32 x f16 [2]> +func @float_tensor_types( + %arg0: tensor<8x4xf64>, + %arg1: tensor<8x4xf32>, + %arg2: tensor<8x4xf16> +) { return } + +} // end module + +// ----- + +// Check that tensor element types are changed to 32-bit without capabilities. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: spv.func @int_tensor_types +// CHECK-SAME: !spv.array<32 x i32 [4]> +// CHECK-SAME: !spv.array<32 x i32 [4]> +// CHECK-SAME: !spv.array<32 x i32 [4]> +// CHECK-SAME: !spv.array<32 x i32 [4]> +func @int_tensor_types( + %arg0: tensor<8x4xi64>, + %arg1: tensor<8x4xi32>, + %arg2: tensor<8x4xi16>, + %arg3: tensor<8x4xi8> +) { return } + +// CHECK-LABEL: spv.func @float_tensor_types +// CHECK-SAME: !spv.array<32 x f32 [4]> +// CHECK-SAME: !spv.array<32 x f32 [4]> +// CHECK-SAME: !spv.array<32 x f32 [4]> +func @float_tensor_types( + %arg0: tensor<8x4xf64>, + %arg1: tensor<8x4xf32>, + %arg2: tensor<8x4xf16> +) { return } + +} // end module + +// ----- + +// Check that dynamic shapes are not supported. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: func @unranked_tensor +// CHECK-SAME: tensor<*xi32> +func @unranked_tensor(%arg0: tensor<*xi32>) { return } + +// CHECK-LABEL: func @dynamic_dim_tensor +// CHECK-SAME: tensor<8x?xi32> +func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return } + +} // end module