diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1269,6 +1269,45 @@ }]; } +def Vector_TransposeOp : + Vector_Op<"transpose", [NoSideEffect, + PredOpTrait<"operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>]>, + Arguments<(ins AnyVector:$vector, I64ArrayAttr:$transp)>, + Results<(outs AnyVector:$result)> { + let summary = "vector transpose operation"; + let description = [{ + Takes a n-D vector and returns the transposed n-D vector defined by + the permutation of ranks in the n-sized integer array attribute. + In the operation + + %1 = vector.tranpose %0, [i_1, .., i_n] + : vector + to vector + + the transp array [i_1, .., i_n] must be a permutation of [0, .., n-1]. + + Example: + + %1 = vector.tranpose %0, [1, 0] : vector<2x3xf32> to vector<3x2xf32> + + [ [a, b, c], [ [a, d], + [d, e, f] ] -> [b, e], + [c, f] ] + }]; + let extraClassDeclaration = [{ + VectorType getVectorType() { + return vector().getType().cast(); + } + VectorType getResultType() { + return result().getType().cast(); + } + }]; + let assemblyFormat = [{ + $vector `,` $transp attr-dict `:` type($vector) `to` type($result) + }]; +} + def Vector_TupleGetOp : Vector_Op<"tuple_get", [NoSideEffect]>, Arguments<(ins TupleOf<[AnyVector]>:$vectors, APIntAttr:$index)>, diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1521,6 +1521,35 @@ static LogicalResult verify(TupleOp op) { return success(); } //===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(TransposeOp op) { + VectorType vectorType = op.getVectorType(); + VectorType resultType = op.getResultType(); + int64_t rank = resultType.getRank(); + if (vectorType.getRank() != rank) + return op.emitOpError("vector result rank mismatch: ") << rank; + // Verify transposition array. + auto transpAttr = op.transp().getValue(); + int64_t size = transpAttr.size(); + if (rank != size) + return op.emitOpError("transposition length mismatch: ") << size; + SmallVector seen(rank, false); + for (auto ta : llvm::enumerate(transpAttr)) { + int64_t i = ta.value().cast().getInt(); + if (i < 0 || i >= rank) + return op.emitOpError("transposition index out of range: ") << i; + if (seen[i]) + return op.emitOpError("duplicate position index: ") << i; + seen[i] = true; + if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i)) + return op.emitOpError("dimension size mismatch at: ") << i; + } + return success(); +} + +//===----------------------------------------------------------------------===// // TupleGetOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1049,6 +1049,41 @@ // ----- +func @transpose_rank_mismatch(%arg0: vector<4x16x11xf32>) { + // expected-error@+1 {{'vector.transpose' op vector result rank mismatch: 1}} + %0 = vector.transpose %arg0, [2, 1, 0] : vector<4x16x11xf32> to vector<100xf32> +} + +// ----- + +func @transpose_length_mismatch(%arg0: vector<4x4xf32>) { + // expected-error@+1 {{'vector.transpose' op transposition length mismatch: 3}} + %0 = vector.transpose %arg0, [2, 0, 1] : vector<4x4xf32> to vector<4x4xf32> +} + +// ----- + +func @transpose_index_oob(%arg0: vector<4x4xf32>) { + // expected-error@+1 {{'vector.transpose' op transposition index out of range: 2}} + %0 = vector.transpose %arg0, [2, 0] : vector<4x4xf32> to vector<4x4xf32> +} + +// ----- + +func @transpose_index_dup(%arg0: vector<4x4xf32>) { + // expected-error@+1 {{'vector.transpose' op duplicate position index: 0}} + %0 = vector.transpose %arg0, [0, 0] : vector<4x4xf32> to vector<4x4xf32> +} + +// ----- + +func @transpose_dim_size_mismatch(%arg0: vector<11x7x3x2xi32>) { + // expected-error@+1 {{'vector.transpose' op dimension size mismatch at: 0}} + %0 = vector.transpose %arg0, [3, 0, 1, 2] : vector<11x7x3x2xi32> to vector<2x3x7x11xi32> +} + +// ----- + func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>>) { // expected-error@+1 {{expects operand to be a memref with no layout}} %0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -315,3 +315,15 @@ // CHECK: return %[[X]] : i32 return %0 : i32 } + +// CHECK-LABEL: transpose_fp +func @transpose_fp(%arg0: vector<3x7xf32>) -> vector<7x3xf32> { + %0 = vector.transpose %arg0, [1, 0] : vector<3x7xf32> to vector<7x3xf32> + return %0 : vector<7x3xf32> +} + +// CHECK-LABEL: transpose_int +func @transpose_int(%arg0: vector<11x7x3x2xi32>) -> vector<2x11x7x3xi32> { + %0 = vector.transpose %arg0, [3, 0, 1, 2] : vector<11x7x3x2xi32> to vector<2x11x7x3xi32> + return %0 : vector<2x11x7x3xi32> +}