diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -12535,6 +12535,39 @@ } const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { + auto ApplyCondition = [&](ICmpInst *Cmp) { + Value *V = Cmp->getOperand(0); + const SCEV *Op0 = getSCEV(Cmp->getOperand(0)); + const SCEV *Op1 = getSCEV(Cmp->getOperand(1)); + auto Predicate = Cmp->getPredicate(); + if (isa(Op0)) { + std::swap(Op0, Op1); + Predicate = CmpInst::getSwappedPredicate(Predicate); + V = Cmp->getOperand(1); + } + if (!isa(Op0) || !isa(Op1)) + return; + + // TODO: use information from more predicates. + switch (Predicate) { + case CmpInst::ICMP_ULT: { + ValueToSCEVMapTy RewriteMap; + RewriteMap[V] = + getUMinExpr(Op0, getMinusSCEV(Op1, getOne(Op1->getType()))); + Expr = SCEVParameterRewriter::rewrite(Expr, *this, RewriteMap); + break; + } + + case CmpInst::ICMP_EQ: { + ValueToSCEVMapTy RewriteMap; + RewriteMap[V] = getUMaxExpr(Op0, getAddExpr(Op1, getOne(Op1->getType()))); + Expr = SCEVParameterRewriter::rewrite(Expr, *this, RewriteMap); + break; + } + default: + break; + } + }; // Starting at the loop predecessor, climb up the predecessor chain, as long // as there are predecessors that can be found that have unique successors // leading to the original header. @@ -12553,23 +12586,20 @@ if (!Cmp) continue; - // TODO: use information from more predicates. - switch (Cmp->getPredicate()) { - case CmpInst::ICMP_ULT: { - const SCEV *LHS = getSCEV(Cmp->getOperand(0)); - const SCEV *RHS = getSCEV(Cmp->getOperand(1)); - if (isa(LHS)) { - ValueToSCEVMapTy RewriteMap; - RewriteMap[Cmp->getOperand(0)] = - getUMinExpr(LHS, getMinusSCEV(RHS, getOne(RHS->getType()))); - Expr = SCEVParameterRewriter::rewrite(Expr, *this, RewriteMap); - } + ApplyCondition(Cmp); + } - break; - } - default: - break; - } + for (auto &AssumeVH : AC.assumptions()) { + if (!AssumeVH) + continue; + auto *CI = cast(AssumeVH); + if (!DT.dominates(CI, L->getHeader())) + continue; + auto *Cmp = dyn_cast(CI->getOperand(0)); + if (!Cmp) + continue; + + ApplyCondition(Cmp); } return Expr; diff --git a/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info.ll b/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info.ll --- a/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info.ll +++ b/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info.ll @@ -30,7 +30,7 @@ define void @test_guard_and_assume(i32* nocapture readonly %data, i64 %count) { ; CHECK-LABEL: Determining loop execution counts for: @test_guard_and_assume ; CHECK-NEXT: Loop %loop: backedge-taken count is (-1 + %count) -; CHECK-NEXT: Loop %loop: max backedge-taken count is -2 +; CHECK-NEXT: Loop %loop: max backedge-taken count is 3 ; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is (-1 + %count) ; entry: