diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -627,9 +627,21 @@ if (getLhs() == getRhs()) return Builder(getContext()).getZeroAttr(getType()); /// xor(xor(x, a), a) -> x - if (arith::XOrIOp prev = getLhs().getDefiningOp()) + /// xor(xor(a, x), a) -> x + if (arith::XOrIOp prev = getLhs().getDefiningOp()) { if (prev.getRhs() == getRhs()) return prev.getLhs(); + if (prev.getLhs() == getRhs()) + return prev.getRhs(); + } + /// xor(a, xor(x, a)) -> x + /// xor(a, xor(a, x)) -> x + if (arith::XOrIOp prev = getRhs().getDefiningOp()) { + if (prev.getRhs() == getLhs()) + return prev.getLhs(); + if (prev.getLhs() == getLhs()) + return prev.getRhs(); + } return constFoldBinaryOp( operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -1585,3 +1585,51 @@ %2 = arith.andi %1, %arg0 : index return %2 : index } + +// ----- +/// xor(xor(x, a), a) -> x + +// CHECK-LABEL: @xorxor0( +// CHECK-NOT: xori +// CHECK: return %arg0 +func.func @xorxor0(%a : i32, %b : i32) -> i32 { + %c = arith.xori %a, %b : i32 + %res = arith.xori %c, %b : i32 + return %res : i32 +} + +// ----- +/// xor(xor(a, x), a) -> x + +// CHECK-LABEL: @xorxor1( +// CHECK-NOT: xori +// CHECK: return %arg0 +func.func @xorxor1(%a : i32, %b : i32) -> i32 { + %c = arith.xori %b, %a : i32 + %res = arith.xori %c, %b : i32 + return %res : i32 +} + +// ----- +/// xor(a, xor(x, a)) -> x + +// CHECK-LABEL: @xorxor2( +// CHECK-NOT: xori +// CHECK: return %arg0 +func.func @xorxor2(%a : i32, %b : i32) -> i32 { + %c = arith.xori %a, %b : i32 + %res = arith.xori %b, %c : i32 + return %res : i32 +} + +// ----- +/// xor(a, xor(a, x)) -> x + +// CHECK-LABEL: @xorxor3( +// CHECK-NOT: xori +// CHECK: return %arg0 +func.func @xorxor3(%a : i32, %b : i32) -> i32 { + %c = arith.xori %b, %a : i32 + %res = arith.xori %b, %c : i32 + return %res : i32 +}