diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3045,7 +3045,7 @@ def SPV_Int32 : TypeAlias; def SPV_Float : FloatOfWidths<[16, 32, 64]>; def SPV_Float16or32 : FloatOfWidths<[16, 32]>; -def SPV_Vector : VectorOfLengthAndType<[2, 3, 4], +def SPV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16], [SPV_Bool, SPV_Integer, SPV_Float]>; // Component type check is done in the type parser for the following SPIR-V // dialect-specific types so we use "Any" here. @@ -3083,10 +3083,10 @@ "Cooperative Matrix">; class SPV_ScalarOrVectorOf : - AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>]>; + AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>; class SPV_ScalarOrVectorOrCoopMatrixOf : - AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>, + AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, SPV_CoopMatrixOfType<[type]>]>; def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -172,8 +172,17 @@ } bool CompositeType::isValid(VectorType type) { - return type.getRank() == 1 && type.getElementType().isa() && - type.getNumElements() >= 2 && type.getNumElements() <= 4; + switch (type.getNumElements()) { + case 2: + case 3: + case 4: + case 8: + case 16: + break; + default: + return false; + } + return type.getRank() == 1 && type.getElementType().isa(); } Type CompositeType::getElementType(unsigned index) const { @@ -233,6 +242,12 @@ StructType>( [&](auto type) { type.getCapabilities(capabilities, storage); }) .Case([&](VectorType type) { + auto vecSize = getNumElements(); + if (vecSize == 8 || vecSize == 16) { + static const Capability caps[] = {Capability::Vector16}; + ArrayRef ref(caps, llvm::array_lengthof(caps)); + capabilities.push_back(ref); + } return type.getElementType().cast().getCapabilities( capabilities, storage); }) diff --git a/mlir/test/Dialect/SPIRV/Serialization/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/ocl-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/ocl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/ocl-ops.mlir @@ -14,4 +14,10 @@ %0 = spv.OCL.s_abs %arg0 : i32 spv.Return } + + spv.func @vector_size16(%arg0 : vector<16xf32>) "None" { + // CHECK: {{%.*}} = spv.OCL.fabs {{%.*}} : vector<16xf32> + %0 = spv.OCL.fabs %arg0 : vector<16xf32> + spv.Return + } } diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -121,6 +121,18 @@ } } +// Using 16-element vectors requires Vector16. +// CHECK: requires #spv.vce +spv.module Logical GLSL450 attributes { + spv.target_env = #spv.target_env< + #spv.vce, {}> +} { + spv.func @iadd_v16_function(%val : vector<16xi32>) -> vector<16xi32> "None" { + %0 = spv.IAdd %val, %val : vector<16xi32> + spv.ReturnValue %0: vector<16xi32> + } +} + //===----------------------------------------------------------------------===// // Extension //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -843,7 +843,7 @@ func @logicalUnary(%arg0 : i32) { - // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4, but got 'i32'}} + // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} %0 = spv.LogicalNot %arg0 : i32 return }