diff --git a/llvm/include/llvm/Analysis/BranchProbabilityInfo.h b/llvm/include/llvm/Analysis/BranchProbabilityInfo.h --- a/llvm/include/llvm/Analysis/BranchProbabilityInfo.h +++ b/llvm/include/llvm/Analysis/BranchProbabilityInfo.h @@ -62,7 +62,8 @@ } BranchProbabilityInfo(BranchProbabilityInfo &&Arg) - : Probs(std::move(Arg.Probs)), LastF(Arg.LastF), + : Probs(std::move(Arg.Probs)), MaxSuccIdx(std::move(Arg.MaxSuccIdx)), + LastF(Arg.LastF), PostDominatedByUnreachable(std::move(Arg.PostDominatedByUnreachable)), PostDominatedByColdCall(std::move(Arg.PostDominatedByColdCall)) {} @@ -72,6 +73,7 @@ BranchProbabilityInfo &operator=(BranchProbabilityInfo &&RHS) { releaseMemory(); Probs = std::move(RHS.Probs); + MaxSuccIdx = std::move(RHS.MaxSuccIdx); PostDominatedByColdCall = std::move(RHS.PostDominatedByColdCall); PostDominatedByUnreachable = std::move(RHS.PostDominatedByUnreachable); return *this; @@ -273,6 +275,9 @@ DenseMap Probs; + // The maximum successor index ever entered for a given basic block. + DenseMap MaxSuccIdx; + /// Track the last function we run over for printing. const Function *LastF = nullptr; diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp --- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp +++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp @@ -1031,6 +1031,7 @@ void BranchProbabilityInfo::releaseMemory() { Probs.clear(); + MaxSuccIdx.clear(); Handles.clear(); } @@ -1136,6 +1137,11 @@ LLVM_DEBUG(dbgs() << "set edge " << Src->getName() << " -> " << IndexInSuccessors << " successor probability to " << Prob << "\n"); + + if (MaxSuccIdx.find(Src) == MaxSuccIdx.end()) + MaxSuccIdx[Src] = IndexInSuccessors; + else + MaxSuccIdx[Src] = std::max(MaxSuccIdx[Src], IndexInSuccessors); } /// Set the edge probability for all edges at once. @@ -1173,11 +1179,16 @@ } void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) { - for (const_succ_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { - auto MapI = Probs.find(std::make_pair(BB, I.getSuccessorIndex())); + auto It = MaxSuccIdx.find(BB); + if (It == MaxSuccIdx.end()) + return; + + for (unsigned I = 0, E = It->second; I <= E; ++I) { + auto MapI = Probs.find(std::make_pair(BB, I)); if (MapI != Probs.end()) Probs.erase(MapI); } + MaxSuccIdx.erase(BB); } void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI,