diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -194,6 +194,16 @@ if (matchPattern(getRhs(), m_Zero())) return getLhs(); + // add(sub(a, b), b) -> a + if (auto sub = getLhs().getDefiningOp()) + if (getRhs() == sub.getRhs()) + return sub.getLhs(); + + // add(b, sub(a, b)) -> a + if (auto sub = getRhs().getDefiningOp()) + if (getLhs() == sub.getRhs()) + return sub.getLhs(); + return constFoldBinaryOp( operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); } diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -299,6 +299,22 @@ return %add2 : index } +// CHECK-LABEL: @doubleAddSub1 +// CHECK-NEXT: return %arg0 +func @doubleAddSub1(%arg0: index, %arg1 : index) -> index { + %sub = arith.subi %arg0, %arg1 : index + %add = arith.addi %sub, %arg1 : index + return %add : index +} + +// CHECK-LABEL: @doubleAddSub2 +// CHECK-NEXT: return %arg0 +func @doubleAddSub2(%arg0: index, %arg1 : index) -> index { + %sub = arith.subi %arg0, %arg1 : index + %add = arith.addi %arg1, %sub : index + return %add : index +} + // CHECK-LABEL: @notCmpEQ // CHECK: %[[cres:.+]] = arith.cmpi ne, %arg0, %arg1 : i8 // CHECK: return %[[cres]]