diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1688,6 +1688,11 @@ /// constant is canonicalized to be operand 1. bool isBitwiseNot(SDValue V, bool AllowUndefs = false); +/// If \p V is a bitwise not, returns the inverted operand. Otherwise returns +/// an empty SDValue. Only bits set in \p Mask are required to be inverted, +/// other bits may be arbitrary. +SDValue getBitwiseNotOperand(SDValue V, SDValue Mask, bool AllowUndefs); + /// Returns the SDNode if it is a constant splat BuildVector or constant int. ConstantSDNode *isConstOrConstSplat(SDValue N, bool AllowUndefs = false, bool AllowTruncation = false); diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -6814,11 +6814,14 @@ EVT VT = N0.getValueType(); if (N0.getOpcode() == ISD::AND) { // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y) - if (isBitwiseNot(N0.getOperand(1)) && N0.getOperand(1).getOperand(0) == N1) + // TODO: Set AllowUndefs = true. + if (getBitwiseNotOperand(N0.getOperand(1), N0.getOperand(0), + /* AllowUndefs */ false) == N1) return DAG.getNode(ISD::OR, SDLoc(N), VT, N0.getOperand(0), N1); // fold (or (and (xor Y, -1), X), Y) -> (or X, Y) - if (isBitwiseNot(N0.getOperand(0)) && N0.getOperand(0).getOperand(0) == N1) + if (getBitwiseNotOperand(N0.getOperand(0), N0.getOperand(1), + /* AllowUndefs */ false) == N1) return DAG.getNode(ISD::OR, SDLoc(N), VT, N0.getOperand(1), N1); } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -4685,8 +4685,8 @@ } // Only bits set in Mask must be negated, other bits may be arbitrary. -static SDValue getBitwiseNotOperand(SDValue V, SDValue Mask) { - if (isBitwiseNot(V, true)) +SDValue llvm::getBitwiseNotOperand(SDValue V, SDValue Mask, bool AllowUndefs) { + if (isBitwiseNot(V, AllowUndefs)) return V.getOperand(0); // Handle any_extend (not (truncate X)) pattern, where Mask only sets @@ -4697,7 +4697,7 @@ SDValue ExtArg = V.getOperand(0); if (ExtArg.getScalarValueSizeInBits() >= MaskC->getAPIntValue().getActiveBits() && - isBitwiseNot(ExtArg, true) && + isBitwiseNot(ExtArg, AllowUndefs) && ExtArg.getOperand(0).getOpcode() == ISD::TRUNCATE && ExtArg.getOperand(0).getOperand(0).getValueType() == V.getValueType()) return ExtArg.getOperand(0).getOperand(0); @@ -4709,7 +4709,8 @@ // Including degenerate case (X & ~M) op M auto MatchNoCommonBitsPattern = [&](SDValue Not, SDValue Mask, SDValue Other) { - if (SDValue NotOperand = getBitwiseNotOperand(Not, Mask)) { + if (SDValue NotOperand = + getBitwiseNotOperand(Not, Mask, /* AllowUndefs */ true)) { if (Other == NotOperand) return true; if (Other->getOpcode() == ISD::AND) diff --git a/llvm/test/CodeGen/X86/add-and-not.ll b/llvm/test/CodeGen/X86/add-and-not.ll --- a/llvm/test/CodeGen/X86/add-and-not.ll +++ b/llvm/test/CodeGen/X86/add-and-not.ll @@ -118,10 +118,8 @@ define i64 @add_and_xor_const(i64 %x) { ; CHECK-LABEL: add_and_xor_const: ; CHECK: # %bb.0: -; CHECK-NEXT: movl %edi, %eax -; CHECK-NEXT: notl %eax -; CHECK-NEXT: andl $1, %eax -; CHECK-NEXT: orq %rdi, %rax +; CHECK-NEXT: movq %rdi, %rax +; CHECK-NEXT: orq $1, %rax ; CHECK-NEXT: retq %xor = xor i64 %x, -1 %and = and i64 %xor, 1 @@ -145,10 +143,8 @@ define i64 @add_and_xor_const_explicit_trunc(i64 %x) { ; CHECK-LABEL: add_and_xor_const_explicit_trunc: ; CHECK: # %bb.0: -; CHECK-NEXT: movl %edi, %eax -; CHECK-NEXT: notl %eax -; CHECK-NEXT: andl $1, %eax -; CHECK-NEXT: orq %rdi, %rax +; CHECK-NEXT: movq %rdi, %rax +; CHECK-NEXT: orq $1, %rax ; CHECK-NEXT: retq %trunc = trunc i64 %x to i32 %xor = xor i32 %trunc, -1 @@ -192,10 +188,8 @@ define i8* @gep_and_xor_const(i8* %a) { ; CHECK-LABEL: gep_and_xor_const: ; CHECK: # %bb.0: -; CHECK-NEXT: movl %edi, %eax -; CHECK-NEXT: notl %eax -; CHECK-NEXT: andl $1, %eax -; CHECK-NEXT: orq %rdi, %rax +; CHECK-NEXT: movq %rdi, %rax +; CHECK-NEXT: orq $1, %rax ; CHECK-NEXT: retq %old = ptrtoint i8* %a to i64 %old.not = and i64 %old, 1