diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -860,6 +860,7 @@ }]; let hasFolder = 1; + let hasCanonicalizer = 1; let verifier = [{ return verifyTruncateOp(*this); }]; } diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td @@ -106,6 +106,16 @@ (Arith_ConstantOp ConstantAttr)), (Arith_CmpIOp (InvertPredicate $pred), $a, $b)>; +def XOrIntAttrs : NativeCodeCall<"xorIntegerAttrs($_builder, $0, $1, $2)">; + +// xor(xor(x, c0), c1) = xor(x, xor(c1, c2)) +def XOrXOrConstant : + Pat<(Arith_XOrIOp:$res + (Arith_XOrIOp $x, (Arith_ConstantOp APIntAttr:$c0)), + (Arith_ConstantOp APIntAttr:$c1) + ), + (Arith_XOrIOp $x, (Arith_ConstantOp (XOrIntAttrs $res, $c0, $c1)))>; + //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// @@ -128,4 +138,12 @@ def BitcastOfBitcast : Pat<(Arith_BitcastOp (Arith_BitcastOp $x)), (replaceWithValue $x)>; +//===----------------------------------------------------------------------===// +// TruncIOp +//===----------------------------------------------------------------------===// + +// trunc(trunc(x, a), b) -> trunc(x, b) +def DoubleTruncI : + Pat<(Arith_TruncIOp (Arith_TruncIOp $x)), (Arith_TruncIOp $x)>; + #endif // ARITHMETIC_PATTERNS diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -37,6 +37,13 @@ rhs.cast().getInt()); } +static IntegerAttr xorIntegerAttrs(PatternRewriter &builder, Value res, + Attribute lhs, Attribute rhs) { + return builder.getIntegerAttr(res.getType(), + lhs.cast().getInt() ^ + rhs.cast().getInt()); +} + /// Invert an integer comparison predicate. static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) { switch (pred) { @@ -536,7 +543,7 @@ void arith::XOrIOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); + patterns.insert(context); } //===----------------------------------------------------------------------===// @@ -831,6 +838,11 @@ return checkWidthChangeCast(inputs, outputs); } +void arith::TruncIOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // TruncFOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -77,6 +77,15 @@ return %tr : i16 } +// CHECK-LABEL: @truncTrunc +// CHECK: %[[cres:.+]] = arith.trunci %arg0 : i64 to i8 +// CHECK: return %[[cres]] +func @truncTrunc(%arg0: i64) -> i8 { + %tr1 = arith.trunci %arg0 : i64 to i32 + %tr2 = arith.trunci %tr1 : i32 to i8 + return %tr2 : i8 +} + // CHECK-LABEL: @truncFPConstant // CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16 // CHECK: return %[[cres]] @@ -316,6 +325,18 @@ // ----- +// CHECK-LABEL: @xorxor( +// CHECK-NOT: xori +// CHECK: return %arg0 +func @xorxor(%cmp : i1) -> i1 { + %true = arith.constant true + %ncmp = arith.xori %cmp, %true : i1 + %nncmp = arith.xori %ncmp, %true : i1 + return %nncmp : i1 +} + +// ----- + // CHECK-LABEL: @bitcastSameType( // CHECK-SAME: %[[ARG:[a-zA-Z0-9_]*]] func @bitcastSameType(%arg : f32) -> f32 {