diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -531,6 +531,8 @@ %a = complex.sub %b, %c : complex ``` }]; + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -138,6 +138,21 @@ return {}; } +//===----------------------------------------------------------------------===// +// SubOp +//===----------------------------------------------------------------------===// + +OpFoldResult SubOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary op takes 2 operands"); + + // complex.sub(complex.add(a, b), b) -> a + if (auto add = getLhs().getDefiningOp()) + if (getRhs() == add.getRhs()) + return add.getLhs(); + + return {}; +} + //===----------------------------------------------------------------------===// // NegOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir --- a/mlir/test/Dialect/Complex/canonicalize.mlir +++ b/mlir/test/Dialect/Complex/canonicalize.mlir @@ -133,4 +133,15 @@ // CHECK-NEXT: return %[[CPLX:.*]] : complex %add = complex.add %complex1, %complex2 : complex return %add : complex +} + +// CHECK-LABEL: func @complex_sub_add_lhs +func.func @complex_sub_add_lhs() -> complex { + %complex1 = complex.constant [1.0 : f32, 0.0 : f32] : complex + %complex2 = complex.constant [0.0 : f32, 2.0 : f32] : complex + // CHECK: %[[CPLX:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex + // CHECK-NEXT: return %[[CPLX:.*]] : complex + %add = complex.add %complex1, %complex2 : complex + %sub = complex.sub %add, %complex2 : complex + return %sub : complex } \ No newline at end of file