diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -14937,7 +14937,7 @@ /// At the moment only rewriting SCEVUnknown and SCEVZeroExtendExpr is /// supported. class SCEVLoopGuardRewriter : public SCEVRewriteVisitor { - const DenseMap ⤅ + DenseMap ⤅ public: SCEVLoopGuardRewriter(ScalarEvolution &SE, @@ -14958,16 +14958,27 @@ if (I == Map.end()) return SCEVRewriteVisitor::visitZeroExtendExpr( Expr); - return I->second; + Map.erase(Expr); + const SCEV *Result = SCEVRewriteVisitor::visit(I->second); + // It is possible that we visit Expr again and its result is cached. Erase + // it to ensure this visit's result is cached. + RewriteResults.erase(Expr); + return Result; } }; const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { - SmallVector ExprsToRewrite; + DenseMap RewriteMap; + + // Check whether LHS has already been rewritten. In that case we want to + // chain further rewrites onto the already rewritten value. + auto GetRewrittenLHS = [&](const SCEV *LHS) { + auto I = RewriteMap.find(LHS); + return I == RewriteMap.end() ? LHS : I->second; + }; + auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS, - const SCEV *RHS, - DenseMap - &RewriteMap) { + const SCEV *RHS) { // WARNING: It is generally unsound to apply any wrap flags to the proposed // replacement SCEV which isn't directly implied by the structure of that // SCEV. In particular, using contextual facts to imply flags is *NOT* @@ -14983,7 +14994,7 @@ // create this form when combining two checks of the form (X u< C2 + C1) and // (X >=u C1). auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap, - &ExprsToRewrite]() { + &GetRewrittenLHS]() { auto *AddExpr = dyn_cast(LHS); if (!AddExpr || AddExpr->getNumOperands() != 2) return false; @@ -15001,12 +15012,10 @@ // Bail out, unless we have a non-wrapping, monotonic range. if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet()) return false; - auto I = RewriteMap.find(LHSUnknown); - const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown; + const SCEV *RewrittenLHS = GetRewrittenLHS(LHSUnknown); RewriteMap[LHSUnknown] = getUMaxExpr( getConstant(ExactRegion.getUnsignedMin()), getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax()))); - ExprsToRewrite.push_back(LHSUnknown); return true; }; if (MatchRangeCheckIdiom()) @@ -15023,9 +15032,9 @@ const SCEV *URemRHS = nullptr; if (matchURem(LHS, URemLHS, URemRHS)) { if (const SCEVUnknown *LHSUnknown = dyn_cast(URemLHS)) { - const auto *Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS); + const SCEV *RewrittenLHS = GetRewrittenLHS(LHSUnknown); + const auto *Multiple = getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS); RewriteMap[LHSUnknown] = Multiple; - ExprsToRewrite.push_back(LHSUnknown); return; } } @@ -15045,11 +15054,7 @@ if (!isa(LHS) && !isa(LHS)) return; - // Check whether LHS has already been rewritten. In that case we want to - // chain further rewrites onto the already rewritten value. - auto I = RewriteMap.find(LHS); - const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS; - + const SCEV *RewrittenLHS = GetRewrittenLHS(LHS); const SCEV *RewrittenRHS = nullptr; switch (Predicate) { case CmpInst::ICMP_ULT: @@ -15093,11 +15098,8 @@ break; } - if (RewrittenRHS) { + if (RewrittenRHS) RewriteMap[LHS] = RewrittenRHS; - if (LHS == RewrittenLHS) - ExprsToRewrite.push_back(LHS); - } }; BasicBlock *Header = L->getHeader(); @@ -15134,7 +15136,6 @@ // Conditions are processed in reverse order, so the earliest conditions is // processed first. This ensures the SCEVs with the shortest dependency chains // are constructed first. - DenseMap RewriteMap; for (auto [Term, EnterIfTrue] : reverse(Terms)) { SmallVector Worklist; SmallPtrSet Visited; @@ -15149,7 +15150,7 @@ EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate(); const auto *LHS = getSCEV(Cmp->getOperand(0)); const auto *RHS = getSCEV(Cmp->getOperand(1)); - CollectCondition(Predicate, LHS, RHS, RewriteMap); + CollectCondition(Predicate, LHS, RHS); continue; } @@ -15165,18 +15166,6 @@ if (RewriteMap.empty()) return Expr; - // Now that all rewrite information is collect, rewrite the collected - // expressions with the information in the map. This applies information to - // sub-expressions. - if (ExprsToRewrite.size() > 1) { - for (const SCEV *Expr : ExprsToRewrite) { - const SCEV *RewriteTo = RewriteMap[Expr]; - RewriteMap.erase(Expr); - SCEVLoopGuardRewriter Rewriter(*this, RewriteMap); - RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)}); - } - } - SCEVLoopGuardRewriter Rewriter(*this, RewriteMap); return Rewriter.visit(Expr); }