diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -550,44 +550,40 @@ return {}; } +template +static bool isSplatZero(DenseElementsAttr val) { + return val && val.isSplat() && val.getSplatValue().isZero(); +} + +static bool isSplatZero(Type elemType, DenseElementsAttr val) { + if (elemType.isa()) + return isSplatZero(val); + if (elemType.isa()) + return isSplatZero(val); + return false; +} + OpFoldResult AddOp::fold(FoldAdaptor adaptor) { auto lhsTy = getInput1().getType().dyn_cast(); auto rhsTy = getInput2().getType().dyn_cast(); auto resultTy = getType().dyn_cast(); if (!lhsTy || !rhsTy || !resultTy) return {}; - if (lhsTy != rhsTy) - return {}; auto resultETy = resultTy.getElementType(); auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); - if (lhsAttr && lhsAttr.isSplat() && resultETy.isa()) { - if (lhsAttr.getSplatValue().isZero()) - return getInput2(); - } - - if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { - if (rhsAttr.getSplatValue().isZero()) - return getInput1(); - } - - if (lhsAttr && lhsAttr.isSplat() && resultETy.isa()) { - if (lhsAttr.getSplatValue().isZero()) - return getInput2(); - } - - if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { - if (rhsAttr.getSplatValue().isZero()) - return getInput1(); - } + if (isSplatZero(resultETy, rhsAttr) && lhsTy == resultTy) + return getInput1(); + if (isSplatZero(resultETy, lhsAttr) && rhsTy == resultTy) + return getInput2(); if (!lhsAttr || !rhsAttr) return {}; return binaryFolder, std::plus>(lhsAttr, rhsAttr, - lhsTy); + resultTy); } OpFoldResult DivOp::fold(FoldAdaptor adaptor) { @@ -665,8 +661,6 @@ auto resultTy = getType().dyn_cast(); if (!lhsTy || !rhsTy || !resultTy) return {}; - if (lhsTy != rhsTy) - return {}; auto resultETy = resultTy.getElementType(); auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); @@ -674,27 +668,27 @@ if (lhsAttr && lhsAttr.isSplat() && resultETy.isa()) { auto val = lhsAttr.getSplatValue(); - if (val.isZero()) + if (val.isZero() && lhsTy == resultTy) return lhsAttr; - if (val.isExactlyValue(1.0)) + if (val.isExactlyValue(1.0) && rhsTy == resultTy) return rhs; } if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { auto val = rhsAttr.getSplatValue(); - if (val.isZero()) + if (val.isZero() && rhsTy == resultTy) return rhsAttr; - if (val.isExactlyValue(1.0)) + if (val.isExactlyValue(1.0) && lhsTy == resultTy) return lhs; } if (lhsAttr && lhsAttr.isSplat() && resultETy.isa()) { auto val = lhsAttr.getSplatValue(); - if (val.isZero()) + if (val.isZero() && lhsTy == resultTy) return lhsAttr; const int64_t shift = getShift(); const int64_t shifted = 1LL << shift; - if (val.getSExtValue() == shifted) + if (val.getSExtValue() == shifted && rhsTy == resultTy) return rhs; } @@ -702,13 +696,13 @@ auto val = rhsAttr.getSplatValue(); const int64_t shift = getShift(); const int64_t shifted = 1LL << shift; - if (val.isZero()) + if (val.isZero() && rhsTy == resultTy) return rhsAttr; - if (val.getSExtValue() == shifted) + if (val.getSExtValue() == shifted && lhsTy == resultTy) return lhs; } - return mulBinaryFolder(lhsAttr, rhsAttr, lhsTy, getShift()); + return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift()); } OpFoldResult SubOp::fold(FoldAdaptor adaptor) { @@ -717,28 +711,19 @@ auto resultTy = getType().dyn_cast(); if (!lhsTy || !rhsTy || !resultTy) return {}; - if (lhsTy != rhsTy) - return {}; auto resultETy = resultTy.getElementType(); auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); - if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { - if (rhsAttr.getSplatValue().isZero()) - return getInput1(); - } - - if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { - if (rhsAttr.getSplatValue().isZero()) - return getInput1(); - } + if (isSplatZero(resultETy, rhsAttr) && lhsTy == resultTy) + return getInput1(); if (!lhsAttr || !rhsAttr) return {}; return binaryFolder, std::minus>(lhsAttr, rhsAttr, - lhsTy); + resultTy); } namespace { diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -15,6 +15,14 @@ return %1 : tensor<4x2x3xi32> } +// CHECK-LABEL: @add_bcast_zero_int +func.func @add_bcast_zero_int(%arg0: tensor<4x2x3xi32>) -> tensor<4x2x3xi32> { + // CHECK: return %arg0 + // CHECK-NOT: tosa.add + %zeros = "tosa.const"() {value = dense<0> : tensor<1x1xi32>} : () -> tensor<1x1xi32> + %1 = "tosa.add"(%arg0, %zeros) : (tensor<4x2x3xi32>, tensor<1x1xi32>) -> tensor<4x2x3xi32> + return %1 : tensor<4x2x3xi32> +} // CHECK-LABEL: @add_zero_int func.func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> { @@ -193,6 +201,15 @@ return %1 : tensor<2x3xf32> } +// CHECK-LABEL: @mul_bcast_one_float +func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK: return %arg0 + // CHECK-NOT: tosa.mul + %ones = "tosa.const"() {value = dense<1.0> : tensor<1xf32>} : () -> tensor<1xf32> + %1 = "tosa.mul"(%ones, %arg0) {shift = 0 : i32} : (tensor<1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + return %1 : tensor<2x3xf32> +} + // CHECK-LABEL: @mul_one_int func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> { // CHECK: return %arg0