Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -8351,14 +8351,16 @@ } return S; }; - auto *InnerLHS = LHS; + SmallVector OpStack; + OpStack.push_back(LHS); while (true) { - auto *Temp = peekThroughInvertibleFunctions(InnerLHS); - if (Temp == InnerLHS) + auto *Last = OpStack.back(); + auto *Temp = peekThroughInvertibleFunctions(Last); + if (Temp == Last) break; - InnerLHS = Temp; + OpStack.push_back(Temp); } - if (const SCEVAddRecExpr *AR = dyn_cast(InnerLHS)) { + if (const SCEVAddRecExpr *AR = dyn_cast(OpStack.back())) { auto *StrideC = dyn_cast(AR->getStepRecurrence(*this)); if (!AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() && StrideC && StrideC->getAPInt().isPowerOf2()) { @@ -8367,6 +8369,42 @@ SmallVector Operands{AR->operands()}; Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); setNoWrapFlags(const_cast(AR), Flags); + + // Get a new SCEV (hopefully an addrec), if we know that replacing + // the loop varying operand of S with NewOp is likely to allow + // specialization. Note that this has the side effect of removing + // old SCEVs from the set which breaks equality comparison between + // previously handed out pointers and newly queried ones. + auto getWithVaryingOperandReplaced = + [&](const SCEV *S, const SCEVAddRecExpr *NewOp) { + if (isa(S) && NewOp->hasNoUnsignedWrap()) { + UniqueSCEVs.RemoveNode(const_cast(S)); + return getZeroExtendExpr(NewOp, S->getType()); + } else if (isa(S) && + NewOp->hasNoSignedWrap()) { + UniqueSCEVs.RemoveNode(const_cast(S)); + return getSignExtendExpr(NewOp, S->getType()); + } else if (auto *SAdd = dyn_cast(S)) { + assert(isa(NewOp)); + UniqueSCEVs.RemoveNode(const_cast(S)); + SmallVector Operands; + for (auto *Op : SAdd->operands()) + Operands.push_back(isLoopInvariant(Op, L) ? Op : NewOp); + return getAddExpr(Operands, SAdd->getNoWrapFlags()); + } + return S; + }; + + const SCEVAddRecExpr *Rebuilt = AR; + assert(OpStack.back() == AR); + OpStack.pop_back(); + while (!OpStack.empty() && Rebuilt) { + auto *Next = OpStack.pop_back_val(); + auto *S = getWithVaryingOperandReplaced(Next, Rebuilt); + Rebuilt = dyn_cast(S); + } + if (Rebuilt) + LHS = Rebuilt; } } } Index: llvm/test/Analysis/ScalarEvolution/ne-overflow.ll =================================================================== --- llvm/test/Analysis/ScalarEvolution/ne-overflow.ll +++ llvm/test/Analysis/ScalarEvolution/ne-overflow.ll @@ -236,11 +236,11 @@ define void @test_zext(i64 %N) mustprogress { ; CHECK-LABEL: 'test_zext' ; CHECK-NEXT: Determining loop execution counts for: @test_zext -; CHECK-NEXT: Loop %for.body: Unpredictable backedge-taken count. -; CHECK-NEXT: Loop %for.body: Unpredictable max backedge-taken count. +; CHECK-NEXT: Loop %for.body: backedge-taken count is (%N /u 2) +; CHECK-NEXT: Loop %for.body: max backedge-taken count is 9223372036854775807 ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (%N /u 2) ; CHECK-NEXT: Predicates: -; CHECK-NEXT: {0,+,2}<%for.body> Added Flags: +; CHECK: Loop %for.body: Trip multiple is 1 ; entry: br label %for.body @@ -307,11 +307,11 @@ define void @test_zext_offset(i64 %N) mustprogress { ; CHECK-LABEL: 'test_zext_offset' ; CHECK-NEXT: Determining loop execution counts for: @test_zext_offset -; CHECK-NEXT: Loop %for.body: Unpredictable backedge-taken count. -; CHECK-NEXT: Loop %for.body: Unpredictable max backedge-taken count. +; CHECK-NEXT: Loop %for.body: backedge-taken count is ((-21 + %N) /u 2) +; CHECK-NEXT: Loop %for.body: max backedge-taken count is 9223372036854775807 ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is ((-21 + %N) /u 2) ; CHECK-NEXT: Predicates: -; CHECK-NEXT: {0,+,2}<%for.body> Added Flags: +; CHECK: Loop %for.body: Trip multiple is 1 ; entry: br label %for.body