diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -301,12 +301,6 @@ DenseElementsAttr input1Attr; if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() && input2.getType() == op.getType()) { - if (input1Attr.getType().getElementType().isa() && - input1Attr.getSplatValue().isZero()) { - rewriter.replaceOp(op, op.input2()); - return success(); - } - if (input1Attr.getType().getElementType().isa() && input1Attr.getSplatValue().isZero()) { rewriter.replaceOp(op, op.input2()); @@ -317,12 +311,6 @@ DenseElementsAttr input2Attr; if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() && input1.getType() == op.getType()) { - if (input2Attr.getType().getElementType().isa() && - input2Attr.getSplatValue().isZero()) { - rewriter.replaceOp(op, op.input1()); - return success(); - } - if (input2Attr.getType().getElementType().isa() && input2Attr.getSplatValue().isZero()) { rewriter.replaceOp(op, op.input1()); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -10,23 +10,13 @@ // ----- // CHECK-LABEL: @add_zero_different_shape -func @add_zero_different_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> { +func @add_zero_different_shape(%arg0: tensor<2x3xi32>) -> tensor<4x2x3xi32> { // CHECK: tosa.add - %zeros = "tosa.const"() {value = dense<0.0> : tensor<4x2x3xf32>} : () -> tensor<4x2x3xf32> - %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xf32>, tensor<4x2x3xf32>) -> tensor<4x2x3xf32> - return %1 : tensor<4x2x3xf32> + %zeros = "tosa.const"() {value = dense<0> : tensor<4x2x3xi32>} : () -> tensor<4x2x3xi32> + %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xi32>, tensor<4x2x3xi32>) -> tensor<4x2x3xi32> + return %1 : tensor<4x2x3xi32> } -// ----- - -// CHECK-LABEL: @add_zero_float -func @add_zero_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK: return %arg0 - // CHECK-NOT: tosa.add - %zeros = "tosa.const"() {value = dense<0.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> - return %1 : tensor<2x3xf32> -} // -----