Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -4684,12 +4684,31 @@ return false; } +// Only bits set in Mask must be negated, other bits may be arbitrary. +static SDValue getBitwiseNotOperand(SDValue V, SDValue Mask) { + if (isBitwiseNot(V, true)) + return V.getOperand(0); + + // Handle any_extend (not (truncate X)) pattern, where Mask only sets + // bits in the non-extended part. + ConstantSDNode *MaskC = isConstOrConstSplat(Mask); + if (!MaskC || V.getOpcode() != ISD::ANY_EXTEND) + return SDValue(); + SDValue ExtArg = V.getOperand(0); + if (ExtArg.getScalarValueSizeInBits() >= + MaskC->getAPIntValue().getActiveBits() && + isBitwiseNot(ExtArg, true) && + ExtArg.getOperand(0).getOpcode() == ISD::TRUNCATE) + return ExtArg.getOperand(0).getOperand(0); + return SDValue(); +} + static bool haveNoCommonBitsSetCommutative(SDValue A, SDValue B) { // Match masked merge pattern (X & ~M) op (Y & M) // Including degenerate case (X & ~M) op M - auto MatchNoCommonBitsPattern = [&](SDValue NotM, SDValue Other) { - if (isBitwiseNot(NotM, true)) { - SDValue NotOperand = NotM->getOperand(0); + auto MatchNoCommonBitsPattern = [&](SDValue Not, SDValue Mask, + SDValue Other) { + if (SDValue NotOperand = getBitwiseNotOperand(Not, Mask)) { if (Other == NotOperand) return true; if (Other->getOpcode() == ISD::AND) @@ -4699,8 +4718,8 @@ return false; }; if (A->getOpcode() == ISD::AND) - return MatchNoCommonBitsPattern(A->getOperand(0), B) || - MatchNoCommonBitsPattern(A->getOperand(1), B); + return MatchNoCommonBitsPattern(A->getOperand(0), A->getOperand(1), B) || + MatchNoCommonBitsPattern(A->getOperand(1), A->getOperand(0), B); return false; } Index: llvm/test/CodeGen/X86/add-and-not.ll =================================================================== --- llvm/test/CodeGen/X86/add-and-not.ll +++ llvm/test/CodeGen/X86/add-and-not.ll @@ -121,7 +121,7 @@ ; CHECK-NEXT: movl %edi, %eax ; CHECK-NEXT: notl %eax ; CHECK-NEXT: andl $1, %eax -; CHECK-NEXT: addq %rdi, %rax +; CHECK-NEXT: orq %rdi, %rax ; CHECK-NEXT: retq %xor = xor i64 %x, -1 %and = and i64 %xor, 1 @@ -148,7 +148,7 @@ ; CHECK-NEXT: movl %edi, %eax ; CHECK-NEXT: notl %eax ; CHECK-NEXT: andl $1, %eax -; CHECK-NEXT: addq %rdi, %rax +; CHECK-NEXT: orq %rdi, %rax ; CHECK-NEXT: retq %trunc = trunc i64 %x to i32 %xor = xor i32 %trunc, -1 @@ -195,7 +195,7 @@ ; CHECK-NEXT: movl %edi, %eax ; CHECK-NEXT: notl %eax ; CHECK-NEXT: andl $1, %eax -; CHECK-NEXT: addq %rdi, %rax +; CHECK-NEXT: orq %rdi, %rax ; CHECK-NEXT: retq %old = ptrtoint i8* %a to i64 %old.not = and i64 %old, 1