diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -714,8 +714,8 @@ Value *CmpRHS = IC->getOperand(1); unsigned C1Log; - bool IsEqualZero; bool NeedAnd = false; + CmpInst::Predicate Pred = IC->getPredicate(); if (IC->isEquality()) { if (!match(CmpRHS, m_Zero())) return nullptr; @@ -725,17 +725,13 @@ return nullptr; C1Log = C1->logBase2(); - IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_EQ; } else { - // We also need to recognize (icmp slt X, 0) and (icmp sgt X, -1). - if (IC->getPredicate() == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes())) - IsEqualZero = true; - if (IC->getPredicate() == ICmpInst::ICMP_SLT && match(CmpRHS, m_Zero())) - IsEqualZero = false; - else + APInt C1; + if (!decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, CmpLHS, C1) || + !C1.isPowerOf2()) return nullptr; - C1Log = CmpLHS->getType()->getScalarSizeInBits() - 1; + C1Log = C1.logBase2(); NeedAnd = true; } @@ -745,11 +741,11 @@ if (match(FalseVal, m_Or(m_Specific(TrueVal), m_Power2(C2)))) { Y = TrueVal; Or = FalseVal; - NeedXor = !IsEqualZero; + NeedXor = Pred == ICmpInst::ICMP_NE; } else if (match(TrueVal, m_Or(m_Specific(FalseVal), m_Power2(C2)))) { Y = FalseVal; Or = TrueVal; - NeedXor = IsEqualZero; + NeedXor = Pred == ICmpInst::ICMP_EQ; } else { return nullptr; }