diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -4806,7 +4806,9 @@ if (op->getOperand(1).getType() != factorTy) return op->emitOpError("requires the same type for both vector operands"); + unsigned expectedNumAttrs = 0; if (auto intTy = factorTy.dyn_cast()) { + ++expectedNumAttrs; auto packedVectorFormat = op->getAttr(kPackedVectorFormatAttrName) .dyn_cast_or_null(); @@ -4816,15 +4818,20 @@ assert(packedVectorFormat.getValue() == spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && - "unknown Packed Vector format"); + "Unknown Packed Vector Format"); if (intTy.getWidth() != 32) return op->emitOpError( llvm::formatv("with specified Packed Vector Format ({0}) requires " "integer vector operands to be 32-bits wide", packedVectorFormat.getValue())); + } else { + if (op->hasAttr(kPackedVectorFormatAttrName)) + return op->emitOpError(llvm::formatv( + "with invalid format attribute for vector operands of type '{0}'", + factorTy)); } - if (op->getAttrs().size() > 1) + if (op->getAttrs().size() > expectedNumAttrs) return op->emitError( "op only supports the 'format' #spirv.packed_vector_format attribute"); diff --git a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir @@ -49,6 +49,14 @@ %r = spirv.SDot %a, %b : (i32, i64) -> i32 return %r : i32 } +// ----- + +func.func @sdot_vector_4xi8_bad_attr(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 { + // expected-error @+1 {{op with invalid format attribute for vector operands of type 'vector<4xi8>'}} + %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format}: + (vector<4xi8>, vector<4xi8>) -> i32 + return %r : i32 +} // ----- @@ -61,6 +69,14 @@ // ----- +func.func @udot_vector_4xi8_bad_attr(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 { + // expected-error @+1 {{op only supports the 'format' #spirv.packed_vector_format attribute}} + %r = spirv.UDot %a, %b {volatile = #spirv.decoration}: (vector<4xi8>, vector<4xi8>) -> i32 + return %r : i32 +} + +// ----- + func.func @sdot_scalar_bad_types(%a: i32, %b: i32) -> i16 { // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}} %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format}: (i32, i32) -> i16