diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -785,6 +785,8 @@ let results = (outs Tosa_Tensor:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -501,6 +501,36 @@ lhsTy); } +OpFoldResult SubOp::fold(ArrayRef operands) { + auto lhsTy = getInput1().getType().dyn_cast(); + auto rhsTy = getInput2().getType().dyn_cast(); + auto resultTy = getType().dyn_cast(); + if (!lhsTy || !rhsTy || !resultTy) + return {}; + if (lhsTy != rhsTy) + return {}; + + auto resultETy = resultTy.getElementType(); + auto lhsAttr = operands[0].dyn_cast_or_null(); + auto rhsAttr = operands[1].dyn_cast_or_null(); + + if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { + if (rhsAttr.getSplatValue().isZero()) + return getInput1(); + } + + if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { + if (rhsAttr.getSplatValue().isZero()) + return getInput1(); + } + + if (!lhsAttr || !rhsAttr) + return {}; + + return BinaryFolder, std::minus>(lhsAttr, rhsAttr, + lhsTy); +} + OpFoldResult CastOp::fold(ArrayRef operands) { if (getInput().getType() == getType()) return getInput(); diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -164,6 +164,50 @@ // ----- +// CHECK-LABEL: @fold_sub_zero_rhs_f32 +func.func @fold_sub_zero_rhs_f32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor + %sub = "tosa.sub"(%arg0, %zero) : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %sub : tensor +} + +// ----- + +// CHECK-LABEL: @fold_sub_zero_rhs_i32 +func.func @fold_sub_zero_rhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {value = dense<0> : tensor} : () -> tensor + %sub = "tosa.sub"(%arg0, %zero) : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %sub : tensor +} + +// ----- + +// CHECK-LABEL: @fold_sub_splat_i32 +func.func @fold_sub_splat_i32() -> tensor<10xi32> { + %one = "tosa.const"() {value = dense<1> : tensor<10xi32>} : () -> tensor<10xi32> + %two = "tosa.const"() {value = dense<2> : tensor<10xi32>} : () -> tensor<10xi32> + %sub = "tosa.sub"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + // CHECK: %[[THREE:.+]] = "tosa.const"() {value = dense<-1> : tensor<10xi32>} + // CHECK: return %[[THREE]] + return %sub : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @fold_sub_splat_f32 +func.func @fold_sub_splat_f32() -> tensor<10xf32> { + %one = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> + %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %sub = "tosa.sub"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + // CHECK: %[[THREE:.+]] = "tosa.const"() {value = dense<-1.000000e+00> : tensor<10xf32>} + // CHECK: return %[[THREE]] + return %sub : tensor<10xf32> +} + +// ----- + // CHECK-LABEL: @slice_splat func.func @slice_splat() -> tensor<1x1x1xi32> { // CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<42> : tensor<1x1x1xi32>}