diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td @@ -12,6 +12,9 @@ include "mlir/IR/PatternBase.td" include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" +// Create zero attribute of type matching the argument's type. +def GetZeroAttr : NativeCodeCall<"$_builder.getZeroAttr($0.getType())">; + // Add two integer attributes and create a new one with the result. def AddIntAttrs : NativeCodeCall<"addIntegerAttrs($_builder, $0, $1, $2)">; @@ -92,6 +95,11 @@ (Arith_SubIOp (Arith_ConstantOp APIntAttr:$c0), $x)), (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>; +// subi(subi(a, b), a) -> subi(0, b) +def SubISubILHSRHSLHS : + Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y), $x), + (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y)>; + //===----------------------------------------------------------------------===// // XOrIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -305,16 +305,24 @@ if (matchPattern(getRhs(), m_Zero())) return getLhs(); + if (auto add = getLhs().getDefiningOp()) { + // subi(addi(a, b), b) -> a + if (getRhs() == add.getRhs()) + return add.getLhs(); + // subi(addi(a, b), a) -> b + if (getRhs() == add.getLhs()) + return add.getRhs(); + } + return constFoldBinaryOp( operands, [](APInt a, const APInt &b) { return std::move(a) - b; }); } void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns - .add( - context); + patterns.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -480,6 +480,16 @@ return %add2 : index } +// CHECK-LABEL: @subSub0 +// CHECK: %[[c0:.+]] = arith.constant 0 : index +// CHECK: %[[add:.+]] = arith.subi %[[c0]], %arg1 : index +// CHECK: return %[[add]] +func.func @subSub0(%arg0: index, %arg1: index) -> index { + %sub1 = arith.subi %arg0, %arg1 : index + %sub2 = arith.subi %sub1, %arg0 : index + return %sub2 : index +} + // CHECK-LABEL: @tripleSubSub0 // CHECK: %[[cres:.+]] = arith.constant 25 : index // CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index @@ -528,6 +538,22 @@ return %add2 : index } +// CHECK-LABEL: @subAdd1 +// CHECK-NEXT: return %arg0 +func.func @subAdd1(%arg0: index, %arg1 : index) -> index { + %add = arith.addi %arg0, %arg1 : index + %sub = arith.subi %add, %arg1 : index + return %sub : index +} + +// CHECK-LABEL: @subAdd2 +// CHECK-NEXT: return %arg1 +func.func @subAdd2(%arg0: index, %arg1 : index) -> index { + %add = arith.addi %arg0, %arg1 : index + %sub = arith.subi %add, %arg0 : index + return %sub : index +} + // CHECK-LABEL: @doubleAddSub1 // CHECK-NEXT: return %arg0 func.func @doubleAddSub1(%arg0: index, %arg1 : index) -> index {