diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td @@ -106,6 +106,16 @@ (Arith_ConstantOp ConstantAttr)), (Arith_CmpIOp (InvertPredicate $pred), $a, $b)>; +def XOrIntAttrs : NativeCodeCall<"xorIntegerAttrs($_builder, $0, $1, $2)">; + +// xor(xor(x, c0), c1) = xor(x, xor(c1, c2)) +def XOrXOrConstant : + Pat<(Arith_XOrIOp:$res + (Arith_XOrIOp $x, (Arith_ConstantOp APIntAttr:$c0)), + (Arith_ConstantOp APIntAttr:$c1) + ), + (Arith_XOrIOp $x, (Arith_ConstantOp (XOrIntAttrs $res, $c0, $c1)))>; + //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -39,6 +39,13 @@ rhs.cast().getInt()); } +static IntegerAttr xorIntegerAttrs(PatternRewriter &builder, Value res, + Attribute lhs, Attribute rhs) { + return builder.getIntegerAttr(res.getType(), + lhs.cast().getInt() ^ + rhs.cast().getInt()); +} + /// Invert an integer comparison predicate. static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) { switch (pred) { @@ -557,6 +564,10 @@ /// xor(x, x) -> 0 if (getLhs() == getRhs()) return Builder(getContext()).getZeroAttr(getType()); + /// xor(xor(x, a), a) -> x + if (arith::XOrIOp prev = getLhs().getDefiningOp()) + if (prev.getRhs() == getRhs()) + return prev.getLhs(); return constFoldBinaryOp( operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); @@ -564,7 +575,7 @@ void arith::XOrIOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); + patterns.insert(context); } //===----------------------------------------------------------------------===// @@ -859,13 +870,19 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::TruncIOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "unary operation takes one operand"); + // trunci(zexti(a)) -> a // trunci(sexti(a)) -> a if (matchPattern(getOperand(), m_Op()) || matchPattern(getOperand(), m_Op())) return getOperand().getDefiningOp()->getOperand(0); - assert(operands.size() == 1 && "unary operation takes one operand"); + // trunci(trunci(a)) -> trunci(a)) + if (matchPattern(getOperand(), m_Op())) { + setOperand(getOperand().getDefiningOp()->getOperand(0)); + return getResult(); + } if (!operands[0]) return {}; diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -172,6 +172,15 @@ return %tr : i16 } +// CHECK-LABEL: @truncTrunc +// CHECK: %[[cres:.+]] = arith.trunci %arg0 : i64 to i8 +// CHECK: return %[[cres]] +func @truncTrunc(%arg0: i64) -> i8 { + %tr1 = arith.trunci %arg0 : i64 to i32 + %tr2 = arith.trunci %tr1 : i32 to i8 + return %tr2 : i8 +} + // CHECK-LABEL: @truncFPConstant // CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16 // CHECK: return %[[cres]] @@ -427,6 +436,18 @@ // ----- +// CHECK-LABEL: @xorxor( +// CHECK-NOT: xori +// CHECK: return %arg0 +func @xorxor(%cmp : i1) -> i1 { + %true = arith.constant true + %ncmp = arith.xori %cmp, %true : i1 + %nncmp = arith.xori %ncmp, %true : i1 + return %nncmp : i1 +} + +// ----- + // CHECK-LABEL: @bitcastSameType( // CHECK-SAME: %[[ARG:[a-zA-Z0-9_]*]] func @bitcastSameType(%arg : f32) -> f32 {