Index: llvm/include/llvm/Analysis/BranchProbabilityInfo.h =================================================================== --- llvm/include/llvm/Analysis/BranchProbabilityInfo.h +++ llvm/include/llvm/Analysis/BranchProbabilityInfo.h @@ -189,6 +189,9 @@ /// unset for source. void copyEdgeProbabilities(BasicBlock *Src, BasicBlock *Dst); + /// Swap outgoing edges probabilities for \p Src with branch terminator + void swapSuccEdgesProbabilities(const BasicBlock *Src); + static BranchProbability getBranchProbStackProtector(bool IsLikely) { static const BranchProbability LikelyProb((1u << 20) - 1, 1u << 20); return IsLikely ? LikelyProb : LikelyProb.getCompl(); Index: llvm/lib/Analysis/BranchProbabilityInfo.cpp =================================================================== --- llvm/lib/Analysis/BranchProbabilityInfo.cpp +++ llvm/lib/Analysis/BranchProbabilityInfo.cpp @@ -1175,6 +1175,14 @@ } } +void BranchProbabilityInfo::swapSuccEdgesProbabilities(const BasicBlock *Src) { + assert(Src->getTerminator()->getNumSuccessors() == 2); + if (!Probs.contains(std::make_pair(Src, 0))) + return; // No probability is set for edges from Src + assert(Probs.contains(std::make_pair(Src, 1))); + std::swap(Probs[std::make_pair(Src, 0)], Probs[std::make_pair(Src, 1)]); +} + raw_ostream & BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS, const BasicBlock *Src, Index: llvm/unittests/Analysis/BranchProbabilityInfoTest.cpp =================================================================== --- llvm/unittests/Analysis/BranchProbabilityInfoTest.cpp +++ llvm/unittests/Analysis/BranchProbabilityInfoTest.cpp @@ -116,9 +116,10 @@ EXPECT_LT(ProbEdge0, ProbEdge1); Branch->swapSuccessors(); + BPI->swapSuccEdgesProbabilities(LoopHeaderBB); // TODO: Check the probabilities are swapped as well as the edges - EXPECT_EQ(ProbEdge0, BPI->getEdgeProbability(LoopHeaderBB, 0U)); - EXPECT_EQ(ProbEdge1, BPI->getEdgeProbability(LoopHeaderBB, 1U)); + EXPECT_EQ(ProbEdge0, BPI->getEdgeProbability(LoopHeaderBB, 1U)); + EXPECT_EQ(ProbEdge1, BPI->getEdgeProbability(LoopHeaderBB, 0U)); } } // end anonymous namespace