diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -193,6 +193,16 @@ if (matchPattern(getRhs(), m_Zero())) return getLhs(); + // add(sub(a, b), b) -> a + if (auto sub = getLhs().getDefiningOp()) + if (getRhs() == sub.getRhs()) + return sub.getLhs(); + + // add(b, sub(a, b)) -> a + if (auto sub = getRhs().getDefiningOp()) + if (getLhs() == sub.getRhs()) + return sub.getLhs(); + return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a + b; }); } diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -204,6 +204,22 @@ return %add2 : index } +// CHECK-LABEL: @doubleAddSub1 +// CHECK: return %arg0 +func @doubleAddSub1(%arg0: index, %arg1 : index) -> index { + %sub = arith.subi %arg0, %arg1 : index + %add = arith.addi %sub, %arg1 : index + return %add : index +} + +// CHECK-LABEL: @doubleAddSub2 +// CHECK: return %arg0 +func @doubleAddSub2(%arg0: index, %arg1 : index) -> index { + %sub = arith.subi %arg0, %arg1 : index + %add = arith.addi %arg1, %sub : index + return %add : index +} + // CHECK-LABEL: @notCmpEQ // CHECK: %[[cres:.+]] = arith.cmpi ne, %arg0, %arg1 : i8 // CHECK: return %[[cres]]