diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1811,7 +1811,7 @@ /// SCEV predicates in order to return an exact answer. ExitLimit howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool isSigned, bool ControlsExit, - bool AllowPredicates = false); + bool AllowPredicates, ICmpInst *OrigCond); ExitLimit howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool isSigned, bool IsSubExpr, diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -8108,7 +8108,7 @@ case ICmpInst::ICMP_ULT: { // while (X < Y) bool IsSigned = Pred == ICmpInst::ICMP_SLT; ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit, - AllowPredicates); + AllowPredicates, ExitCond); if (EL.hasAnyInfo()) return EL; break; } @@ -11573,7 +11573,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool ControlsExit, bool AllowPredicates) { + bool ControlsExit, bool AllowPredicates, + ICmpInst *OrigCond) { SmallPtrSet Predicates; const SCEVAddRecExpr *IV = dyn_cast(LHS); @@ -11765,8 +11766,23 @@ const SCEV *BECount = nullptr; auto *StartMinusStride = getMinusSCEV(OrigStart, Stride); // Can we prove (max(RHS,Start) > Start - Stride? - if (isLoopEntryGuardedByCond(L, Cond, StartMinusStride, Start) && - isLoopEntryGuardedByCond(L, Cond, StartMinusStride, RHS)) { + if (isLoopEntryGuardedByCond(L, Cond, StartMinusStride, RHS)) { + auto MaySubOverflow = [&]() { + // Start - Stride < Start implies no overflow. + if (isLoopEntryGuardedByCond(L, Cond, StartMinusStride, Start)) + return false; + // Check if we have an IR instruction feeding into the branch + // "add nsw IVMinusStride, Stride" + if (getSCEV(OrigCond->getOperand(0)) == IV) { + if (auto *Add = dyn_cast(OrigCond->getOperand(0))) { + if (IsSigned ? Add->hasNoSignedWrap() : Add->hasNoUnsignedWrap()) { + if (getSCEV(Add->getOperand(1)) == Stride) + return false; + } + } + } + return true; + }; // In this case, we can use a refined formula for computing backedge taken // count. The general formula remains: // "End-Start /uceiling Stride" where "End = max(RHS,Start)" @@ -11785,11 +11801,13 @@ // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to // "((RHS - (Start - Stride) - 1) /u Stride". // Our preconditions trivially imply no overflow in that form. - const SCEV *MinusOne = getMinusOne(Stride->getType()); - const SCEV *Numerator = - getMinusSCEV(getAddExpr(RHS, MinusOne), StartMinusStride); - if (!isa(Numerator)) { - BECount = getUDivExpr(Numerator, Stride); + if (!MaySubOverflow()) { + const SCEV *MinusOne = getMinusOne(Stride->getType()); + const SCEV *Numerator = + getMinusSCEV(getAddExpr(RHS, MinusOne), StartMinusStride); + if (!isa(Numerator)) { + BECount = getUDivExpr(Numerator, Stride); + } } } diff --git a/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll b/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll --- a/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll +++ b/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll @@ -5,7 +5,7 @@ ; that this is not an infinite loop with side effects. ; CHECK-LABEL: Determining loop execution counts for: @foo1 -; CHECK: backedge-taken count is ((-1 + (%n smax %s)) /u %s) +; CHECK: backedge-taken count is ((-1 + %n) /u %s) ; We should have a conservative estimate for the max backedge taken count for ; loops with unknown stride.