diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2440,7 +2440,11 @@ } def Vector_PrintOp : - Vector_Op<"print", []>, Arguments<(ins AnyType:$source)> { + Vector_Op<"print", []>, + Arguments<(ins Type>:$source)> { let summary = "print operation (for testing and debugging)"; let description = [{ Prints the source vector (or scalar) to stdout in human readable diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -987,6 +987,14 @@ // ----- +func.func private @print_needs_vector(%arg0: tensor<8xf32>) { + // expected-error@+1 {{op operand #0 must be , but got 'tensor<8xf32>'}} + vector.print %arg0 : tensor<8xf32> + return +} + +// ----- + func.func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) { %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -401,13 +401,20 @@ return } -// CHECK-LABEL: @vector_print -func.func @vector_print(%arg0: vector<8x4xf32>) { +// CHECK-LABEL: @vector_print_on_vector +func.func @vector_print_on_vector(%arg0: vector<8x4xf32>) { // CHECK: vector.print %{{.*}} : vector<8x4xf32> vector.print %arg0 : vector<8x4xf32> return } +// CHECK-LABEL: @vector_print_on_scalar +func.func @vector_print_on_scalar(%arg0: i64) { + // CHECK: vector.print %{{.*}} : i64 + vector.print %arg0 : i64 + return +} + // CHECK-LABEL: @reshape func.func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) { // CHECK: %[[C2:.*]] = arith.constant 2 : index