Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -1856,6 +1856,9 @@ /// This maps loops to a list of SCEV expressions that (transitively) use said /// loop. DenseMap> LoopUsers; + /// The inverse of LoopUsers map; maps a SCEV expression to a set of + /// (transitively) referenced Loops. + DenseMap> LoopsRefd; /// Cache tentative mappings from UnknownSCEVs in a Loop, to a SCEV expression /// they can be rewritten into under certain predicates. Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -11760,6 +11760,7 @@ ExprValueMap.erase(S); HasRecMap.erase(S); MinTrailingZerosCache.erase(S); + LoopsRefd.erase(S); for (auto I = PredicatedSCEVRewrites.begin(); I != PredicatedSCEVRewrites.end();) { @@ -11790,10 +11791,17 @@ ScalarEvolution::getUsedLoops(const SCEV *S, SmallPtrSetImpl &LoopsUsed) { struct FindUsedLoops { - FindUsedLoops(SmallPtrSetImpl &LoopsUsed) - : LoopsUsed(LoopsUsed) {} + FindUsedLoops(SmallPtrSetImpl &LoopsUsed, ScalarEvolution &SE) + : LoopsUsed(LoopsUsed), SE(SE) {} SmallPtrSetImpl &LoopsUsed; + ScalarEvolution &SE; + bool follow(const SCEV *S) { + auto It = SE.LoopsRefd.find(S); + if (It != SE.LoopsRefd.end() && &It->second != &LoopsUsed) { + LoopsUsed.insert(It->second.begin(), It->second.end()); + return false; + } if (auto *AR = dyn_cast(S)) LoopsUsed.insert(AR->getLoop()); return true; @@ -11802,12 +11810,14 @@ bool isDone() const { return false; } }; - FindUsedLoops F(LoopsUsed); + FindUsedLoops F(LoopsUsed, *this); SCEVTraversal(F).visitAll(S); } void ScalarEvolution::addToLoopUseLists(const SCEV *S) { - SmallPtrSet LoopsUsed; + assert(LoopsRefd.find(S) == LoopsRefd.end() && + "addToLoopUseLists should be called exactly once per every new SCEV"); + SmallPtrSetImpl &LoopsUsed = LoopsRefd[S]; getUsedLoops(S, LoopsUsed); for (auto *L : LoopsUsed) LoopUsers[L].push_back(S);