diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -164,6 +164,16 @@ (ConstantLikeMatcher ConstantAttr)), (Arith_CmpIOp (InvertPredicate $pred), $a, $b)>; +// xor extui(x), extui(y) -> extui(xor(x,y)) +def XOrIOfExtUI : + Pat<(Arith_XOrIOp (Arith_ExtUIOp $x), (Arith_ExtUIOp $y)), (Arith_ExtUIOp (Arith_XOrIOp $x, $y)), + [(Constraint> $x, $y)]>; + +// xor extsi(x), extsi(y) -> extsi(xor(x,y)) +def XOrIOfExtSI : + Pat<(Arith_XOrIOp (Arith_ExtSIOp $x), (Arith_ExtSIOp $y)), (Arith_ExtSIOp (Arith_XOrIOp $x, $y)), + [(Constraint> $x, $y)]>; + //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -841,7 +841,7 @@ void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(context); + patterns.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -1036,6 +1036,28 @@ return %nncmp : i1 } +// CHECK-LABEL: @xorOfExtSI +// CHECK: %[[comb:.+]] = arith.xori %arg0, %arg1 : i8 +// CHECK: %[[ext:.+]] = arith.extsi %[[comb]] : i8 to i64 +// CHECK: return %[[ext]] +func.func @xorOfExtSI(%arg0: i8, %arg1: i8) -> i64 { + %ext0 = arith.extsi %arg0 : i8 to i64 + %ext1 = arith.extsi %arg1 : i8 to i64 + %res = arith.xori %ext0, %ext1 : i64 + return %res : i64 +} + +// CHECK-LABEL: @xorOfExtUI +// CHECK: %[[comb:.+]] = arith.xori %arg0, %arg1 : i8 +// CHECK: %[[ext:.+]] = arith.extui %[[comb]] : i8 to i64 +// CHECK: return %[[ext]] +func.func @xorOfExtUI(%arg0: i8, %arg1: i8) -> i64 { + %ext0 = arith.extui %arg0 : i8 to i64 + %ext1 = arith.extui %arg1 : i8 to i64 + %res = arith.xori %ext0, %ext1 : i64 + return %res : i64 +} + // ----- // CHECK-LABEL: @bitcastSameType(