diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -2017,11 +2017,27 @@ createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI); /// Compute the backedge taken count knowing the interval difference, and - /// the stride for an inequality. Result takes the form: - /// (Delta + (Stride - 1)) udiv Stride. - /// Caller must ensure that this expression either does not overflow or - /// that the result is undefined if it does. - const SCEV *computeBECount(const SCEV *Delta, const SCEV *Stride); + /// the stride for an inequality. + /// + /// Caller must ensure that non-negative N exists such that + /// (Start + Stride * N) >= End, and that computing "(Start + Stride * N)" + /// doesn't overflow. In other words: + /// 1. Start is greater than or equal to End + /// 2. If Stride is not positive, End is equal to Start + /// 3. The index variable doesn't overflow. + /// + /// If the preconditions hold, the backedge taken count is N. + /// + /// IsSigned determines whether End, Start, and Stride are treated as + /// signed values, for the purpose of optimizing the form of the result. + /// + /// The result is usually of the form: + /// ((End - Start) + (Stride - 1)) /u Stride + /// + /// The function will use an alternate form if it can't prove the addition + /// of "Stride - 1" doesn't overflow. + const SCEV *computeBECount(bool IsSigned, const SCEV *End, const SCEV *Start, + const SCEV *Stride); /// Compute the maximum backedge count based on the range of values /// permitted by Start, End, and Stride. This is for loops of the form diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -11434,11 +11434,111 @@ return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS); } -const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, - const SCEV *Step) { - const SCEV *One = getOne(Step->getType()); - Delta = getAddExpr(Delta, getMinusSCEV(Step, One)); - return getUDivExpr(Delta, Step); +const SCEV *ScalarEvolution::computeBECount(bool IsSigned, const SCEV *End, + const SCEV *Start, + const SCEV *Stride) { + // The basic formula here is ceil((End - Start) / Stride). Since SCEV + // doesn't natively have division that rounds up, we need to convert to + // floor division. + // + // MayOverflow is whether adding (End - Start) + (Stride - 1) + // can overflow if Stride is positive. It's a precondition of the + // function that "End - Start" doesn't overflow. We handle the case where + // Stride isn't positive later. + // + // In practice, the arithmetic almost never overflows, but we have to prove + // it. We have a variety of ways to come up with a proof. + const SCEV *One = getOne(Stride->getType()); + bool MayOverflow = [&] { + if (auto *StrideC = dyn_cast(Stride)) { + if (StrideC->getAPInt().isPowerOf2()) { + // If the stride is a power of two, the maximum backedge-taken count is + // is (2^bitwidth / Stride) - 1. After that, the addrec repeats values. + // + // If (End - Start) + (Stride - 1) overflows, that means it's greater + // than or equal to 2^bitwidth. In that case, + // floor(((End - Start) + (Stride - 1)) / Stride) must be greater than + // or equal to (2^bitwidth / Stride). But this is impossible: it's + // greater than the maximum number of iterations. + // + // Note this proof requires that Stride is a factor of 2^bitwidth: + // in general, the maximum backedge-taken count is + // ceil(2^bitwidth / Stride) - 1, and the inequalities don't hold. + return false; + } + } + // Figure out the "base" start for MayOverflow checks. If the start is + // of the form "min(BaseStart, End)", then either Delta is equal to + // zero, or Start == BaseStart. If Delta is zero, adding Stride - 1 + // trivially can't overflow. + const SCEV *BaseStart = Start; + if (IsSigned) { + if (auto *SMinStart = dyn_cast(Start)) { + if (SMinStart->getNumOperands() == 2 && SMinStart->getOperand(1) == End) + BaseStart = SMinStart->getOperand(0); + } + } else { + if (auto *UMinStart = dyn_cast(Start)) { + if (UMinStart->getNumOperands() == 2 && UMinStart->getOperand(1) == End) + BaseStart = UMinStart->getOperand(0); + } + } + if (BaseStart == Stride || BaseStart == getMinusSCEV(Stride, One)) { + // If Start is equal to Stride, (End - Start) + Stride - 1 < End. + // If Start is equal to Stride - 1, (End - Start) + Stride - 1 == End. + // If IsSigned is true, we assume the stride is positive. + return false; + } + if (auto *StrideC = dyn_cast(Stride)) { + if (auto *BaseStartC = dyn_cast(BaseStart)) { + // Start >= Stride implies End - Start + Stride - 1 < End. + if (!IsSigned && BaseStartC->getAPInt().uge(StrideC->getAPInt())) + return false; + if (IsSigned && BaseStartC->getAPInt().sge(StrideC->getAPInt())) + return false; + } + if (auto *EndC = dyn_cast(End)) { + // For unsigned: + // End <= 2^bitwidth - Stride implies + // End < 2^bitwidth - Stride + 1 implies + // End + Stride - 1 < 2^bitwidth implies + // (End - Start) + (Stride - 1) < 2^bitwidth, i.e. no overflow. + // + // FIXME: Signed equivalent? + if (!IsSigned && EndC->getAPInt().ule(-StrideC->getAPInt())) + return false; + } + } + if (IsSigned && isKnownNonNegative(BaseStart)) { + // If Start is non-negative, End - Start is also non-negative. So in + // unsigned arithmetic, "(End - Start) + (Stride - 1)" easily fits. + return false; + } + return true; + }(); + + // Force the stride to at least one, so we don't divide by zero. The stride + // can be zero if Delta is zero. We don't actually care what value we use + // for Stride in this case, as long as it isn't zero. + // + // FIXME: Try to use loop guards to prove Stride is non-zero + // FIXME: Try to use loop guards to prove Start != End. + Stride = getUMaxExpr(Stride, One); + + const SCEV *Delta = getMinusSCEV(End, Start); + if (!MayOverflow) { + // floor((D + (S - 1)) / S) + // We prefer this formulation if it's legal because it's fewer operations. + return getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride); + } + // umin(D, 1) + floor((D - umin(D, 1)) / S) + // This is equivalent to "1 + floor((D - 1) / S)" for D != 0. The umin + // expression fixes the case of D=0. + // + // FIXME: Try to use loop guards to prove D is non-zero. + const SCEV *MinDeltaOne = getUMinExpr(Delta, One); + const SCEV *DeltaMinusOne = getMinusSCEV(Delta, MinDeltaOne); + return getAddExpr(MinDeltaOne, getUDivExpr(DeltaMinusOne, Stride)); } const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start, @@ -11477,8 +11577,9 @@ APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit) : APIntOps::umin(getUnsignedRangeMax(End), Limit); - MaxBECount = computeBECount(getConstant(MaxEnd - MinStart) /* Delta */, - getConstant(StrideForMaxBECount) /* Step */); + MaxBECount = + computeBECount(IsSigned, getConstant(MaxEnd), getConstant(MinStart), + getConstant(StrideForMaxBECount)); return MaxBECount; } @@ -11636,7 +11737,7 @@ // is the LHS value of the less-than comparison the first time it is evaluated // and End is the RHS. const SCEV *BECountIfBackedgeTaken = - computeBECount(getMinusSCEV(End, Start), Stride); + computeBECount(IsSigned, End, Start, Stride); // If the loop entry is guarded by the result of the backedge test of the // first loop iteration, then we know the backedge will be taken at least // once and so the backedge taken count is as above. If not then we use the @@ -11655,7 +11756,7 @@ End = RHS; else End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); - BECount = computeBECount(getMinusSCEV(End, Start), Stride); + BECount = computeBECount(IsSigned, End, Start, Stride); } const SCEV *MaxBECount; @@ -11741,7 +11842,7 @@ return End; } - const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride); + const SCEV *BECount = computeBECount(IsSigned, Start, End, Stride); APInt MaxStart = IsSigned ? getSignedRangeMax(Start) : getUnsignedRangeMax(Start); @@ -11760,10 +11861,11 @@ IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit) : APIntOps::umax(getUnsignedRangeMin(RHS), Limit); - const SCEV *MaxBECount = isa(BECount) - ? BECount - : computeBECount(getConstant(MaxStart - MinEnd), - getConstant(MinStride)); + const SCEV *MaxBECount = + isa(BECount) + ? BECount + : computeBECount(IsSigned, getConstant(MaxStart), getConstant(MinEnd), + getConstant(MinStride)); if (isa(MaxBECount)) MaxBECount = BECount; diff --git a/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll b/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll --- a/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll +++ b/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll @@ -4,8 +4,8 @@ ; ScalarEvolution should be able to compute trip count of the loop by proving ; that this is not an infinite loop with side effects. -; CHECK: Determining loop execution counts for: @foo1 -; CHECK: backedge-taken count is ((-1 + %n) /u %s) +; CHECK-LABEL: Determining loop execution counts for: @foo1 +; CHECK: backedge-taken count is ((-1 + (-1 * %s) + (1 umax %s) + %n) /u (1 umax %s)) ; We should have a conservative estimate for the max backedge taken count for ; loops with unknown stride. @@ -34,8 +34,8 @@ ; Check that we are able to compute trip count of a loop without an entry guard. -; CHECK: Determining loop execution counts for: @foo2 -; CHECK: backedge-taken count is ((-1 + (%n smax %s)) /u %s) +; CHECK-LABEL: Determining loop execution counts for: @foo2 +; CHECK: backedge-taken count is ((-1 + (-1 * %s) + (1 umax %s) + (%n smax %s)) /u (1 umax %s)) ; We should have a conservative estimate for the max backedge taken count for ; loops with unknown stride. @@ -61,7 +61,7 @@ ; Check that without mustprogress we don't make assumptions about infinite ; loops being UB. -; CHECK: Determining loop execution counts for: @foo3 +; CHECK-LABEL: Determining loop execution counts for: @foo3 ; CHECK: Loop %for.body: Unpredictable backedge-taken count. ; CHECK: Loop %for.body: Unpredictable max backedge-taken count. @@ -84,8 +84,8 @@ } ; Same as foo2, but with mustprogress on loop, not function -; CHECK: Determining loop execution counts for: @foo4 -; CHECK: backedge-taken count is ((-1 + (%n smax %s)) /u %s) +; CHECK-LABEL: Determining loop execution counts for: @foo4 +; CHECK: backedge-taken count is ((-1 + (-1 * %s) + (1 umax %s) + (%n smax %s)) /u (1 umax %s)) ; CHECK: max backedge-taken count is -1 define void @foo4(i32* nocapture %A, i32 %n, i32 %s) { @@ -106,5 +106,31 @@ ret void } +; A more complex case with pre-increment compare instead of post-increment. +; CHECK-LABEL: Determining loop execution counts for: @foo5 +; CHECK: Loop %for.body: backedge-taken count is ((((-1 * (1 umin ((-1 * %start) + (%n smax %start)))) + (-1 * %start) + (%n smax %start)) /u (1 umax %s)) + (1 umin ((-1 * %start) + (%n smax %start)))) + +; We should have a conservative estimate for the max backedge taken count for +; loops with unknown stride. +; CHECK: max backedge-taken count is -1 + +define void @foo5(i32* nocapture %A, i32 %n, i32 %s, i32 %start) mustprogress { +entry: + br label %for.body + +for.body: ; preds = %entry, %for.body + %i.05 = phi i32 [ %add, %for.body ], [ %start, %entry ] + %arrayidx = getelementptr inbounds i32, i32* %A, i32 %i.05 + %0 = load i32, i32* %arrayidx, align 4 + %inc = add nsw i32 %0, 1 + store i32 %inc, i32* %arrayidx, align 4 + %add = add nsw i32 %i.05, %s + %cmp = icmp slt i32 %i.05, %n + br i1 %cmp, label %for.body, label %for.end + +for.end: ; preds = %for.body, %entry + ret void +} + !8 = distinct !{!8, !9} !9 = !{!"llvm.loop.mustprogress"}