diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -1104,6 +1104,7 @@ }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td @@ -106,6 +106,28 @@ (Arith_ConstantOp ConstantAttr)), (Arith_CmpIOp (InvertPredicate $pred), $a, $b)>; +//===----------------------------------------------------------------------===// +// CmpIOp +//===----------------------------------------------------------------------===// + +// cmpi(== or !=, a ext iNN, b ext iNN) == cmpi(== or !=, a, b) +def CmpIExtSI : + Pat<(Arith_CmpIOp $pred, + (Arith_ExtSIOp $a), + (Arith_ExtSIOp $b)), + (Arith_CmpIOp $pred, $a, $b), + [(Constraint> $a, $b), + (Constraint> $pred)]>; + +// cmpi(== or !=, a ext iNN, b ext iNN) == cmpi(== or !=, a, b) +def CmpIExtUI : + Pat<(Arith_CmpIOp $pred, + (Arith_ExtUIOp $a), + (Arith_ExtUIOp $b)), + (Arith_CmpIOp $pred, $a, $b), + [(Constraint> $a, $b), + (Constraint> $pred)]>; + //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -1322,6 +1322,11 @@ return BoolAttr::get(getContext(), val); } +void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // CmpFOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -176,6 +176,48 @@ // ----- +// CHECK-LABEL: @cmpIExtSINE +// CHECK: %[[comb:.+]] = arith.cmpi ne, %arg0, %arg1 : i8 +// CHECK: return %[[comb]] +func @cmpIExtSINE(%arg0: i8, %arg1: i8) -> i1 { + %ext0 = arith.extsi %arg0 : i8 to i64 + %ext1 = arith.extsi %arg1 : i8 to i64 + %res = arith.cmpi ne, %ext0, %ext1 : i64 + return %res : i1 +} + +// CHECK-LABEL: @cmpIExtSIEQ +// CHECK: %[[comb:.+]] = arith.cmpi eq, %arg0, %arg1 : i8 +// CHECK: return %[[comb]] +func @cmpIExtSIEQ(%arg0: i8, %arg1: i8) -> i1 { + %ext0 = arith.extsi %arg0 : i8 to i64 + %ext1 = arith.extsi %arg1 : i8 to i64 + %res = arith.cmpi eq, %ext0, %ext1 : i64 + return %res : i1 +} + +// CHECK-LABEL: @cmpIExtUINE +// CHECK: %[[comb:.+]] = arith.cmpi ne, %arg0, %arg1 : i8 +// CHECK: return %[[comb]] +func @cmpIExtUINE(%arg0: i8, %arg1: i8) -> i1 { + %ext0 = arith.extui %arg0 : i8 to i64 + %ext1 = arith.extui %arg1 : i8 to i64 + %res = arith.cmpi ne, %ext0, %ext1 : i64 + return %res : i1 +} + +// CHECK-LABEL: @cmpIExtUIEQ +// CHECK: %[[comb:.+]] = arith.cmpi eq, %arg0, %arg1 : i8 +// CHECK: return %[[comb]] +func @cmpIExtUIEQ(%arg0: i8, %arg1: i8) -> i1 { + %ext0 = arith.extui %arg0 : i8 to i64 + %ext1 = arith.extui %arg1 : i8 to i64 + %res = arith.cmpi eq, %ext0, %ext1 : i64 + return %res : i1 +} + +// ----- + // CHECK-LABEL: @andOfExtSI // CHECK: %[[comb:.+]] = arith.andi %arg0, %arg1 : i8 // CHECK: %[[ext:.+]] = arith.extsi %[[comb]] : i8 to i64