Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -23,6 +23,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/ChunkedList.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/FoldingSet.h" @@ -1761,10 +1762,18 @@ const SCEV *getOrCreateMulExpr(SmallVectorImpl &Ops, SCEV::NoWrapFlags Flags); + /// Find all of the loops transitively used in \p S, and update \c LoopUsers + /// accordingly. + void addToLoopUseLists(const SCEV *S); + FoldingSet UniqueSCEVs; FoldingSet UniquePreds; BumpPtrAllocator SCEVAllocator; + /// This maps loops to a list of SCEV expressions that (transitively) use said + /// loop. + DenseMap> LoopUsers; + /// Cache tentative mappings from UnknownSCEVs in a Loop, to a SCEV expression /// they can be rewritten into under certain predicates. DenseMap, Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -1290,6 +1290,7 @@ SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); return S; } @@ -1580,6 +1581,7 @@ SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); return S; } @@ -1766,6 +1768,7 @@ SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); return S; } @@ -1803,6 +1806,7 @@ SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); return S; } @@ -2014,6 +2018,7 @@ SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); return S; } @@ -2662,6 +2667,7 @@ S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); } S->setNoWrapFlags(Flags); return S; @@ -2683,6 +2689,7 @@ S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); } S->setNoWrapFlags(Flags); return S; @@ -3135,6 +3142,7 @@ SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator), LHS, RHS); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); return S; } @@ -3315,6 +3323,7 @@ S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Operands.size(), L); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); } S->setNoWrapFlags(Flags); return S; @@ -3470,6 +3479,7 @@ SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); return S; } @@ -3571,6 +3581,7 @@ SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); return S; } @@ -6393,6 +6404,13 @@ ++I; } + auto LoopUsersItr = LoopUsers.find(CurrL); + if (LoopUsersItr != LoopUsers.end()) { + for (auto *S : LoopUsersItr->second) + forgetMemoizedResults(S); + LoopUsers.erase(LoopUsersItr); + } + // Drop information about expressions based on loop-header PHIs. PushLoopPHIs(CurrL, Worklist); @@ -10575,6 +10593,7 @@ UniqueSCEVs(std::move(Arg.UniqueSCEVs)), UniquePreds(std::move(Arg.UniquePreds)), SCEVAllocator(std::move(Arg.SCEVAllocator)), + LoopUsers(std::move(Arg.LoopUsers)), PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)), FirstUnknown(Arg.FirstUnknown) { Arg.FirstUnknown = nullptr; @@ -11017,6 +11036,25 @@ ExitLimits.erase(I); } +void ScalarEvolution::addToLoopUseLists(const SCEV *S) { + struct FindUsedLoops { + SmallPtrSet LoopsUsed; + bool follow(const SCEV *S) { + if (auto *AR = dyn_cast(S)) + LoopsUsed.insert(AR->getLoop()); + return true; + } + + bool isDone() const { return false; } + }; + + FindUsedLoops F; + SCEVTraversal(F).visitAll(S); + + for (auto *L : F.LoopsUsed) + LoopUsers[L].push_back(S); +} + void ScalarEvolution::verify() const { ScalarEvolution &SE = *const_cast(this); ScalarEvolution SE2(F, TLI, AC, DT, LI); Index: unittests/Analysis/ScalarEvolutionTest.cpp =================================================================== --- unittests/Analysis/ScalarEvolutionTest.cpp +++ unittests/Analysis/ScalarEvolutionTest.cpp @@ -856,6 +856,17 @@ EXPECT_TRUE(isa(EC)); EXPECT_EQ(cast(EC)->getAPInt().getLimitedValue(), 999u); + // The add recurrence {5,+,1} does not correspond to any PHI in the IR, and + // that is relevant to this test. + auto *Five = SE.getConstant(APInt(/*numBits=*/64, 5)); + auto *AR = + SE.getAddRecExpr(Five, SE.getOne(T_int64), Loop, SCEV::FlagAnyWrap); + const SCEV *ARAtLoopExit = SE.getSCEVAtScope(AR, nullptr); + EXPECT_FALSE(isa(ARAtLoopExit)); + EXPECT_TRUE(isa(ARAtLoopExit)); + EXPECT_EQ(cast(ARAtLoopExit)->getAPInt().getLimitedValue(), + 1004u); + SE.forgetLoop(Loop); Br->eraseFromParent(); Cond->eraseFromParent(); @@ -868,6 +879,11 @@ EXPECT_FALSE(isa(NewEC)); EXPECT_TRUE(isa(NewEC)); EXPECT_EQ(cast(NewEC)->getAPInt().getLimitedValue(), 1999u); + const SCEV *NewARAtLoopExit = SE.getSCEVAtScope(AR, nullptr); + EXPECT_FALSE(isa(NewARAtLoopExit)); + EXPECT_TRUE(isa(NewARAtLoopExit)); + EXPECT_EQ(cast(NewARAtLoopExit)->getAPInt().getLimitedValue(), + 2004u); } // Make sure that SCEV invalidates exit limits after invalidating the values it