diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -612,6 +612,23 @@ // AndIOp //===----------------------------------------------------------------------===// +/// Fold `and(a, and(a, b))` to `and(a, b)` +static Value foldAndIofAndI(arith::AndIOp op) { + for (bool reversePrev : {false, true}) { + auto prev = (reversePrev ? op.getRhs() : op.getLhs()) + .getDefiningOp(); + if (!prev) + continue; + + Value other = (reversePrev ? op.getLhs() : op.getRhs()); + if (other != prev.getLhs() && other != prev.getRhs()) + continue; + + return prev.getResult(); + } + return {}; +} + OpFoldResult arith::AndIOp::fold(ArrayRef operands) { /// and(x, 0) -> 0 if (matchPattern(getRhs(), m_Zero())) @@ -631,6 +648,10 @@ intValue.isAllOnes()) return IntegerAttr::get(getType(), 0); + /// and(a, and(a, b)) -> and(a, b) + if (Value result = foldAndIofAndI(*this)) + return result; + return constFoldBinaryOp( operands, [](APInt a, const APInt &b) { return std::move(a) & b; }); } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -1682,3 +1682,45 @@ } // ----- + +/// and(a, and(a, b)) -> and(a, b) + +// CHECK-LABEL: @andand0 +// CHECK-SAME: (%[[A:.*]]: i32, %[[B:.*]]: i32) +// CHECK: %[[RES:.*]] = arith.andi %[[A]], %[[B]] : i32 +// CHECK: return %[[RES]] +func.func @andand0(%a : i32, %b : i32) -> i32 { + %c = arith.andi %a, %b : i32 + %res = arith.andi %a, %c : i32 + return %res : i32 +} + +// CHECK-LABEL: @andand1 +// CHECK-SAME: (%[[A:.*]]: i32, %[[B:.*]]: i32) +// CHECK: %[[RES:.*]] = arith.andi %[[A]], %[[B]] : i32 +// CHECK: return %[[RES]] +func.func @andand1(%a : i32, %b : i32) -> i32 { + %c = arith.andi %a, %b : i32 + %res = arith.andi %c, %a : i32 + return %res : i32 +} + +// CHECK-LABEL: @andand2 +// CHECK-SAME: (%[[A:.*]]: i32, %[[B:.*]]: i32) +// CHECK: %[[RES:.*]] = arith.andi %[[A]], %[[B]] : i32 +// CHECK: return %[[RES]] +func.func @andand2(%a : i32, %b : i32) -> i32 { + %c = arith.andi %a, %b : i32 + %res = arith.andi %b, %c : i32 + return %res : i32 +} + +// CHECK-LABEL: @andand3 +// CHECK-SAME: (%[[A:.*]]: i32, %[[B:.*]]: i32) +// CHECK: %[[RES:.*]] = arith.andi %[[A]], %[[B]] : i32 +// CHECK: return %[[RES]] +func.func @andand3(%a : i32, %b : i32) -> i32 { + %c = arith.andi %a, %b : i32 + %res = arith.andi %c, %b : i32 + return %res : i32 +}