diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -675,8 +675,8 @@ Value *CmpRHS = IC->getOperand(1); unsigned C1Log; - bool IsEqualZero; bool NeedAnd = false; + CmpInst::Predicate Pred = IC->getPredicate(); if (IC->isEquality()) { if (!match(CmpRHS, m_Zero())) return nullptr; @@ -686,18 +686,15 @@ return nullptr; C1Log = C1->logBase2(); - IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_EQ; } else { - // We also need to recognize (icmp slt X, 0) and (icmp sgt X, -1). - if (IC->getPredicate() == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes())) - IsEqualZero = true; - if (IC->getPredicate() == ICmpInst::ICMP_SLT && match(CmpRHS, m_Zero())) - IsEqualZero = false; - else + Value *Unused; + APInt C1; + if (!decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, Unused, C1) || + !C1.isPowerOf2()) return nullptr; - C1Log = CmpLHS->getType()->getScalarSizeInBits() - 1; - NeedAnd = true; + C1Log = C1.logBase2(); + NeedAnd = !match(CmpRHS, m_SpecificInt(C1)); } Value *Or, *Y, *V = CmpLHS; @@ -706,11 +703,11 @@ if (match(FalseVal, m_Or(m_Specific(TrueVal), m_Power2(C2)))) { Y = TrueVal; Or = FalseVal; - NeedXor = !IsEqualZero; + NeedXor = Pred == ICmpInst::ICMP_NE; } else if (match(TrueVal, m_Or(m_Specific(FalseVal), m_Power2(C2)))) { Y = FalseVal; Or = TrueVal; - NeedXor = IsEqualZero; + NeedXor = Pred == ICmpInst::ICMP_EQ; } else { return nullptr; }