diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1541,6 +1541,12 @@ outs Tosa_Tensor1Dto6D:$output ); + let extraClassDeclaration = [{ + llvm::SmallVector getConstantPerms(); + mlir::OpFoldResult foldIdentityTranspose(); + mlir::OpFoldResult foldCancellableTranspose(); + }]; + let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -954,6 +954,53 @@ return {}; } +llvm::SmallVector TransposeOp::getConstantPerms() { + // Perms must be constants. + DenseIntElementsAttr perms; + if (!matchPattern(getPerms(), m_Constant(&perms))) + return {}; + + return llvm::to_vector<6>( + llvm::map_range(perms.getValues(), + [](const APInt &val) { return val.getSExtValue(); })); +} + +OpFoldResult TransposeOp::foldIdentityTranspose() { + // Transpose does not change the input type. + if (getInput1().getType() != getType()) + return {}; + + // Transpose is not the identity transpose. + auto perms = getConstantPerms(); + if (!llvm::equal(llvm::seq(0, perms.size()), perms)) + return {}; + + return getInput1(); +} + +OpFoldResult TransposeOp::foldCancellableTranspose() { + // Input is also TransposeOp - transpose(transpose(A)). + auto innerTranspose = getInput1().getDefiningOp(); + if (!innerTranspose) + return {}; + + auto transposePerms = getConstantPerms(); + auto innerTransposePerms = innerTranspose.getConstantPerms(); + + // Number of Perms values is same and positive. + if (transposePerms.size() != innerTransposePerms.size() || + transposePerms.empty()) + return {}; + + // Transpose & inner transpose cancel each other. + for (unsigned int i = 0; i < transposePerms.size(); ++i) { + if (transposePerms[innerTransposePerms[i]] != i) + return {}; + } + + return innerTranspose.getInput1(); +} + OpFoldResult TransposeOp::fold(ArrayRef operands) { if (!operands[1]) return {}; @@ -964,12 +1011,10 @@ return input.reshape(getType().cast()); } - auto perms = llvm::to_vector<6>(llvm::map_range( - operands[1].cast().getValues(), - [](const APInt &val) { return val.getSExtValue(); })); + if (auto folded = foldIdentityTranspose()) + return folded; + if (auto folded = foldCancellableTranspose()) + return folded; - if (llvm::equal(llvm::seq(0, perms.size()), perms) && - getInput1().getType() == getType()) - return getInput1(); return {}; } diff --git a/mlir/test/IR/transpose-fold.mlir b/mlir/test/IR/transpose-fold.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/transpose-fold.mlir @@ -0,0 +1,46 @@ +// RUN: mlir-opt %s --canonicalize -split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @test_cancel_transpose_transpose( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> { +// CHECK: return %[[VAL_0]] : tensor<1x2x3xi32> +// CHECK: } + +func.func @test_cancel_transpose_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) { + %0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32> + %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<2x3x1xi32>) + %2 = arith.constant dense<[2, 0, 1]> : tensor<3xi32> + %3 = "tosa.transpose"(%1, %2) : (tensor<2x3x1xi32>, tensor<3xi32>) -> tensor<1x2x3xi32> + return %3 : tensor<1x2x3xi32> +} + +// ----- + +// CHECK-LABEL: func.func @test_remove_identity_transpose( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> { +// CHECK: return %[[VAL_0]] : tensor<1x2x3xi32> +// CHECK: } + +func.func @test_remove_identity_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) { + %0 = arith.constant dense<[0, 1, 2]> : tensor<3xi32> + %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<1x2x3xi32>) + return %1 : tensor<1x2x3xi32> +} + +// ----- + +// CHECK-LABEL: func.func @test_do_not_cancel_different_transpose( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4xi32>) -> tensor<4x3x2x1xi32> { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<[3, 1, 0, 2]> : tensor<4xi32> +// CHECK: %[[VAL_2:.*]] = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.transpose"(%[[VAL_0]], %[[VAL_2]]) : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<2x3x1x4xi32> +// CHECK: %[[VAL_4:.*]] = "tosa.transpose"(%[[VAL_3]], %[[VAL_1]]) : (tensor<2x3x1x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32> +// CHECK: return %[[VAL_4]] : tensor<4x3x2x1xi32> +// CHECK: } + +func.func @test_do_not_cancel_different_transpose(%arg0: tensor<1x2x3x4xi32>) -> (tensor<4x3x2x1xi32>) { + %0 = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32> + %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> (tensor<2x3x1x4xi32>) + %2 = arith.constant dense<[3, 1, 0, 2]> : tensor<4xi32> + %3 = "tosa.transpose"(%1, %2) : (tensor<2x3x1x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32> + return %3 : tensor<4x3x2x1xi32> +} \ No newline at end of file