diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp --- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp +++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp @@ -1113,11 +1113,10 @@ BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src, unsigned IndexInSuccessors) const { auto I = Probs.find(std::make_pair(Src, IndexInSuccessors)); - assert((Probs.end() == Probs.find(std::make_pair(Src, 0))) == - (Probs.end() == I) && - "Probability for I-th successor must always be defined along with the " - "probability for the first successor"); + // We can be having cases where we have probability defined for first + // successor but not for the I-th successor. This happens when BPI is updated + // lossily through loop passes. if (I != Probs.end()) return I->second; @@ -1139,9 +1138,12 @@ return BranchProbability(llvm::count(successors(Src), Dst), succ_size(Src)); auto Prob = BranchProbability::getZero(); + // We cannot rely on retrieving info in `Probs` since the successor may + // not have been updated into Probs if BPI is not updated when new blocks are + // added through loop passes. for (const_succ_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I) if (*I == Dst) - Prob += Probs.find(std::make_pair(Src, I.getSuccessorIndex()))->second; + Prob += getEdgeProbability(Src, I.getSuccessorIndex()); return Prob; } diff --git a/llvm/unittests/Analysis/BranchProbabilityInfoTest.cpp b/llvm/unittests/Analysis/BranchProbabilityInfoTest.cpp --- a/llvm/unittests/Analysis/BranchProbabilityInfoTest.cpp +++ b/llvm/unittests/Analysis/BranchProbabilityInfoTest.cpp @@ -13,12 +13,15 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/Support/DataTypes.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/BreakCriticalEdges.h" #include "gtest/gtest.h" namespace llvm { @@ -83,5 +86,92 @@ EXPECT_TRUE(BPI.isEdgeHot(EntryBB, ExitBB)); } +static std::unique_ptr +makeLLVMModuleWithCustomFunc(LLVMContext &Context, const char *ModuleStr) { + SMDiagnostic Err; + return parseAssemblyString(ModuleStr, Err, Context); +} + +static BasicBlock *getBasicBlockByName(Function &F, StringRef Name) { + for (BasicBlock &BB : F) + if (BB.getName() == Name) + return &BB; + llvm_unreachable("Expected to find basic block!"); +} + +// We may update BPI lossily in loop passes, which means newly created blocks do +// not have BPI. +TEST_F(BranchProbabilityInfoTest, TestLossyBPIInLoop) { + const char *ModuleStr = + "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" + "define void @foo(i32 %n, i32 %m, i1 %inv_cond) {\n" + "entry:\n" + " br i1 undef, label %for.end, label %for.cond\n" + "for.cond:\n" + " %i.0 = phi i32 [ 0, %entry ], [ %inc, %latch.1 ], [ %inc, %latch.2 ]\n" + " %inc = add nsw i32 %i.0, 1\n" + " %cmp = icmp slt i32 %i.0, %n\n" + " br i1 %inv_cond, label %deadbb, label %latch.1\n" + "deadbb:\n" + " br label %latch.1\n" + "latch.1:\n" + " br i1 %cmp, label %for.cond, label %latch.2\n" + "latch.2:\n" + " %cmp2 = icmp slt i32 %i.0, %m\n" + " br i1 %cmp2, label %for.cond, label %for.end\n" + "for.end:\n" + " %non.lcssa.phi = phi i1 [ %cmp2, %latch.2 ], [ %inv_cond, %entry ]\n" + " ret void\n" + "}\n"; + + // Parse the module. + LLVMContext Context; + std::unique_ptr M = makeLLVMModuleWithCustomFunc(Context, ModuleStr); + Function &F = *(M->getFunction("foo")); + auto *ExitBB = getBasicBlockByName(F, "for.end"); + auto *Header = getBasicBlockByName(F, "for.cond"); + auto *Latch2 = getBasicBlockByName(F, "latch.2"); + auto *Latch1 = getBasicBlockByName(F, "latch.1"); + auto *DeadBB = getBasicBlockByName(F, "deadbb"); + + BranchProbabilityInfo &BPI = buildBPI(F); + + ASSERT_NE(BPI.getEdgeProbability(Latch2, ExitBB), + BranchProbability::getZero()); + ASSERT_NE(BPI.getEdgeProbability(Header, Latch1), + BranchProbability::getZero()); + // Maybe some loop pass decides to preserveLCSSA form and split the critical + // edges to the exit block. + auto *CritExit = SplitCriticalEdge( + Latch2, ExitBB, + CriticalEdgeSplittingOptions().unsetPreserveLoopSimplify()); + // Remove dead block from loop. + auto *OldBr = Header->getTerminator(); + IRBuilder<> Builder(OldBr); + auto *NewBI = Builder.CreateBr(Latch1); + OldBr->eraseFromParent(); + DeadBB->eraseFromParent(); + + // We should not fail here when retrieving the edge probability of newly + // created blocks or blocks whose successors have changed. + EXPECT_NE(BPI.getEdgeProbability(Latch2, CritExit), + BranchProbability::getZero()); + EXPECT_NE(BPI.getEdgeProbability(Header, Latch1), + BranchProbability::getZero()); + EXPECT_NE(BPI.getEdgeProbability(Latch1, Latch2), + BranchProbability::getZero()); + EXPECT_NE(BPI.getEdgeProbability(Latch1, Header), + BranchProbability::getZero()); + EXPECT_NE(BPI.getEdgeProbability(CritExit, ExitBB), + BranchProbability::getZero()); + + BasicBlock *Latch2Pred = + splitBlockBefore(Latch2, Latch2->getTerminator(), nullptr, nullptr, + nullptr, "latch2.pred"); + EXPECT_NE(BPI.getEdgeProbability(Latch2Pred, Latch2), + BranchProbability::getZero()); + EXPECT_NE(BPI.getEdgeProbability(Latch1, Latch2Pred), + BranchProbability::getZero()); +} } // end anonymous namespace } // end namespace llvm