diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1315,10 +1315,12 @@ VectorType getResultType() { return result().getType().cast(); } + void getTransp(SmallVectorImpl &results); }]; let assemblyFormat = [{ $vector `,` $transp attr-dict `:` type($vector) `to` type($result) }]; + let hasFolder = 1; } def Vector_TupleGetOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1524,6 +1524,23 @@ // TransposeOp //===----------------------------------------------------------------------===// +// Eliminates transpose operations, which produce values identical to their +// input values. This happens when the dimensions of the input vector remain in +// their original order after the transpose operation. +OpFoldResult TransposeOp::fold(ArrayRef operands) { + SmallVector transp; + getTransp(transp); + + // Check if the permutation of the dimensions contains sequential values: + // {0, 1, 2, ...}. + for (int64_t i = 0, e = transp.size(); i < e; i++) { + if (transp[i] != i) + return {}; + } + + return vector(); +} + static LogicalResult verify(TransposeOp op) { VectorType vectorType = op.getVectorType(); VectorType resultType = op.getResultType(); @@ -1549,6 +1566,10 @@ return success(); } +void TransposeOp::getTransp(SmallVectorImpl &results) { + populateFromInt64AttrArray(transp(), results); +} + //===----------------------------------------------------------------------===// // TupleGetOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -87,3 +87,72 @@ // CHECK: vector.constant_mask [1, 1] : vector<2x1xi1> return %1 : vector<2x1xi1> } + +// ----- + +// CHECK-LABEL: transpose_1D_identity +// CHECK-SAME: ([[ARG:%.*]]: vector<4xf32>) +func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> { + // CHECK-NOT: transpose + %0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32> + // CHECK-NEXT: return [[ARG]] + return %0 : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: transpose_2D_identity +// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>) +func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> { + // CHECK-NOT: transpose + %0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32> + // CHECK-NEXT: return [[ARG]] + return %0 : vector<4x3xf32> +} + +// ----- + +// CHECK-LABEL: transpose_3D_identity +// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>) +func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> { + // CHECK-NOT: transpose + %0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32> + // CHECK-NEXT: return [[ARG]] + return %0 : vector<4x3x2xf32> +} + +// ----- + +// CHECK-LABEL: transpose_2D_sequence +// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>) +func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<3x4xf32> { + // CHECK-NOT: transpose + %0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32> + // CHECK: [[T1:%.*]] = vector.transpose [[ARG]], [1, 0] + %1 = vector.transpose %0, [1, 0] : vector<4x3xf32> to vector<3x4xf32> + // CHECK-NOT: transpose + %2 = vector.transpose %1, [0, 1] : vector<3x4xf32> to vector<3x4xf32> + // CHECK: [[ADD:%.*]] = addf [[T1]], [[T1]] + %4 = addf %1, %2 : vector<3x4xf32> + // CHECK-NEXT: return [[ADD]] + return %4 : vector<3x4xf32> +} + +// ----- + +// CHECK-LABEL: transpose_3D_sequence +// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>) +func @transpose_3D_sequence(%arg : vector<4x3x2xf32>) -> vector<2x3x4xf32> { + // CHECK: [[T0:%.*]] = vector.transpose [[ARG]], [1, 2, 0] + %0 = vector.transpose %arg, [1, 2, 0] : vector<4x3x2xf32> to vector<3x2x4xf32> + // CHECK-NOT: transpose + %1 = vector.transpose %0, [0, 1, 2] : vector<3x2x4xf32> to vector<3x2x4xf32> + // CHECK: [[T2:%.*]] = vector.transpose [[T0]], [1, 0, 2] + %2 = vector.transpose %1, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32> + // CHECK: [[ADD:%.*]] = addf [[T2]], [[T2]] + %3 = addf %2, %2 : vector<2x3x4xf32> + // CHECK-NOT: transpose + %4 = vector.transpose %3, [0, 1, 2] : vector<2x3x4xf32> to vector<2x3x4xf32> + // CHECK-NEXT: return [[ADD]] + return %4 : vector<2x3x4xf32> +}