diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -11761,17 +11761,42 @@ // and End is the RHS. const SCEV *BECountIfBackedgeTaken = computeBECount(getMinusSCEV(End, Start), Stride); - // If the loop entry is guarded by the result of the backedge test of the - // first loop iteration, then we know the backedge will be taken at least - // once and so the backedge taken count is as above. If not then we use the - // expression (max(End,Start)-Start)/Stride to describe the backedge count, - // as if the backedge is taken at least once max(End,Start) is End and so the - // result is as above, and if not max(End,Start) is Start so we get a backedge - // count of zero. - const SCEV *BECount; - if (isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(OrigStart, Stride), OrigRHS)) - BECount = BECountIfBackedgeTaken; - else { + + // We use the expression (max(End,Start)-Start)/Stride to describe the + // backedge count, as if the backedge is taken at least once max(End,Start) + // is End and so the result is as above, and if not max(End,Start) is Start + // so we get a backedge count of zero. + const SCEV *BECount = nullptr; + auto *StartMinusStride = getMinusSCEV(OrigStart, Stride); + // Can we prove (max(RHS,Start) > Start - Stride? + if (isLoopEntryGuardedByCond(L, Cond, StartMinusStride, Start) && + isLoopEntryGuardedByCond(L, Cond, StartMinusStride, RHS)) { + // In this case, we can use a refined formula for computing backedge taken + // count. The general formula remains: + // "End-Start /uceiling Stride" where "End = max(RHS,Start)" + // We want to use the alternate formula: + // "((End - 1) - (Start - Stride)) /u Stride" + // Let's do a quick case analysis to show these are equivalent under + // our precondition that max(RHS,Start) > Start - Stride. + // * For RHS <= Start, the backedge-taken count must be zero. + // "((End - 1) - (Start - Stride)) /u Stride" reduces to + // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to + // "Stride - 1 /u Stride" which is indeed zero for all non-zero values + // of Stride. For 0 stride, we've use umin(1,Stride) above, reducing + // this to the stride of 1 case. + // * For RHS >= Start, the backedge count must be "RHS-Start /uceil Stride". + // "((End - 1) - (Start - Stride)) /u Stride" reduces to + // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to + // "((RHS - (Start - Stride) - 1) /u Stride". + // Our preconditions trivially imply no overflow in that form. + const SCEV *MinusOne = getMinusOne(Stride->getType()); + const SCEV *Numerator = + getMinusSCEV(getAddExpr(RHS, MinusOne), StartMinusStride); + if (!isa(Numerator)) { + BECount = getUDivExpr(Numerator, Stride); + } + } + if (!BECount) { auto canProveRHSGreaterThanEqualStart = [&]() { auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart)) diff --git a/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll b/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll --- a/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll +++ b/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll @@ -1,7 +1,7 @@ ; RUN: opt < %s -analyze -enable-new-pm=0 -scalar-evolution 2>&1 | FileCheck %s ; RUN: opt < %s -disable-output "-passes=print" 2>&1 2>&1 | FileCheck %s -; CHECK: Loop %bb: backedge-taken count is ((999 + (-1 * %x)) /u 3) +; CHECK: Loop %bb: backedge-taken count is ((-1 + (-1 * %x) + (1000 umax (3 + %x))) /u 3) ; CHECK: Loop %bb: max backedge-taken count is 334 diff --git a/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll b/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll --- a/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll +++ b/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll @@ -5,7 +5,7 @@ ; that this is not an infinite loop with side effects. ; CHECK-LABEL: Determining loop execution counts for: @foo1 -; CHECK: backedge-taken count is ((-1 + %n) /u %s) +; CHECK: backedge-taken count is ((-1 + (%n smax %s)) /u %s) ; We should have a conservative estimate for the max backedge taken count for ; loops with unknown stride. diff --git a/llvm/test/Transforms/LoopReroll/nonconst_lb.ll b/llvm/test/Transforms/LoopReroll/nonconst_lb.ll --- a/llvm/test/Transforms/LoopReroll/nonconst_lb.ll +++ b/llvm/test/Transforms/LoopReroll/nonconst_lb.ll @@ -17,22 +17,24 @@ ; CHECK-NEXT: [[CMP34:%.*]] = icmp slt i32 [[M:%.*]], [[N:%.*]] ; CHECK-NEXT: br i1 [[CMP34]], label [[FOR_BODY_PREHEADER:%.*]], label [[FOR_END:%.*]] ; CHECK: for.body.preheader: -; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[N]], -1 -; CHECK-NEXT: [[TMP1:%.*]] = sub i32 [[TMP0]], [[M]] -; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 2 -; CHECK-NEXT: [[TMP3:%.*]] = shl nuw i32 [[TMP2]], 2 -; CHECK-NEXT: [[TMP4:%.*]] = add nuw nsw i32 [[TMP3]], 3 +; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[M]], 4 +; CHECK-NEXT: [[SMAX:%.*]] = call i32 @llvm.smax.i32(i32 [[N]], i32 [[TMP0]]) +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[SMAX]], -1 +; CHECK-NEXT: [[TMP2:%.*]] = sub i32 [[TMP1]], [[M]] +; CHECK-NEXT: [[TMP3:%.*]] = lshr i32 [[TMP2]], 2 +; CHECK-NEXT: [[TMP4:%.*]] = shl nuw i32 [[TMP3]], 2 +; CHECK-NEXT: [[TMP5:%.*]] = add nuw nsw i32 [[TMP4]], 3 ; CHECK-NEXT: br label [[FOR_BODY:%.*]] ; CHECK: for.body: ; CHECK-NEXT: [[INDVAR:%.*]] = phi i32 [ 0, [[FOR_BODY_PREHEADER]] ], [ [[INDVAR_NEXT:%.*]], [[FOR_BODY]] ] -; CHECK-NEXT: [[TMP5:%.*]] = add i32 [[M]], [[INDVAR]] -; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i32, i32* [[B:%.*]], i32 [[TMP5]] -; CHECK-NEXT: [[TMP6:%.*]] = load i32, i32* [[ARRAYIDX]], align 4 -; CHECK-NEXT: [[MUL:%.*]] = shl nsw i32 [[TMP6]], 2 -; CHECK-NEXT: [[ARRAYIDX2:%.*]] = getelementptr inbounds i32, i32* [[A:%.*]], i32 [[TMP5]] +; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[M]], [[INDVAR]] +; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i32, i32* [[B:%.*]], i32 [[TMP6]] +; CHECK-NEXT: [[TMP7:%.*]] = load i32, i32* [[ARRAYIDX]], align 4 +; CHECK-NEXT: [[MUL:%.*]] = shl nsw i32 [[TMP7]], 2 +; CHECK-NEXT: [[ARRAYIDX2:%.*]] = getelementptr inbounds i32, i32* [[A:%.*]], i32 [[TMP6]] ; CHECK-NEXT: store i32 [[MUL]], i32* [[ARRAYIDX2]], align 4 ; CHECK-NEXT: [[INDVAR_NEXT]] = add i32 [[INDVAR]], 1 -; CHECK-NEXT: [[EXITCOND:%.*]] = icmp eq i32 [[INDVAR]], [[TMP4]] +; CHECK-NEXT: [[EXITCOND:%.*]] = icmp eq i32 [[INDVAR]], [[TMP5]] ; CHECK-NEXT: br i1 [[EXITCOND]], label [[FOR_END_LOOPEXIT:%.*]], label [[FOR_BODY]] ; CHECK: for.end.loopexit: ; CHECK-NEXT: br label [[FOR_END]]