Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -978,6 +978,15 @@ /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is + /// true. Here LHS is an operation that includes FoundLHS as one of its + /// arguments. + bool isImpliedViaOperations(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS); + + /// Test whether the condition described by Pred, LHS, and RHS is true + /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. bool isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -8492,6 +8492,77 @@ } bool +ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { + auto GetOpFromExt = [&](const SCEV *S) { + if (auto *Ext = dyn_cast(S)) + return Ext->getOperand(); + return S; + }; + + auto ExtendType = [&](const SCEV *S, Type *T) { + if (S->getType() == T) + return S; + return getSignExtendExpr(S, T); + }; + + auto MaxType = [&](const SCEV *S1, const SCEV *S2) { + return S1->getType()->getIntegerBitWidth() > + S2->getType()->getIntegerBitWidth() ? + S1->getType() : S2->getType(); + }; + + // Acquire values from extensions. + LHS = GetOpFromExt(LHS); + FoundLHS = GetOpFromExt(FoundLHS); + + if (auto *Operation = dyn_cast(LHS)) { + Value *Op1, *Op2; + // FIXME: Once we have SDiv implemented, we can get rid of this matching. + using namespace llvm::PatternMatch; + if (match(Operation->getValue(), + m_SDiv(m_Value(Op1), m_Value(Op2)))) { + // Rules for division. + const SCEV *Num = getSCEV(Op1); + const SCEV *Denum = getSCEV(Op2); + + auto *Ty1 = MaxType(Num, FoundLHS); + auto *NumExt = ExtendType(Num, Ty1); + auto *FoundLHSExt = ExtendType(FoundLHS, Ty1); + // Require that LHS = FoundLHS / Denum. + if (!HasSameValue(NumExt, FoundLHSExt)) + return false; + + // Try to prove the following rules: + // FoundLHS > FoundRHS > 0 => FoundLHS / 2 > 0 >= RHS. + // FoundLHS >= FoundRHS > 0 => FoundLHS / 2 >= 0 >= RHS. + if ((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) && + isKnownPositive(FoundRHS) && + isKnownPredicate(ICmpInst::ICMP_SLE, RHS, + getConstant(RHS->getType(), 0u, true)) && + HasSameValue(Denum, getConstant(Denum->getType(), 2u, true))) + return true; + + // Try to prove the following rule: + // FoundLHS >= FoundRHS, 0 < Denum <= FoundRHS => + // LHS = FoundLHS / Denum >= 1 >= RHS. + auto *Ty2 = MaxType(Denum, FoundRHS); + auto *DenumExt = ExtendType(Denum, Ty2); + auto *FoundRHSExt = ExtendType(FoundRHS, Ty2); + if (Pred == ICmpInst::ICMP_SGE && isKnownPositive(Denum) && + isKnownPredicate(ICmpInst::ICMP_SLE, RHS, + getConstant(RHS->getType(), 1u, true)) && + isKnownPredicate(ICmpInst::ICMP_SLE, DenumExt, FoundRHSExt)) + return true; + } + } + + return false; +} + +bool ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, @@ -8537,6 +8608,9 @@ break; } + if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + return false; }