diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -574,6 +574,16 @@ APInt intValue; if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) return getLhs(); + /// and(x, not(x)) -> 0 + if (matchPattern(getRhs(), m_Op(matchers::m_Val(getLhs()), + m_ConstantInt(&intValue))) && + intValue.isAllOnes()) + return IntegerAttr::get(getType(), 0); + /// and(not(x), x) -> 0 + if (matchPattern(getLhs(), m_Op(matchers::m_Val(getRhs()), + m_ConstantInt(&intValue))) && + intValue.isAllOnes()) + return IntegerAttr::get(getType(), 0); return constFoldBinaryOp( operands, [](APInt a, const APInt &b) { return std::move(a) & b; }); diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -1532,3 +1532,30 @@ %0 = arith.remf %v1, %v2 : vector<4xf32> return %0 : vector<4xf32> } + +// ----- + +// CHECK-LABEL: @test_andi_not_fold_rhs( +// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] +// CHECK: %[[C:.*]] = arith.constant 0 : index +// CHECK: return %[[C]] + +func.func @test_andi_not_fold_rhs(%arg0 : index) -> index { + %0 = arith.constant -1 : index + %1 = arith.xori %arg0, %0 : index + %2 = arith.andi %arg0, %1 : index + return %2 : index +} + + +// CHECK-LABEL: @test_andi_not_fold_lhs( +// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] +// CHECK: %[[C:.*]] = arith.constant 0 : index +// CHECK: return %[[C]] + +func.func @test_andi_not_fold_lhs(%arg0 : index) -> index { + %0 = arith.constant -1 : index + %1 = arith.xori %arg0, %0 : index + %2 = arith.andi %1, %arg0 : index + return %2 : index +}