diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -557,6 +557,10 @@ /// xor(x, x) -> 0 if (getLhs() == getRhs()) return Builder(getContext()).getZeroAttr(getType()); + /// xor(xor(x, a), a) -> x + if (arith::XOrIOp prev = getLhs().getDefiningOp()) + if (prev.getRhs() == getRhs()) + return prev.getLhs(); return constFoldBinaryOp( operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); @@ -859,13 +863,19 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::TruncIOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "unary operation takes one operand"); + // trunci(zexti(a)) -> a // trunci(sexti(a)) -> a if (matchPattern(getOperand(), m_Op()) || matchPattern(getOperand(), m_Op())) return getOperand().getDefiningOp()->getOperand(0); - assert(operands.size() == 1 && "unary operation takes one operand"); + // trunci(trunci(a)) -> trunci(a)) + if (matchPattern(getOperand(), m_Op())) { + setOperand(getOperand().getDefiningOp()->getOperand(0)); + return getResult(); + } if (!operands[0]) return {}; diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -172,6 +172,15 @@ return %tr : i16 } +// CHECK-LABEL: @truncTrunc +// CHECK: %[[cres:.+]] = arith.trunci %arg0 : i64 to i8 +// CHECK: return %[[cres]] +func @truncTrunc(%arg0: i64) -> i8 { + %tr1 = arith.trunci %arg0 : i64 to i32 + %tr2 = arith.trunci %tr1 : i32 to i8 + return %tr2 : i8 +} + // CHECK-LABEL: @truncFPConstant // CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16 // CHECK: return %[[cres]] @@ -427,6 +436,18 @@ // ----- +// CHECK-LABEL: @xorxor( +// CHECK-NOT: xori +// CHECK: return %arg0 +func @xorxor(%cmp : i1) -> i1 { + %true = arith.constant true + %ncmp = arith.xori %cmp, %true : i1 + %nncmp = arith.xori %ncmp, %true : i1 + return %nncmp : i1 +} + +// ----- + // CHECK-LABEL: @bitcastSameType( // CHECK-SAME: %[[ARG:[a-zA-Z0-9_]*]] func @bitcastSameType(%arg : f32) -> f32 {