diff --git a/llvm/include/llvm/Analysis/BranchProbabilityInfo.h b/llvm/include/llvm/Analysis/BranchProbabilityInfo.h --- a/llvm/include/llvm/Analysis/BranchProbabilityInfo.h +++ b/llvm/include/llvm/Analysis/BranchProbabilityInfo.h @@ -189,6 +189,9 @@ /// unset for source. void copyEdgeProbabilities(BasicBlock *Src, BasicBlock *Dst); + /// Swap outgoing edges probabilities for \p Src with branch terminator + void swapSuccEdgesProbabilities(const BasicBlock *Src); + static BranchProbability getBranchProbStackProtector(bool IsLikely) { static const BranchProbability LikelyProb((1u << 20) - 1, 1u << 20); return IsLikely ? LikelyProb : LikelyProb.getCompl(); diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp --- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp +++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp @@ -1175,6 +1175,14 @@ } } +void BranchProbabilityInfo::swapSuccEdgesProbabilities(const BasicBlock *Src) { + assert(Src->getTerminator()->getNumSuccessors() == 2); + if (!Probs.contains(std::make_pair(Src, 0))) + return; // No probability is set for edges from Src + assert(Probs.contains(std::make_pair(Src, 1))); + std::swap(Probs[std::make_pair(Src, 0)], Probs[std::make_pair(Src, 1)]); +} + raw_ostream & BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS, const BasicBlock *Src, diff --git a/llvm/unittests/Analysis/BranchProbabilityInfoTest.cpp b/llvm/unittests/Analysis/BranchProbabilityInfoTest.cpp --- a/llvm/unittests/Analysis/BranchProbabilityInfoTest.cpp +++ b/llvm/unittests/Analysis/BranchProbabilityInfoTest.cpp @@ -116,9 +116,10 @@ EXPECT_LT(ProbEdge0, ProbEdge1); Branch->swapSuccessors(); + BPI->swapSuccEdgesProbabilities(LoopHeaderBB); // TODO: Check the probabilities are swapped as well as the edges - EXPECT_EQ(ProbEdge0, BPI->getEdgeProbability(LoopHeaderBB, 0U)); - EXPECT_EQ(ProbEdge1, BPI->getEdgeProbability(LoopHeaderBB, 1U)); + EXPECT_EQ(ProbEdge0, BPI->getEdgeProbability(LoopHeaderBB, 1U)); + EXPECT_EQ(ProbEdge1, BPI->getEdgeProbability(LoopHeaderBB, 0U)); } } // end anonymous namespace