diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1320,6 +1320,7 @@ let assemblyFormat = [{ $vector `,` $transp attr-dict `:` type($vector) `to` type($result) }]; + let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1566,6 +1566,55 @@ return success(); } +namespace { + +// Rewrites two back-to-back TransposeOp operations into a single TransposeOp. +class TransposeFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + // Wrapper around TransposeOp::getTransp() for cleaner code. + auto getPermutation = [](TransposeOp transpose) { + SmallVector permutation; + transpose.getTransp(permutation); + return permutation; + }; + + // Composes two permutations: result[i] = permutation1[permutation2[i]]. + auto composePermutations = [](ArrayRef permutation1, + ArrayRef permutation2) { + SmallVector result; + for (auto index : permutation2) + result.push_back(permutation1[index]); + return result; + }; + + // Return if the input of 'transposeOp' is not defined by another transpose. + TransposeOp parentTransposeOp = + dyn_cast_or_null(transposeOp.vector().getDefiningOp()); + if (!parentTransposeOp) + return failure(); + + SmallVector permutation = composePermutations( + getPermutation(parentTransposeOp), getPermutation(transposeOp)); + // Replace 'transposeOp' with a new transpose operation. + rewriter.replaceOpWithNewOp( + transposeOp, transposeOp.getResult().getType(), + parentTransposeOp.vector(), + vector::getVectorSubscriptAttr(rewriter, permutation)); + return success(); + } +}; + +} // end anonymous namespace + +void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + void TransposeOp::getTransp(SmallVectorImpl &results) { populateFromInt64AttrArray(transp(), results); } @@ -1704,7 +1753,8 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); + patterns.insert(context); } namespace mlir { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -125,34 +125,37 @@ // CHECK-LABEL: transpose_2D_sequence // CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>) -func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<3x4xf32> { +func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> { // CHECK-NOT: transpose - %0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32> - // CHECK: [[T1:%.*]] = vector.transpose [[ARG]], [1, 0] - %1 = vector.transpose %0, [1, 0] : vector<4x3xf32> to vector<3x4xf32> - // CHECK-NOT: transpose - %2 = vector.transpose %1, [0, 1] : vector<3x4xf32> to vector<3x4xf32> - // CHECK: [[ADD:%.*]] = addf [[T1]], [[T1]] - %4 = addf %1, %2 : vector<3x4xf32> + %0 = vector.transpose %arg, [1, 0] : vector<4x3xf32> to vector<3x4xf32> + %1 = vector.transpose %0, [0, 1] : vector<3x4xf32> to vector<3x4xf32> + %2 = vector.transpose %1, [1, 0] : vector<3x4xf32> to vector<4x3xf32> + %3 = vector.transpose %2, [0, 1] : vector<4x3xf32> to vector<4x3xf32> + // CHECK: [[ADD:%.*]] = addf [[ARG]], [[ARG]] + %4 = addf %2, %3 : vector<4x3xf32> // CHECK-NEXT: return [[ADD]] - return %4 : vector<3x4xf32> + return %4 : vector<4x3xf32> } // ----- // CHECK-LABEL: transpose_3D_sequence // CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>) -func @transpose_3D_sequence(%arg : vector<4x3x2xf32>) -> vector<2x3x4xf32> { - // CHECK: [[T0:%.*]] = vector.transpose [[ARG]], [1, 2, 0] +func @transpose_3D_sequence(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> { + // CHECK: [[T0:%.*]] = vector.transpose [[ARG]], [2, 1, 0] %0 = vector.transpose %arg, [1, 2, 0] : vector<4x3x2xf32> to vector<3x2x4xf32> + %1 = vector.transpose %0, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32> // CHECK-NOT: transpose - %1 = vector.transpose %0, [0, 1, 2] : vector<3x2x4xf32> to vector<3x2x4xf32> - // CHECK: [[T2:%.*]] = vector.transpose [[T0]], [1, 0, 2] - %2 = vector.transpose %1, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32> - // CHECK: [[ADD:%.*]] = addf [[T2]], [[T2]] - %3 = addf %2, %2 : vector<2x3x4xf32> + %2 = vector.transpose %1, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32> + %3 = vector.transpose %2, [2, 1, 0] : vector<4x3x2xf32> to vector<2x3x4xf32> + // CHECK: [[MUL:%.*]] = mulf [[T0]], [[T0]] + %4 = mulf %1, %3 : vector<2x3x4xf32> + // CHECK: [[T5:%.*]] = vector.transpose [[MUL]], [2, 1, 0] + %5 = vector.transpose %4, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32> // CHECK-NOT: transpose - %4 = vector.transpose %3, [0, 1, 2] : vector<2x3x4xf32> to vector<2x3x4xf32> + %6 = vector.transpose %3, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32> + // CHECK: [[ADD:%.*]] = addf [[T5]], [[ARG]] + %7 = addf %5, %6 : vector<4x3x2xf32> // CHECK-NEXT: return [[ADD]] - return %4 : vector<2x3x4xf32> + return %7 : vector<4x3x2xf32> }