diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -542,6 +542,8 @@ let results = (outs Tosa_Int32Tensor:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -501,6 +501,40 @@ lhsTy); } +OpFoldResult DivOp::fold(ArrayRef operands) { + auto lhsTy = getInput1().getType().dyn_cast(); + auto rhsTy = getInput2().getType().dyn_cast(); + auto resultTy = getType().dyn_cast(); + if (!lhsTy || !rhsTy || !resultTy) + return {}; + if (lhsTy != rhsTy) + return {}; + + auto resultETy = resultTy.getElementType(); + auto lhsAttr = operands[0].dyn_cast_or_null(); + auto rhsAttr = operands[1].dyn_cast_or_null(); + if (lhsAttr && lhsAttr.isSplat()) { + if (resultETy.isa() && lhsAttr.getSplatValue().isZero()) + return lhsAttr; + } + + if (rhsAttr && rhsAttr.isSplat()) { + if (resultETy.isa() && rhsAttr.getSplatValue().isOne()) + return getInput1(); + } + + if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) { + if (resultETy.isa()) { + APInt l = lhsAttr.getSplatValue(); + APInt r = rhsAttr.getSplatValue(); + APInt result = l.sdiv(r); + return DenseElementsAttr::get(resultTy, result); + } + } + + return {}; +} + OpFoldResult CastOp::fold(ArrayRef operands) { if (getInput().getType() == getType()) return getInput(); diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -164,6 +164,39 @@ // ----- +// CHECK-LABEL: @fold_div_zero_lhs_i32 +func.func @fold_div_zero_lhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {value = dense<0> : tensor} : () -> tensor + // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0> : tensor} + %div = "tosa.div"(%zero, %arg0) : (tensor, tensor) -> tensor + // CHECK: return %[[ZERO]] + return %div : tensor +} + +// ----- + +// CHECK-LABEL: @fold_div_one_rhs_i32 +func.func @fold_div_one_rhs_i32(%arg0: tensor) -> tensor { + %one = "tosa.const"() {value = dense<1> : tensor} : () -> tensor + %div = "tosa.div"(%arg0, %one) : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %div : tensor +} + +// ----- + +// CHECK-LABEL: @fold_div_splat_i32 +func.func @fold_div_splat_i32() -> tensor { + %lhs = "tosa.const"() {value = dense<10> : tensor} : () -> tensor + %rhs = "tosa.const"() {value = dense<-3> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<-3> : tensor} + %div = "tosa.div"(%lhs, %rhs) : (tensor, tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %div : tensor +} + +// ----- + // CHECK-LABEL: @slice_splat func.func @slice_splat() -> tensor<1x1x1xi32> { // CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<42> : tensor<1x1x1xi32>}