diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -647,13 +647,13 @@ const int64_t shift = resultETy.isa() ? getShift() : 0; if (rhsTy == resultTy) { if (isSplatZero(resultETy, lhsAttr)) - return lhsAttr; + return lhsAttr.resizeSplat(resultTy); if (isSplatOne(resultETy, lhsAttr, shift)) return rhs; } if (lhsTy == resultTy) { if (isSplatZero(resultETy, rhsAttr)) - return rhsAttr; + return rhsAttr.resizeSplat(resultTy); if (isSplatOne(resultETy, rhsAttr, shift)) return lhs; } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -203,6 +203,19 @@ return %1 : tensor<2x3xi32> } +// CHECK-LABEL: @mul_zero_broadcast +func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) { + // CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}> : () -> tensor<2x3xf32> + // CHECK-NOT: tosa.mul + %zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %1 = "tosa.mul"(%arg0, %zeros) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32> + + // CHECK-NOT: tosa.mul + // CHECK: return %[[ZERO]], %[[ZERO]] + %2 = "tosa.mul"(%zeros, %arg0) {shift = 0 : i32} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32> +} + // CHECK-LABEL: @select_same_value func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { %0 = "tosa.select"(%arg0, %arg1, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>