Index: llvm/include/llvm/Analysis/ScalarEvolution.h =================================================================== --- llvm/include/llvm/Analysis/ScalarEvolution.h +++ llvm/include/llvm/Analysis/ScalarEvolution.h @@ -906,11 +906,32 @@ // bodies). void forgetAllLoops(); + /// Defines the type of invalidation on forgetLoop. + /// `Recompute` means that we change some essential facts to the loop (such as + /// exiting blocks), but the loop itself may stay in the IR. It is also the + /// reason that should be used when we just request forgetLoop to get more + /// precise results, or as a safe default option when we don't know whether + /// the loop will stay or not. + /// `WillBreak` means that we are going to break the loop itself (e.g. break + /// its backedge), but its header and some of the contained loops may stay + /// reachable. In this case we can safely drop cached information related to + /// this immediate loop. + /// `WillDelete` means that the header of this loop is no longer reachable, + /// and therefore this loop itself and all its contained loops will be + /// deleted. In this case we may safely drop cached information related to the + /// loop itself and all its subloops recursively. + enum ForgetLoopReason { + Recompute, + WillBreak, + WillDelete, + }; /// This method should be called by the client when it has changed a loop in /// a way that may effect ScalarEvolution's ability to compute a trip count, /// or if the loop is deleted. This call is potentially expensive for large - /// loop bodies. - void forgetLoop(const Loop *L); + /// loop bodies. The parameter \p WillDelete gives a hint on what will happen + /// next to this loop (if we are going to break it or delete with all + /// subloops), and this allows us to manage cached data more carefully. + void forgetLoop(const Loop *L, ForgetLoopReason ForgetReason = Recompute); // This method invokes forgetLoop for the outermost loop of the given loop // \p L, making ScalarEvolution forget about all this subtree. This needs to Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -8040,7 +8040,7 @@ PredicatedSCEVRewrites.clear(); } -void ScalarEvolution::forgetLoop(const Loop *L) { +void ScalarEvolution::forgetLoop(const Loop *L, ForgetLoopReason ForgetReason) { SmallVector LoopWorklist(1, L); SmallVector Worklist; SmallPtrSet Visited; @@ -8068,6 +8068,9 @@ if (LoopUsersItr != LoopUsers.end()) { ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(), LoopUsersItr->second.end()); + // All subloops will be destroyed, so no need to keep their users. + if (ForgetReason == WillDelete) + LoopUsers.erase(CurrL); } // Drop information about expressions based on loop-header PHIs. @@ -8094,6 +8097,9 @@ LoopWorklist.append(CurrL->begin(), CurrL->end()); } forgetMemoizedResults(ToForget); + // This loop will be broken, so no need to store its users. + if (ForgetReason == WillBreak) + LoopUsers.erase(L); } void ScalarEvolution::forgetTopmostLoop(const Loop *L) { Index: llvm/lib/Transforms/Scalar/LoopDeletion.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LoopDeletion.cpp +++ llvm/lib/Transforms/Scalar/LoopDeletion.cpp @@ -460,7 +460,7 @@ // We need to forget the loop before setting the incoming values of the exit // phis to undef, so we properly invalidate the SCEV expressions for those // phis. - SE.forgetLoop(L); + SE.forgetLoop(L, ScalarEvolution::ForgetLoopReason::WillDelete); // Set incoming value to undef for phi nodes in the exit block. for (PHINode &P : ExitBlock->phis()) { std::fill(P.incoming_values().begin(), P.incoming_values().end(), Index: llvm/lib/Transforms/Utils/LoopUtils.cpp =================================================================== --- llvm/lib/Transforms/Utils/LoopUtils.cpp +++ llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -705,7 +705,7 @@ auto *Header = L->getHeader(); Loop *OutermostLoop = getOutermostLoop(L); - SE.forgetLoop(L); + SE.forgetLoop(L, ScalarEvolution::ForgetLoopReason::WillBreak); std::unique_ptr MSSAU; if (MSSA)