diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1335,11 +1335,14 @@ /// as arguments and asserts enforce that internally. /*implicit*/ ExitLimit(const SCEV *E); - ExitLimit(const SCEV *E, const SCEV *ConstantMaxNotTaken, bool MaxOrZero, - ArrayRef *> - PredSetList = std::nullopt); - - ExitLimit(const SCEV *E, const SCEV *ConstantMaxNotTaken, bool MaxOrZero, + ExitLimit( + const SCEV *E, const SCEV *ConstantMaxNotTaken, + const SCEV *SymbolicMaxNotTaken, bool MaxOrZero, + ArrayRef *> PredSetList = + std::nullopt); + + ExitLimit(const SCEV *E, const SCEV *ConstantMaxNotTaken, + const SCEV *SymbolicMaxNotTaken, bool MaxOrZero, const SmallPtrSetImpl &PredSet); /// Test whether this ExitLimit contains any computed information, or diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -8603,29 +8603,31 @@ } ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E) - : ExitLimit(E, E, false, std::nullopt) {} + : ExitLimit(E, E, E, false, std::nullopt) {} ScalarEvolution::ExitLimit::ExitLimit( - const SCEV *E, const SCEV *ConstantMaxNotTaken, bool MaxOrZero, + const SCEV *E, const SCEV *ConstantMaxNotTaken, + const SCEV *SymbolicMaxNotTaken, bool MaxOrZero, ArrayRef *> PredSetList) : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken), - MaxOrZero(MaxOrZero) { + SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) { // If we prove the max count is zero, so is the symbolic bound. This happens // in practice due to differences in a) how context sensitive we've chosen // to be and b) how we reason about bounds implied by UB. - if (ConstantMaxNotTaken->isZero()) - ExactNotTaken = ConstantMaxNotTaken; - - // FIXME: For now, SymbolicMaxNotTaken is either exact (if available) or - // constant max. In the future, we are planning to make it more powerful. - if (isa(ExactNotTaken)) - SymbolicMaxNotTaken = ConstantMaxNotTaken; - else - SymbolicMaxNotTaken = ExactNotTaken; + if (ConstantMaxNotTaken->isZero()) { + this->ExactNotTaken = E = ConstantMaxNotTaken; + this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken; + } assert((isa(ExactNotTaken) || !isa(ConstantMaxNotTaken)) && - "Exact is not allowed to be less precise than Max"); + "Exact is not allowed to be less precise than Constant Max"); + assert((isa(ExactNotTaken) || + !isa(SymbolicMaxNotTaken)) && + "Exact is not allowed to be less precise than Symbolic Max"); + assert((isa(SymbolicMaxNotTaken) || + !isa(ConstantMaxNotTaken)) && + "Symbolic Max is not allowed to be less precise than Constant Max"); assert((isa(ConstantMaxNotTaken) || isa(ConstantMaxNotTaken)) && "No point in having a non-constant max backedge taken count!"); @@ -8640,9 +8642,11 @@ } ScalarEvolution::ExitLimit::ExitLimit( - const SCEV *E, const SCEV *ConstantMaxNotTaken, bool MaxOrZero, + const SCEV *E, const SCEV *ConstantMaxNotTaken, + const SCEV *SymbolicMaxNotTaken, bool MaxOrZero, const SmallPtrSetImpl &PredSet) - : ExitLimit(E, ConstantMaxNotTaken, MaxOrZero, { &PredSet }) {} + : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero, + { &PredSet }) {} /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each /// computable exit into a persistent ExitNotTakenInfo array. @@ -8980,8 +8984,9 @@ if (isa(ConstantMaxBECount) && !isa(BECount)) ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); - - return ExitLimit(BECount, ConstantMaxBECount, false, + const SCEV *SymbolicMaxBECount = + isa(BECount) ? ConstantMaxBECount : BECount; + return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false, { &EL0.Predicates, &EL1.Predicates }); } @@ -9307,7 +9312,7 @@ unsigned BitWidth = getTypeSizeInBits(RHS->getType()); const SCEV *UpperBound = getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); - return ExitLimit(getCouldNotCompute(), UpperBound, false); + return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false); } return getCouldNotCompute(); @@ -10308,7 +10313,7 @@ // should not accept a root of 2. if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) { const auto *R = cast(getConstant(*S)); - return ExitLimit(R, R, false, Predicates); + return ExitLimit(R, R, R, false, Predicates); } return getCouldNotCompute(); } @@ -10373,7 +10378,8 @@ ConstantRange CR = getUnsignedRange(DistancePlusOne); MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1); } - return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates); + return ExitLimit(Distance, getConstant(MaxBECount), Distance, false, + Predicates); } // If the condition controls loop exit (the loop exits only if the expression @@ -10385,12 +10391,15 @@ loopHasNoAbnormalExits(AddRec->getLoop())) { const SCEV *Exact = getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - const SCEV *Max = getCouldNotCompute(); + const SCEV *ConstantMax = getCouldNotCompute(); if (Exact != getCouldNotCompute()) { APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L)); - Max = getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact))); + ConstantMax = + getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact))); } - return ExitLimit(Exact, Max, false, Predicates); + const SCEV *SymbolicMax = + isa(Exact) ? ConstantMax : Exact; + return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates); } // Solve the general equation. @@ -10402,7 +10411,8 @@ APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L)); M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E))); } - return ExitLimit(E, M, false, Predicates); + auto *S = isa(E) ? M : E; + return ExitLimit(E, M, S, false, Predicates); } ScalarEvolution::ExitLimit @@ -12812,7 +12822,7 @@ const SCEV *MaxBECount = computeMaxBECountForLT( Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned); return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount, - false /*MaxOrZero*/, Predicates); + MaxBECount, false /*MaxOrZero*/, Predicates); } // We use the expression (max(End,Start)-Start)/Stride to describe the @@ -12975,27 +12985,30 @@ } } - const SCEV *MaxBECount; + const SCEV *ConstantMaxBECount; bool MaxOrZero = false; if (isa(BECount)) { - MaxBECount = BECount; + ConstantMaxBECount = BECount; } else if (BECountIfBackedgeTaken && isa(BECountIfBackedgeTaken)) { // If we know exactly how many times the backedge will be taken if it's // taken at least once, then the backedge count will either be that or // zero. - MaxBECount = BECountIfBackedgeTaken; + ConstantMaxBECount = BECountIfBackedgeTaken; MaxOrZero = true; } else { - MaxBECount = computeMaxBECountForLT( + ConstantMaxBECount = computeMaxBECountForLT( Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned); } - if (isa(MaxBECount) && + if (isa(ConstantMaxBECount) && !isa(BECount)) - MaxBECount = getConstant(getUnsignedRangeMax(BECount)); + ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); - return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates); + const SCEV *SymbolicMaxBECount = + isa(BECount) ? ConstantMaxBECount : BECount; + return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero, + Predicates); } ScalarEvolution::ExitLimit @@ -13083,15 +13096,19 @@ IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit) : APIntOps::umax(getUnsignedRangeMin(RHS), Limit); - const SCEV *MaxBECount = isa(BECount) - ? BECount - : getUDivCeilSCEV(getConstant(MaxStart - MinEnd), - getConstant(MinStride)); + const SCEV *ConstantMaxBECount = + isa(BECount) + ? BECount + : getUDivCeilSCEV(getConstant(MaxStart - MinEnd), + getConstant(MinStride)); - if (isa(MaxBECount)) - MaxBECount = BECount; + if (isa(ConstantMaxBECount)) + ConstantMaxBECount = BECount; + const SCEV *SymbolicMaxBECount = + isa(BECount) ? ConstantMaxBECount : BECount; - return ExitLimit(BECount, MaxBECount, false, Predicates); + return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false, + Predicates); } const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,