diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1269,6 +1269,33 @@ }]; } +def Vector_TransposeOp : + Vector_Op<"transpose", [NoSideEffect, + PredOpTrait<"operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>]>, + Arguments<(ins AnyVector:$vector)>, + Results<(outs AnyVector:$result)> { + let summary = "vector transpose operation"; + let description = [{ + Takes a 2-D vector and returns the transposed 2-D vector. + + %1 = vector.tranpose %0 : vector<2x3xf32> + + example: [ [a, b, c], [ [a, d], + [d, e, f] ] -> [b, e], + [c, f] ] + + Note that this operation is restricted to 2-D vectors to remain + close to efficient transpose implementations during lowering. The + restriction on matrices may be relaxed in the future though. + }]; + let extraClassDeclaration = [{ + VectorType getVectorType() { + return vector().getType().cast(); + } + }]; +} + def Vector_TupleGetOp : Vector_Op<"tuple_get", [NoSideEffect]>, Arguments<(ins TupleOf<[AnyVector]>:$vectors, APIntAttr:$index)>, diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1520,6 +1520,38 @@ static LogicalResult verify(TupleOp op) { return success(); } +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +static VectorType transposeType(VectorType vectorType) { + SmallVector transShape; + transShape.push_back(vectorType.getDimSize(1)); + transShape.push_back(vectorType.getDimSize(0)); + return VectorType::get(transShape, vectorType.getElementType()); +} + +static ParseResult parseTransposeOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType operandInfo; + VectorType vectorType; + return failure( + parser.parseOperand(operandInfo) || parser.parseColonType(vectorType) || + parser.resolveOperand(operandInfo, vectorType, result.operands) || + parser.addTypeToList(transposeType(vectorType), result.types)); +} + +static void print(OpAsmPrinter &p, TransposeOp op) { + p << op.getOperationName() << op.vector() << " : " << op.vector().getType(); +} + +static LogicalResult verify(TransposeOp op) { + int64_t rank = op.getVectorType().getRank(); + if (rank != 2) + return op.emitOpError("unsupported transposition rank: ") << rank; + return success(); +} + //===----------------------------------------------------------------------===// // TupleGetOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1049,6 +1049,13 @@ // ----- +func @transpose_unsupported_rank(%arg0: vector<4x16x11xf32>) { + // expected-error@+1 {{'vector.transpose' op unsupported transposition rank: 3}} + %0 = vector.transpose %arg0 : vector<4x16x11xf32> +} + +// ----- + func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>>) { // expected-error@+1 {{expects operand to be a memref with no layout}} %0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -315,3 +315,15 @@ // CHECK: return %[[X]] : i32 return %0 : i32 } + +// CHECK-LABEL: transpose_2d_fp +func @transpose_2d_fp(%arg0: vector<3x7xf32>) -> vector<7x3xf32> { + %0 = vector.transpose %arg0 : vector<3x7xf32> + return %0 : vector<7x3xf32> +} + +// CHECK-LABEL: transpose_2d_int +func @transpose_2d_int(%arg0: vector<11x7xi32>) -> vector<7x11xi32> { + %0 = vector.transpose %arg0 : vector<11x7xi32> + return %0 : vector<7x11xi32> +}