diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -728,13 +728,14 @@ V = CmpLHS; C1Log = C1->logBase2(); IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_EQ; - } else if (IC->getPredicate() == ICmpInst::ICMP_SLT || - IC->getPredicate() == ICmpInst::ICMP_SGT) { + } else { // We also need to recognize (icmp slt (trunc (X)), 0) and // (icmp sgt (trunc (X)), -1). - IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_SGT; - if ((IsEqualZero && !match(CmpRHS, m_AllOnes())) || - (!IsEqualZero && !match(CmpRHS, m_Zero()))) + if (IC->getPredicate() == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes())) + IsEqualZero = true; + if (IC->getPredicate() == ICmpInst::ICMP_SLT && match(CmpRHS, m_Zero())) + IsEqualZero = false; + else return nullptr; if (!match(CmpLHS, m_OneUse(m_Trunc(m_Value(V))))) @@ -742,30 +743,30 @@ C1Log = CmpLHS->getType()->getScalarSizeInBits() - 1; NeedAnd = true; - } else { - return nullptr; } + Value *Or, *Y; const APInt *C2; - bool OrOnTrueVal = false; - bool OrOnFalseVal = match(FalseVal, m_Or(m_Specific(TrueVal), m_Power2(C2))); - if (!OrOnFalseVal) - OrOnTrueVal = match(TrueVal, m_Or(m_Specific(FalseVal), m_Power2(C2))); - - if (!OrOnFalseVal && !OrOnTrueVal) + bool NeedXor; + if (match(FalseVal, m_Or(m_Specific(TrueVal), m_Power2(C2)))) { + Y = TrueVal; + Or = FalseVal; + NeedXor = !IsEqualZero; + } else if (match(TrueVal, m_Or(m_Specific(FalseVal), m_Power2(C2)))) { + Y = FalseVal; + Or = TrueVal; + NeedXor = IsEqualZero; + } else { return nullptr; - - Value *Y = OrOnFalseVal ? TrueVal : FalseVal; + } unsigned C2Log = C2->logBase2(); - bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal); bool NeedShift = C1Log != C2Log; bool NeedZExtTrunc = Y->getType()->getScalarSizeInBits() != V->getType()->getScalarSizeInBits(); // Make sure we don't create more instructions than we save. - Value *Or = OrOnFalseVal ? FalseVal : TrueVal; if ((NeedShift + NeedXor + NeedZExtTrunc) > (IC->hasOneUse() + Or->hasOneUse())) return nullptr;