diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -629,10 +629,10 @@ Pure, AllTypesMatch<["lhs", "rhs", "acc", "result"]>, DeclareOpInterfaceMethods ] # ElementwiseMappable.traits>, - Arguments<(ins AnyVectorOfAnyRank:$lhs, - AnyVectorOfAnyRank:$rhs, - AnyVectorOfAnyRank:$acc)>, - Results<(outs AnyVectorOfAnyRank:$result)> { + Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs, + VectorOfAnyRankOf<[AnyFloat]>:$rhs, + VectorOfAnyRankOf<[AnyFloat]>:$acc)>, + Results<(outs VectorOfAnyRankOf<[AnyFloat]>:$result)> { let summary = "vector fused multiply-add"; let description = [{ Multiply-add expressions operate on n-D vectors and compute a fused diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -44,6 +44,13 @@ // ----- +func.func @fma_vector_4xi32(%arg0: vector<4xi32>) { + // expected-error@+1 {{'vector.fma' op operand #0 must be vector of floating-point value}} + %1 = vector.fma %arg0, %arg0, %arg0 : vector<4xi32> +} + +// ----- + func.func @shuffle_elt_type_mismatch(%arg0: vector<2xf32>, %arg1: vector<2xi32>) { // expected-error@+1 {{'vector.shuffle' op failed to verify that second operand v2 and result have same element type}} %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xf32>, vector<2xi32>