diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -944,7 +944,7 @@ /// /// We don't have a way to invalidate per-loop/per-block dispositions. Clear /// and recompute is simpler. - void forgetBlockAndLoopDispositions(); + void forgetBlockAndLoopDispositions(Value *V = nullptr); /// Determine the minimum number of zero bits that S is guaranteed to end in /// (at every loop iteration). It is, at the same time, the minimum number diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -8384,7 +8384,29 @@ void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); } -void ScalarEvolution::forgetBlockAndLoopDispositions() { +void ScalarEvolution::forgetBlockAndLoopDispositions(Value *V) { + if (V) { + if (const SCEV *S = getExistingSCEV(V)) { + BlockDispositions.erase(S); + SmallVector Worklist = {S}; + SmallPtrSet Seen = {S}; + + while (!Worklist.empty()) { + const SCEV *Curr = Worklist.pop_back_val(); + if (!LoopDispositions.erase(Curr)) + continue; + + auto Users = SCEVUsers.find(Curr); + if (Users != SCEVUsers.end()) + for (const auto *User : Users->second) + if (Seen.insert(User).second) + Worklist.push_back(User); + } + + LoopDispositions.erase(S); + return; + } + } BlockDispositions.clear(); LoopDispositions.clear(); } diff --git a/llvm/lib/Transforms/Scalar/LoopDeletion.cpp b/llvm/lib/Transforms/Scalar/LoopDeletion.cpp --- a/llvm/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDeletion.cpp @@ -89,23 +89,22 @@ if (!AllOutgoingValuesSame) break; + bool InstrMoved = false; if (Instruction *I = dyn_cast(incoming)) { - if (!L->makeLoopInvariant(I, Changed, Preheader->getTerminator())) { + if (!L->makeLoopInvariant(I, InstrMoved, Preheader->getTerminator())) { AllEntriesInvariant = false; break; } - if (Changed) { + Changed |= InstrMoved; + if (InstrMoved) { // Moving I to a different location may change its block disposition, // so invalidate its SCEV. - SE.forgetValue(I); + SE.forgetBlockAndLoopDispositions(I); } } } } - if (Changed) - SE.forgetLoopDispositions(); - if (!AllEntriesInvariant || !AllOutgoingValuesSame) return false; diff --git a/llvm/lib/Transforms/Scalar/LoopSink.cpp b/llvm/lib/Transforms/Scalar/LoopSink.cpp --- a/llvm/lib/Transforms/Scalar/LoopSink.cpp +++ b/llvm/lib/Transforms/Scalar/LoopSink.cpp @@ -312,12 +312,13 @@ if (!canSinkOrHoistInst(I, &AA, &DT, &L, MSSAU, false, LICMFlags)) continue; if (sinkInstruction(L, I, ColdLoopBBs, LoopBlockNumber, LI, DT, BFI, - &MSSAU)) + &MSSAU)) { Changed = true; + if (SE) + SE->forgetBlockAndLoopDispositions(&I); + } } - if (Changed && SE) - SE->forgetLoopDispositions(); return Changed; } diff --git a/llvm/lib/Transforms/Utils/LoopSimplify.cpp b/llvm/lib/Transforms/Utils/LoopSimplify.cpp --- a/llvm/lib/Transforms/Utils/LoopSimplify.cpp +++ b/llvm/lib/Transforms/Utils/LoopSimplify.cpp @@ -647,20 +647,22 @@ Instruction *Inst = &*I++; if (Inst == CI) continue; + bool InstInvariant = false; if (!L->makeLoopInvariant( - Inst, AnyInvariant, + Inst, InstInvariant, Preheader ? Preheader->getTerminator() : nullptr, MSSAU)) { AllInvariant = false; break; } + if (InstInvariant && SE) { + // The loop disposition of all SCEV expressions that depend on any + // hoisted values have also changed. + SE->forgetBlockAndLoopDispositions(Inst); + } + AnyInvariant |= InstInvariant; } - if (AnyInvariant) { + if (AnyInvariant) Changed = true; - // The loop disposition of all SCEV expressions that depend on any - // hoisted values have also changed. - if (SE) - SE->forgetLoopDispositions(); - } if (!AllInvariant) continue; // The block has now been cleared of all instructions except for