diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -437,6 +437,7 @@ ``` }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -465,6 +466,7 @@ ``` }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td @@ -136,4 +136,32 @@ def ExtSIOfExtUI : Pat<(Arith_ExtSIOp (Arith_ExtUIOp $x)), (Arith_ExtUIOp $x)>; +//===----------------------------------------------------------------------===// +// AndIOp +//===----------------------------------------------------------------------===// + +// and extui(x), extui(y) -> extui(and(x,y)) +def AndOfExtUI : + Pat<(Arith_AndIOp (Arith_ExtUIOp $x), (Arith_ExtUIOp $y)), (Arith_ExtUIOp (Arith_AndIOp $x, $y)), + [(Constraint> $x, $y)]>; + +// and extsi(x), extsi(y) -> extsi(and(x,y)) +def AndOfExtSI : + Pat<(Arith_AndIOp (Arith_ExtSIOp $x), (Arith_ExtSIOp $y)), (Arith_ExtSIOp (Arith_AndIOp $x, $y)), + [(Constraint> $x, $y)]>; + +//===----------------------------------------------------------------------===// +// OrIOp +//===----------------------------------------------------------------------===// + +// or extui(x), extui(y) -> extui(or(x,y)) +def OrOfExtUI : + Pat<(Arith_OrIOp (Arith_ExtUIOp $x), (Arith_ExtUIOp $y)), (Arith_ExtUIOp (Arith_OrIOp $x, $y)), + [(Constraint> $x, $y)]>; + +// or extsi(x), extsi(y) -> extsi(or(x,y)) +def OrOfExtSI : + Pat<(Arith_OrIOp (Arith_ExtSIOp $x), (Arith_ExtSIOp $y)), (Arith_ExtSIOp (Arith_OrIOp $x, $y)), + [(Constraint> $x, $y)]>; + #endif // ARITHMETIC_PATTERNS diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -891,6 +891,24 @@ return checkWidthChangeCast(inputs, outputs); } +//===----------------------------------------------------------------------===// +// AndIOp +//===----------------------------------------------------------------------===// + +void arith::AndIOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + +//===----------------------------------------------------------------------===// +// OrIOp +//===----------------------------------------------------------------------===// + +void arith::OrIOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // Verifiers for casts between integers and floats. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -99,6 +99,52 @@ // ----- +// CHECK-LABEL: @andOfExtSI +// CHECK: %[[comb:.+]] = arith.andi %arg0, %arg1 : i8 +// CHECK: %[[ext:.+]] = arith.extsi %[[comb]] : i8 to i64 +// CHECK: return %[[ext]] +func @andOfExtSI(%arg0: i8, %arg1: i8) -> i64 { + %ext0 = arith.extsi %arg0 : i8 to i64 + %ext1 = arith.extsi %arg1 : i8 to i64 + %res = arith.andi %ext0, %ext1 : i64 + return %res : i64 +} + +// CHECK-LABEL: @andOfExtUI +// CHECK: %[[comb:.+]] = arith.andi %arg0, %arg1 : i8 +// CHECK: %[[ext:.+]] = arith.extui %[[comb]] : i8 to i64 +// CHECK: return %[[ext]] +func @andOfExtUI(%arg0: i8, %arg1: i8) -> i64 { + %ext0 = arith.extui %arg0 : i8 to i64 + %ext1 = arith.extui %arg1 : i8 to i64 + %res = arith.andi %ext0, %ext1 : i64 + return %res : i64 +} + +// CHECK-LABEL: @orOfExtSI +// CHECK: %[[comb:.+]] = arith.ori %arg0, %arg1 : i8 +// CHECK: %[[ext:.+]] = arith.extsi %[[comb]] : i8 to i64 +// CHECK: return %[[ext]] +func @orOfExtSI(%arg0: i8, %arg1: i8) -> i64 { + %ext0 = arith.extsi %arg0 : i8 to i64 + %ext1 = arith.extsi %arg1 : i8 to i64 + %res = arith.ori %ext0, %ext1 : i64 + return %res : i64 +} + +// CHECK-LABEL: @orOfExtUI +// CHECK: %[[comb:.+]] = arith.ori %arg0, %arg1 : i8 +// CHECK: %[[ext:.+]] = arith.extui %[[comb]] : i8 to i64 +// CHECK: return %[[ext]] +func @orOfExtUI(%arg0: i8, %arg1: i8) -> i64 { + %ext0 = arith.extui %arg0 : i8 to i64 + %ext1 = arith.extui %arg1 : i8 to i64 + %res = arith.ori %ext0, %ext1 : i64 + return %res : i64 +} + +// ----- + // CHECK-LABEL: @indexCastOfSignExtend // CHECK: %[[res:.+]] = arith.index_cast %arg0 : i8 to index // CHECK: return %[[res]]