diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2229,12 +2229,13 @@ DeclareOpInterfaceMethods, PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]>, - Arguments<(ins AnyVector:$vector, I64ArrayAttr:$transp)>, - Results<(outs AnyVector:$result)> { + Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$transp)>, + Results<(outs AnyVectorOfAnyRank:$result)> { let summary = "vector transpose operation"; let description = [{ Takes a n-D vector and returns the transposed n-D vector defined by - the permutation of ranks in the n-sized integer array attribute. + the permutation of ranks in the n-sized integer array attribute (in case + of 0-D vectors the array attribute must be empty). In the operation ```mlir diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1760,6 +1760,8 @@ // CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<4xi32> // CHECK: return %[[result]] : vector<4xi1> +// ----- + func.func @create_mask_1d_scalable(%a : index) -> vector<[4]xi1> { %v = vector.create_mask %a : vector<[4]xi1> return %v: vector<[4]xi1> @@ -1776,6 +1778,17 @@ // ----- +func.func @transpose_0d(%arg0: vector) -> vector { + %0 = vector.transpose %arg0, [] : vector to vector + return %0 : vector +} + +// CHECK-LABEL: func @transpose_0d +// CHECK-SAME: %[[A:.*]]: vector +// CHECK: return %[[A]] : vector + +// ----- + func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> { %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } : vector<16xf32> -> vector<16xf32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1145,11 +1145,25 @@ // ----- +func.func @transpose_rank_mismatch_0d(%arg0: vector) { + // expected-error@+1 {{'vector.transpose' op vector result rank mismatch: 1}} + %0 = vector.transpose %arg0, [] : vector to vector<100xf32> +} + +// ----- + func.func @transpose_rank_mismatch(%arg0: vector<4x16x11xf32>) { // expected-error@+1 {{'vector.transpose' op vector result rank mismatch: 1}} %0 = vector.transpose %arg0, [2, 1, 0] : vector<4x16x11xf32> to vector<100xf32> } +// ----- + +func.func @transpose_length_mismatch_0d(%arg0: vector) { + // expected-error@+1 {{'vector.transpose' op transposition length mismatch: 1}} + %0 = vector.transpose %arg0, [1] : vector to vector +} + // ----- func.func @transpose_length_mismatch(%arg0: vector<4x4xf32>) { diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -570,6 +570,22 @@ return %0 : vector<2x11x7x3xi32> } +// CHECK-LABEL: @transpose_fp_0d +func.func @transpose_fp_0d(%arg0: vector) -> vector { + // CHECK: %[[X:.*]] = vector.transpose %{{.*}}, [] : vector to vector + %0 = vector.transpose %arg0, [] : vector to vector + // CHECK: return %[[X]] : vector + return %0 : vector +} + +// CHECK-LABEL: @transpose_int_0d +func.func @transpose_int_0d(%arg0: vector) -> vector { + // CHECK: %[[X:.*]] = vector.transpose %{{.*}}, [] : vector to vector + %0 = vector.transpose %arg0, [] : vector to vector + // CHECK: return %[[X]] : vector + return %0 : vector +} + // CHECK-LABEL: @flat_transpose_fp func.func @flat_transpose_fp(%arg0: vector<16xf32>) -> vector<16xf32> { // CHECK: %[[X:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -120,6 +120,13 @@ return } +func.func @transpose_0d(%arg: vector) { + %1 = vector.transpose %arg, [] : vector to vector + // CHECK: ( 42 ) + vector.print %1: vector + return +} + func.func @entry() { %0 = arith.constant 42.0 : f32 %1 = arith.constant dense<0.0> : vector @@ -151,6 +158,8 @@ %5 = arith.constant dense<4.0> : vector call @fma_0d(%5) : (vector) -> () + %6 = arith.constant dense<42> : vector + call @transpose_0d(%6) : (vector) -> () return }