diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2470,6 +2470,7 @@ } }]; let assemblyFormat = "$source attr-dict `:` type($source)"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5184,6 +5184,32 @@ populateFromInt64AttrArray(getTransp(), results); } +//===----------------------------------------------------------------------===// +// PrintOp +//===----------------------------------------------------------------------===// + +LogicalResult vector::PrintOp::verify() { + Type tp = getPrintType(); + + // Find element type of a vector type. Assume direct scalar type otherwise. + VectorType vtp = tp.dyn_cast(); + Type eltType = vtp ? vtp.getElementType() : tp; + + // Now verify if element type is supported + if (eltType.isF32()) + return success(); + if (eltType.isF64()) + return success(); + if (eltType.isIndex()) + return success(); + if (auto intTy = eltType.dyn_cast()) + return success(); + + // Anything else not supported. + return emitError("operand must be a vector or scalar with floating-point or " + "integral type"); +} + //===----------------------------------------------------------------------===// // ConstantMaskOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -987,6 +987,14 @@ // ----- +func.func private @print_needs_buffer(%arg0: tensor<8xf32>) { + // expected-error@+1 {{operand must be a vector or scalar with floating-point or integral type}} + vector.print %arg0 : tensor<8xf32> + return +} + +// ----- + func.func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) { %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -401,13 +401,20 @@ return } -// CHECK-LABEL: @vector_print -func.func @vector_print(%arg0: vector<8x4xf32>) { +// CHECK-LABEL: @vector_print_on_vector +func.func @vector_print_on_vector(%arg0: vector<8x4xf32>) { // CHECK: vector.print %{{.*}} : vector<8x4xf32> vector.print %arg0 : vector<8x4xf32> return } +// CHECK-LABEL: @vector_print_on_scalar +func.func @vector_print_on_scalar(%arg0: i64) { + // CHECK: vector.print %{{.*}} : i64 + vector.print %arg0 : i64 + return +} + // CHECK-LABEL: @reshape func.func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) { // CHECK: %[[C2:.*]] = arith.constant 2 : index