diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -442,6 +442,21 @@ } }; +// At the codegen level any identity operations should be removed. Any cases +// where identity is load-bearing (e.g. cross device computation) should be +// handled before lowering to codegen. +template +class IdentityNConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const final { + rewriter.replaceOp(op, op.getOperation()->getOperands()); + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( @@ -462,5 +477,6 @@ PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, - ReshapeOpConverter>(context); + IdentityNConverter, + IdentityNConverter, ReshapeOpConverter>(context); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -304,3 +304,16 @@ %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> return %0 : tensor<6x5x77xf32> } + +// ----- + +// CHECK-LABEL: @test_identity +func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<1xf32>, tensor<1xi32>) { + %0 = "tosa.identity"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + %1 = "tosa.identity"(%arg1) : (tensor<1xi32>) -> tensor<1xi32> + + %2:2 = "tosa.identityn"(%0, %1) : (tensor<1xf32>, tensor<1xi32>) -> (tensor<1xf32>, tensor<1xi32>) + + // CHECK: return %arg0, %arg1 + return %2#0, %2#1 : tensor<1xf32>, tensor<1xi32> +}