diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -188,6 +188,17 @@ if (!matchPattern(op.getPerms(), m_Constant(&permAttr))) return rewriter.notifyMatchFailure(op, "Non-constant permutation"); + if (op.getInput1().getDefiningOp()) + return rewriter.notifyMatchFailure( + op, "Src is from transpose, can compose transposes"); + + Value result = op.getResult(); + for (Operation *subop : result.getUsers()) { + if (dyn_cast_or_null(subop)) + return rewriter.notifyMatchFailure( + op, "Dest is used by transpose, can compose transposes"); + } + auto input = op.getInput1(); auto inputTy = input.getType().cast(); if (!inputTy.hasRank()) diff --git a/mlir/test/IR/transpose-fold.mlir b/mlir/test/Dialect/Tosa/transpose-fold.mlir rename from mlir/test/IR/transpose-fold.mlir rename to mlir/test/Dialect/Tosa/transpose-fold.mlir --- a/mlir/test/IR/transpose-fold.mlir +++ b/mlir/test/Dialect/Tosa/transpose-fold.mlir @@ -42,3 +42,20 @@ %3 = "tosa.transpose"(%1, %2) : (tensor<3x4x2x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32> return %3 : tensor<5x4x3x2xi32> } + +// ----- + +// CHECK-LABEL: func.func @test_prefer_compose_transpose( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4xi32>) -> tensor<4x3x2x1xi32> { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<[3, 2, 1, 0]> : tensor<4xi32> +// CHECK: %[[VAL_2:.*]] = "tosa.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32> +// CHECK: return %[[VAL_2]] : tensor<4x3x2x1xi32> +// CHECK: } + +func.func @test_prefer_compose_transpose(%arg0: tensor<1x2x3x4xi32>) -> (tensor<4x3x2x1xi32>) { + %0 = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32> + %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> (tensor<2x3x1x4xi32>) + %2 = arith.constant dense<[3, 1, 0, 2]> : tensor<4xi32> + %3 = "tosa.transpose"(%1, %2) : (tensor<2x3x1x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32> + return %3 : tensor<4x3x2x1xi32> +}