diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -62,3 +62,10 @@ %result = constant [0.1 : f64, -1.0 : f64] : complex return %result : complex } + +// CHECK-LABEL: func @vector_splat_0d( +func @vector_splat_0d(%a: f32) -> vector { + // CHECK: splat %{{.*}} : vector + %0 = splat %a : vector + return %0 : vector +} diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -16,6 +16,13 @@ // ----- +func @broadcast_rank_too_high_0d(%arg0: vector<1xf32>) { + // expected-error@+1 {{'vector.broadcast' op source rank higher than destination rank}} + %1 = vector.broadcast %arg0 : vector<1xf32> to vector +} + +// ----- + func @broadcast_dim1_mismatch(%arg0: vector<7xf32>) { // expected-error@+1 {{'vector.broadcast' op dimension mismatch (7 vs. 3)}} %1 = vector.broadcast %arg0 : vector<7xf32> to vector<3xf32> @@ -79,7 +86,7 @@ } // ----- - + func @extract_element(%arg0: vector<4xf32>) { %c = arith.constant 3 : i32 // expected-error@+1 {{expected position for 1-D vector}} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -149,16 +149,20 @@ } // CHECK-LABEL: @vector_broadcast -func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> { +func @vector_broadcast(%a: f32, %b: vector, %c: vector<16xf32>, %d: vector<1x16xf32>, %e: vector<8x1xf32>) -> vector<8x16xf32> { + // CHECK: vector.broadcast %{{.*}} : f32 to vector + %0 = vector.broadcast %a : f32 to vector + // CHECK: vector.broadcast %{{.*}} : vector to vector<4xf32> + %1 = vector.broadcast %b : vector to vector<4xf32> // CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32> - %0 = vector.broadcast %a : f32 to vector<16xf32> + %2 = vector.broadcast %a : f32 to vector<16xf32> // CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32> - %1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32> + %3 = vector.broadcast %c : vector<16xf32> to vector<8x16xf32> // CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32> - %2 = vector.broadcast %c : vector<1x16xf32> to vector<8x16xf32> + %4 = vector.broadcast %d : vector<1x16xf32> to vector<8x16xf32> // CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32> - %3 = vector.broadcast %d : vector<8x1xf32> to vector<8x16xf32> - return %3 : vector<8x16xf32> + %5 = vector.broadcast %e : vector<8x1xf32> to vector<8x16xf32> + return %4 : vector<8x16xf32> } // CHECK-LABEL: @shuffle1D