Index: llvm/include/llvm/Analysis/ScalarEvolution.h =================================================================== --- llvm/include/llvm/Analysis/ScalarEvolution.h +++ llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1378,6 +1378,8 @@ /// includes an exact count and a maximum count. /// class BackedgeTakenInfo { + friend class ScalarEvolution; + /// A list of computable exits and their not-taken counts. Loops almost /// never have more than one computable exit. SmallVector ExitNotTaken; @@ -1398,9 +1400,6 @@ /// True iff the backedge is taken either exactly Max or zero times. bool MaxOrZero = false; - /// SCEV expressions used in any of the ExitNotTakenInfo counts. - SmallPtrSet Operands; - bool isComplete() const { return IsComplete; } const SCEV *getConstantMax() const { return ConstantMax; } @@ -1466,10 +1465,6 @@ /// Return true if the number of times this backedge is taken is either the /// value returned by getConstantMax or zero. bool isConstantMaxOrZero(ScalarEvolution *SE) const; - - /// Return true if any backedge taken count expressions refer to the given - /// subexpression. - bool hasOperand(const SCEV *S) const; }; /// Cache the backedge-taken count of the loops for this function as they @@ -1480,6 +1475,10 @@ /// function as they are computed. DenseMap PredicatedBackedgeTakenCounts; + /// Loops whose backedge taken counts directly use this non-constant SCEV. + DenseMap, 4>> + BECountUsers; + /// This map contains entries for all of the PHI instructions that we /// attempt to compute constant evolutions for. This allows us to avoid /// potentially expensive recomputation of these properties. An instruction @@ -1906,6 +1905,9 @@ bool splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R, SCEV::NoWrapFlags &Flags); + /// Forget predicated/non-predicated backedge taken counts for the given loop. + void forgetBackedgeTakenCounts(const Loop *L, bool Predicated); + /// Drop memoized information for all \p SCEVs. void forgetMemoizedResults(ArrayRef SCEVs); Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -7603,6 +7603,7 @@ // result. BackedgeTakenCounts.clear(); PredicatedBackedgeTakenCounts.clear(); + BECountUsers.clear(); LoopPropertiesCache.clear(); ConstantEvolutionLoopExitValue.clear(); ValueExprMap.clear(); @@ -7628,8 +7629,8 @@ auto *CurrL = LoopWorklist.pop_back_val(); // Drop any stored trip count value. - BackedgeTakenCounts.erase(CurrL); - PredicatedBackedgeTakenCounts.erase(CurrL); + forgetBackedgeTakenCounts(CurrL, /* Predicated */ false); + forgetBackedgeTakenCounts(CurrL, /* Predicated */ true); // Drop information about predicated SCEV rewrites for this loop. for (auto I = PredicatedSCEVRewrites.begin(); @@ -7803,10 +7804,6 @@ return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue); } -bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S) const { - return Operands.contains(S); -} - ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E) : ExitLimit(E, E, false, None) { } @@ -7847,19 +7844,6 @@ : ExitLimit(E, M, MaxOrZero, None) { } -class SCEVRecordOperands { - SmallPtrSetImpl &Operands; - -public: - SCEVRecordOperands(SmallPtrSetImpl &Operands) - : Operands(Operands) {} - bool follow(const SCEV *S) { - Operands.insert(S); - return true; - } - bool isDone() { return false; } -}; - /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each /// computable exit into a persistent ExitNotTakenInfo array. ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( @@ -7888,14 +7872,6 @@ assert((isa(ConstantMax) || isa(ConstantMax)) && "No point in having a non-constant max backedge taken count!"); - - SCEVRecordOperands RecordOperands(Operands); - SCEVTraversal ST(RecordOperands); - if (!isa(ConstantMax)) - ST.visitAll(ConstantMax); - for (auto &ENT : ExitNotTaken) - if (!isa(ENT.ExactNotTaken)) - ST.visitAll(ENT.ExactNotTaken); } /// Compute the number of times the backedge of the specified loop will execute. @@ -7941,8 +7917,11 @@ // We couldn't compute an exact value for this exit, so // we won't be able to compute an exact value for the loop. CouldComputeBECount = false; - else + else { ExitCounts.emplace_back(ExitBB, EL); + if (!isa(EL.ExactNotTaken)) + BECountUsers[EL.ExactNotTaken].insert({L, AllowPredicates}); + } // 2. Derive the loop's MaxBECount from each exit's max number of // non-exiting iterations. Partition the loop exits into two kinds: @@ -12877,6 +12856,23 @@ return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; }); } +void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L, + bool Predicated) { + auto &BECounts = + Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts; + auto It = BECounts.find(L); + if (It != BECounts.end()) { + for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) { + if (!isa(ENT.ExactNotTaken)) { + auto UserIt = BECountUsers.find(ENT.ExactNotTaken); + assert(UserIt != BECountUsers.end()); + UserIt->second.erase({L, Predicated}); + } + } + BECounts.erase(It); + } +} + void ScalarEvolution::forgetMemoizedResults(ArrayRef SCEVs) { SmallPtrSet ToForget(SCEVs.begin(), SCEVs.end()); SmallVector Worklist(ToForget.begin(), ToForget.end()); @@ -12901,21 +12897,6 @@ else ++I; } - - auto RemoveSCEVFromBackedgeMap = [&ToForget]( - DenseMap &Map) { - for (auto I = Map.begin(), E = Map.end(); I != E;) { - BackedgeTakenInfo &BEInfo = I->second; - if (any_of(ToForget, - [&BEInfo](const SCEV *S) { return BEInfo.hasOperand(S); })) - Map.erase(I++); - else - ++I; - } - }; - - RemoveSCEVFromBackedgeMap(BackedgeTakenCounts); - RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts); } void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) { @@ -12938,6 +12919,15 @@ } ExprValueMap.erase(ExprIt); } + + auto BEUsersIt = BECountUsers.find(S); + if (BEUsersIt != BECountUsers.end()) { + // Work on a copy, as forgetBackedgeTakenCounts() will modify the original. + auto Copy = BEUsersIt->second; + for (const auto &Pair : Copy) + forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt()); + BECountUsers.erase(BEUsersIt); + } } void