diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1793,13 +1793,13 @@ /// /// If \p AllowPredicates is set, this call will try to use a minimal set of /// SCEV predicates in order to return an exact answer. + /// + /// If \p Invert is set, the input is actually a greater-than comparison. + /// The code will use a binary NOT on the operands, and adjust other logic to + /// try to be more precise. ExitLimit howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, - bool isSigned, bool ControlsExit, - bool AllowPredicates = false); - - ExitLimit howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, - bool isSigned, bool IsSubExpr, - bool AllowPredicates = false); + ICmpInst::Predicate Pred, bool ControlsExit, + bool AllowPredicates); /// Return a predecessor of BB (which may not be an immediate predecessor) /// which has exactly one successor from which BB is reachable, or null if @@ -2038,11 +2038,6 @@ /// the stride. bool canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, bool IsSigned); - /// Verify if an linear IV with negative stride can overflow when in a - /// greater-than comparison, knowing the invariant term of the comparison, - /// the stride. - bool canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, bool IsSigned); - /// Get add expr already created or create a new one. const SCEV *getOrCreateAddExpr(ArrayRef Ops, SCEV::NoWrapFlags Flags); diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -7966,19 +7966,12 @@ break; } case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_ULT: { // while (X < Y) - bool IsSigned = Pred == ICmpInst::ICMP_SLT; - ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit, - AllowPredicates); - if (EL.hasAnyInfo()) return EL; - break; - } + case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_UGT: { // while (X > Y) - bool IsSigned = Pred == ICmpInst::ICMP_SGT; + case ICmpInst::ICMP_UGT: { + // while (X > Y) or while (X < Y) ExitLimit EL = - howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit, - AllowPredicates); + howManyLessThans(LHS, RHS, L, Pred, ControlsExit, AllowPredicates); if (EL.hasAnyInfo()) return EL; break; } @@ -11325,29 +11318,6 @@ return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS); } -bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, - bool IsSigned) { - - unsigned BitWidth = getTypeSizeInBits(RHS->getType()); - const SCEV *One = getOne(Stride->getType()); - - if (IsSigned) { - APInt MinRHS = getSignedRangeMin(RHS); - APInt MinValue = APInt::getSignedMinValue(BitWidth); - APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One)); - - // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow! - return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS); - } - - APInt MinRHS = getUnsignedRangeMin(RHS); - APInt MinValue = APInt::getMinValue(BitWidth); - APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One)); - - // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow! - return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS); -} - const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step) { const SCEV *One = getOne(Step->getType()); @@ -11399,10 +11369,22 @@ ScalarEvolution::ExitLimit ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, - const Loop *L, bool IsSigned, + const Loop *L, ICmpInst::Predicate Pred, bool ControlsExit, bool AllowPredicates) { SmallPtrSet Predicates; + assert((ICmpInst::isGT(Pred) || ICmpInst::isLT(Pred)) && "Unexpected pred"); + bool IsSigned = ICmpInst::isSigned(Pred); + bool Invert = ICmpInst::isGT(Pred); + + // If this is a greater-than comparison, invert the LHS/RHS to make it a + // less-than comparison. + const SCEV *OrigRHS = RHS; + if (Invert) { + LHS = getNotSCEV(LHS); + RHS = getNotSCEV(RHS); + } + const SCEVAddRecExpr *IV = dyn_cast(LHS); bool PredicatedIV = false; @@ -11515,8 +11497,6 @@ return getCouldNotCompute(); } - ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT - : ICmpInst::ICMP_ULT; const SCEV *Start = IV->getStart(); const SCEV *End = RHS; // When the RHS is not invariant, we do not know the end bound of the loop and @@ -11543,17 +11523,35 @@ // as if the backedge is taken at least once max(End,Start) is End and so the // result is as above, and if not max(End,Start) is Start so we get a backedge // count of zero. + // + // Specialize for less-than versus greater-than. This isn't a matter of + // correctness: we inverted the inputs at the beginning of the function, so + // both paths are theoretically equivalent. But isLoopEntryGuardedByCond + // isn't very powerful, so we want to look for more likely patterns. const SCEV *BECount; - if (isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) + const SCEV *GuardRHS = OrigRHS; + const SCEV *GuardStart = Invert ? getNotSCEV(Start) : Start; + const SCEV *GuardStartMinusStride = getMinusSCEV(Start, Stride); + if (Invert) + GuardStartMinusStride = getNotSCEV(GuardStartMinusStride); + if (isLoopEntryGuardedByCond(L, Pred, GuardStartMinusStride, GuardRHS)) { BECount = BECountIfBackedgeTaken; - else { + } else { // If we know that RHS >= Start in the context of loop, then we know that // max(RHS, Start) = RHS at this point. - if (isLoopEntryGuardedByCond( - L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, RHS, Start)) + if (isLoopEntryGuardedByCond(L, ICmpInst::getInversePredicate(Pred), GuardRHS, + GuardStart)) { End = RHS; - else - End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); + } else { + if (Invert) { + // smin expressions are likely simpler than the equivalent smax + // expression for greater-than compares. + End = IsSigned ? getNotSCEV(getSMinExpr(GuardRHS, GuardStart)) + : getNotSCEV(getUMinExpr(GuardRHS, GuardStart)); + } else { + End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); + } + } BECount = computeBECount(getMinusSCEV(End, Start), Stride); } @@ -11579,88 +11577,6 @@ return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates); } -ScalarEvolution::ExitLimit -ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, - const Loop *L, bool IsSigned, - bool ControlsExit, bool AllowPredicates) { - SmallPtrSet Predicates; - // We handle only IV > Invariant - if (!isLoopInvariant(RHS, L)) - return getCouldNotCompute(); - - const SCEVAddRecExpr *IV = dyn_cast(LHS); - if (!IV && AllowPredicates) - // Try to make this an AddRec using runtime tests, in the first X - // iterations of this loop, where X is the SCEV expression found by the - // algorithm below. - IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates); - - // Avoid weird loops - if (!IV || IV->getLoop() != L || !IV->isAffine()) - return getCouldNotCompute(); - - bool NoWrap = ControlsExit && - IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); - - const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this)); - - // Avoid negative or zero stride values - if (!isKnownPositive(Stride)) - return getCouldNotCompute(); - - // Avoid proven overflow cases: this will ensure that the backedge taken count - // will not generate any unsigned overflow. Relaxed no-overflow conditions - // exploit NoWrapFlags, allowing to optimize in presence of undefined - // behaviors like the case of C language. - if (!Stride->isOne() && !NoWrap) - if (canIVOverflowOnGT(RHS, Stride, IsSigned)) - return getCouldNotCompute(); - - ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT - : ICmpInst::ICMP_UGT; - - const SCEV *Start = IV->getStart(); - const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) { - // If we know that Start >= RHS in the context of loop, then we know that - // min(RHS, Start) = RHS at this point. - if (isLoopEntryGuardedByCond( - L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS)) - End = RHS; - else - End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start); - } - - const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride); - - APInt MaxStart = IsSigned ? getSignedRangeMax(Start) - : getUnsignedRangeMax(Start); - - APInt MinStride = IsSigned ? getSignedRangeMin(Stride) - : getUnsignedRangeMin(Stride); - - unsigned BitWidth = getTypeSizeInBits(LHS->getType()); - APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1) - : APInt::getMinValue(BitWidth) + (MinStride - 1); - - // Although End can be a MIN expression we estimate MinEnd considering only - // the case End = RHS. This is safe because in the other case (Start - End) - // is zero, leading to a zero maximum backedge taken count. - APInt MinEnd = - IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit) - : APIntOps::umax(getUnsignedRangeMin(RHS), Limit); - - const SCEV *MaxBECount = isa(BECount) - ? BECount - : computeBECount(getConstant(MaxStart - MinEnd), - getConstant(MinStride)); - - if (isa(MaxBECount)) - MaxBECount = BECount; - - return ExitLimit(BECount, MaxBECount, false, Predicates); -} - const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const { if (Range.isFullSet()) // Infinite loop.