diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -339,6 +339,18 @@ } }; +template +class IdentityNConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const final { + rewriter.replaceOp(op, op.getOperation()->getOperands()); + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( @@ -358,6 +370,7 @@ PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter>( - context); + PointwiseConverter, PointwiseConverter, + IdentityNConverter, + IdentityNConverter>(context); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -255,6 +255,12 @@ // CHECK: select %15 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> + // CHECK-NOT: tosa.identity + %16 = "tosa.identity"(%0) : (tensor<1xi32>) -> tensor<1xi32> + + // CHECK-NOT: tosa.identityn + %17:2 = "tosa.identityn"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> (tensor<1xi32>, tensor<1xi32>) + return }