diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -222,9 +222,37 @@ } }; +struct NoOpOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::TransposeOp op, + PatternRewriter &rewriter) const override { + auto perm = op.perms(); + + DenseIntElementsAttr permAttr; + if (!matchPattern(perm, m_Constant(&permAttr))) { + return failure(); + } + + SmallVector permValues = llvm::to_vector<6>( + llvm::map_range(permAttr.getValues(), + [](const APInt &val) { return val.getSExtValue(); })); + + for (int i = 0, s = permValues.size(); i < s; i++) { + if (i != permValues[i]) { + return failure(); + } + } + + rewriter.replaceOp(op, op.input1()); + return success(); + } +}; + void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -233,7 +233,7 @@ // CHECK-LABEL: @transpose_nofold_shape func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor { // CHECK: "tosa.transpose" - %0 = arith.constant dense<[0, 1]> : tensor<2xi32> + %0 = arith.constant dense<[1, 0]> : tensor<2xi32> %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor return %1 : tensor } @@ -325,3 +325,14 @@ %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> return %0: tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> } + +// ----- + +// CHECK-LABEL: @transpose_no_op +func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> { + // CHECK: return %arg0 + // CHECK-NOT: tosa.transpose + %perms = "tosa.const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tosa.transpose"(%arg0, %perms) : (tensor<3x4x5x6xf32>, tensor<4xi32>) -> tensor<3x4x5x6xf32> + return %1 : tensor<3x4x5x6xf32> +}