diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4588,6 +4588,13 @@ return emitOpError("incompatible input/initial value shapes"); } + // Verify supported reduction kind. + Type eltType = getDestType().getElementType(); + if (!isSupportedCombiningKind(getKind(), eltType)) + return emitOpError("unsupported reduction type ") + << eltType << " for kind '" << stringifyCombiningKind(getKind()) + << "'"; + return success(); } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1523,6 +1523,15 @@ // ----- +func @scan_unsupported_kind(%arg0: vector<2x3xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> { + // expected-error@+1 {{'vector.scan' op unsupported reduction type 'f32' for kind 'xor'}} + %0:2 = vector.scan , %arg0, %arg1 {inclusive = true, reduction_dim = 0} : + vector<2x3xf32>, vector<3xf32> + return %0#0 : vector<2x3xf32> +} + +// ----- + func @invalid_splat(%v : f32) { // expected-error@+1 {{invalid kind of type specified}} vector.splat %v : memref<8xf32>