Index: include/llvm/Analysis/BranchProbabilityInfo.h =================================================================== --- include/llvm/Analysis/BranchProbabilityInfo.h +++ include/llvm/Analysis/BranchProbabilityInfo.h @@ -25,9 +25,9 @@ class LoopInfo; class raw_ostream; -/// \brief Analysis pass providing branch probability information. +/// \brief Analysis providing branch probability information. /// -/// This is a function analysis pass which provides information on the relative +/// This is a function analysis which provides information on the relative /// probabilities of each "edge" in the function's CFG where such an edge is /// defined by a pair (PredBlock and an index in the successors). The /// probability of an edge from one block is always relative to the @@ -37,20 +37,11 @@ /// identify an edge, since we can have multiple edges from Src to Dst. /// As an example, we can have a switch which jumps to Dst with value 0 and /// value 10. -class BranchProbabilityInfo : public FunctionPass { +class BranchProbabilityInfo { public: - static char ID; - - BranchProbabilityInfo() : FunctionPass(ID) { - initializeBranchProbabilityInfoPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override; - bool runOnFunction(Function &F) override; + void releaseMemory(); - void releaseMemory() override; - - void print(raw_ostream &OS, const Module *M = nullptr) const override; + void print(raw_ostream &OS) const; /// \brief Get an edge's probability, relative to other out-edges of the Src. /// @@ -118,6 +109,8 @@ return IsLikely ? (1u << 20) - 1 : 1; } + void calculate(Function &F, const LoopInfo& LI); + private: // Since we allow duplicate edges from one basic block to another, we use // a pair (PredBlock and an index in the successors) to specify an edge. @@ -152,12 +145,33 @@ bool calcMetadataWeights(BasicBlock *BB); bool calcColdCallHeuristics(BasicBlock *BB); bool calcPointerHeuristics(BasicBlock *BB); - bool calcLoopBranchHeuristics(BasicBlock *BB); + bool calcLoopBranchHeuristics(BasicBlock *BB, const LoopInfo &LI); bool calcZeroHeuristics(BasicBlock *BB); bool calcFloatingPointHeuristics(BasicBlock *BB); bool calcInvokeHeuristics(BasicBlock *BB); }; +/// \brief Legacy analysis pass which computes \c BranchProbabilityInfo. +class BranchProbabilityInfoWrapperPass : public FunctionPass { + BranchProbabilityInfo BPI; + +public: + static char ID; + + BranchProbabilityInfoWrapperPass() : FunctionPass(ID) { + initializeBranchProbabilityInfoWrapperPassPass( + *PassRegistry::getPassRegistry()); + } + + BranchProbabilityInfo &getBPI() { return BPI; } + const BranchProbabilityInfo &getBPI() const { return BPI; } + + void getAnalysisUsage(AnalysisUsage &AU) const override; + bool runOnFunction(Function &F) override; + void releaseMemory() override; + void print(raw_ostream &OS, const Module *M = nullptr) const override; +}; + } #endif Index: include/llvm/InitializePasses.h =================================================================== --- include/llvm/InitializePasses.h +++ include/llvm/InitializePasses.h @@ -82,7 +82,7 @@ void initializeBlockFrequencyInfoWrapperPassPass(PassRegistry&); void initializeBoundsCheckingPass(PassRegistry&); void initializeBranchFolderPassPass(PassRegistry&); -void initializeBranchProbabilityInfoPass(PassRegistry&); +void initializeBranchProbabilityInfoWrapperPassPass(PassRegistry&); void initializeBreakCriticalEdgesPass(PassRegistry&); void initializeCallGraphPrinterPass(PassRegistry&); void initializeCallGraphViewerPass(PassRegistry&); Index: lib/Analysis/Analysis.cpp =================================================================== --- lib/Analysis/Analysis.cpp +++ lib/Analysis/Analysis.cpp @@ -28,7 +28,7 @@ initializeNoAAPass(Registry); initializeBasicAliasAnalysisPass(Registry); initializeBlockFrequencyInfoWrapperPassPass(Registry); - initializeBranchProbabilityInfoPass(Registry); + initializeBranchProbabilityInfoWrapperPassPass(Registry); initializeCostModelAnalysisPass(Registry); initializeCFGViewerPass(Registry); initializeCFGPrinterPass(Registry); Index: lib/Analysis/BlockFrequencyInfo.cpp =================================================================== --- lib/Analysis/BlockFrequencyInfo.cpp +++ lib/Analysis/BlockFrequencyInfo.cpp @@ -162,7 +162,7 @@ INITIALIZE_PASS_BEGIN(BlockFrequencyInfoWrapperPass, "block-freq", "Block Frequency Analysis", true, true) -INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfo) +INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_END(BlockFrequencyInfoWrapperPass, "block-freq", "Block Frequency Analysis", true, true) @@ -183,7 +183,7 @@ } void BlockFrequencyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequired(); + AU.addRequired(); AU.addRequired(); AU.setPreservesAll(); } @@ -191,7 +191,8 @@ void BlockFrequencyInfoWrapperPass::releaseMemory() { BFI.releaseMemory(); } bool BlockFrequencyInfoWrapperPass::runOnFunction(Function &F) { - BranchProbabilityInfo &BPI = getAnalysis(); + BranchProbabilityInfo &BPI = + getAnalysis().getBPI(); LoopInfo &LI = getAnalysis().getLoopInfo(); BFI.calculate(F, BPI, LI); return false; Index: lib/Analysis/BranchProbabilityInfo.cpp =================================================================== --- lib/Analysis/BranchProbabilityInfo.cpp +++ lib/Analysis/BranchProbabilityInfo.cpp @@ -27,13 +27,13 @@ #define DEBUG_TYPE "branch-prob" -INITIALIZE_PASS_BEGIN(BranchProbabilityInfo, "branch-prob", +INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob", "Branch Probability Analysis", false, true) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_END(BranchProbabilityInfo, "branch-prob", +INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob", "Branch Probability Analysis", false, true) -char BranchProbabilityInfo::ID = 0; +char BranchProbabilityInfoWrapperPass::ID = 0; // Weights are for internal use only. They are used by heuristics to help to // estimate edges' probability. Example: @@ -319,8 +319,9 @@ // Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges // as taken, exiting edges as not-taken. -bool BranchProbabilityInfo::calcLoopBranchHeuristics(BasicBlock *BB) { - Loop *L = LI->getLoopFor(BB); +bool BranchProbabilityInfo::calcLoopBranchHeuristics(BasicBlock *BB, + const LoopInfo &LI) { + Loop *L = LI.getLoopFor(BB); if (!L) return false; @@ -504,50 +505,11 @@ return true; } -void BranchProbabilityInfo::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequired(); - AU.setPreservesAll(); -} - -bool BranchProbabilityInfo::runOnFunction(Function &F) { - DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName() - << " ----\n\n"); - LastF = &F; // Store the last function we ran on for printing. - LI = &getAnalysis().getLoopInfo(); - assert(PostDominatedByUnreachable.empty()); - assert(PostDominatedByColdCall.empty()); - - // Walk the basic blocks in post-order so that we can build up state about - // the successors of a block iteratively. - for (auto BB : post_order(&F.getEntryBlock())) { - DEBUG(dbgs() << "Computing probabilities for " << BB->getName() << "\n"); - if (calcUnreachableHeuristics(BB)) - continue; - if (calcMetadataWeights(BB)) - continue; - if (calcColdCallHeuristics(BB)) - continue; - if (calcLoopBranchHeuristics(BB)) - continue; - if (calcPointerHeuristics(BB)) - continue; - if (calcZeroHeuristics(BB)) - continue; - if (calcFloatingPointHeuristics(BB)) - continue; - calcInvokeHeuristics(BB); - } - - PostDominatedByUnreachable.clear(); - PostDominatedByColdCall.clear(); - return false; -} - void BranchProbabilityInfo::releaseMemory() { Weights.clear(); } -void BranchProbabilityInfo::print(raw_ostream &OS, const Module *) const { +void BranchProbabilityInfo::print(raw_ostream &OS) const { OS << "---- Branch Probabilities ----\n"; // We print the probabilities from the last function the analysis ran over, // or the function it is currently running over. @@ -688,3 +650,54 @@ return OS; } + +void BranchProbabilityInfo::calculate(Function &F, const LoopInfo& LI) { + DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName() + << " ----\n\n"); + LastF = &F; // Store the last function we ran on for printing. + assert(PostDominatedByUnreachable.empty()); + assert(PostDominatedByColdCall.empty()); + + // Walk the basic blocks in post-order so that we can build up state about + // the successors of a block iteratively. + for (auto BB : post_order(&F.getEntryBlock())) { + DEBUG(dbgs() << "Computing probabilities for " << BB->getName() << "\n"); + if (calcUnreachableHeuristics(BB)) + continue; + if (calcMetadataWeights(BB)) + continue; + if (calcColdCallHeuristics(BB)) + continue; + if (calcLoopBranchHeuristics(BB, LI)) + continue; + if (calcPointerHeuristics(BB)) + continue; + if (calcZeroHeuristics(BB)) + continue; + if (calcFloatingPointHeuristics(BB)) + continue; + calcInvokeHeuristics(BB); + } + + PostDominatedByUnreachable.clear(); + PostDominatedByColdCall.clear(); +} + +void BranchProbabilityInfoWrapperPass::getAnalysisUsage( + AnalysisUsage &AU) const { + AU.addRequired(); + AU.setPreservesAll(); +} + +bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) { + const LoopInfo &LI = getAnalysis().getLoopInfo(); + BPI.calculate(F, LI); + return false; +} + +void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); } + +void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS, + const Module *) const { + BPI.print(OS); +} Index: lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -351,7 +351,8 @@ DAGSize(0) { initializeGCModuleInfoPass(*PassRegistry::getPassRegistry()); initializeAliasAnalysisAnalysisGroup(*PassRegistry::getPassRegistry()); - initializeBranchProbabilityInfoPass(*PassRegistry::getPassRegistry()); + initializeBranchProbabilityInfoWrapperPassPass( + *PassRegistry::getPassRegistry()); initializeTargetLibraryInfoWrapperPassPass( *PassRegistry::getPassRegistry()); } @@ -369,7 +370,7 @@ AU.addPreserved(); AU.addRequired(); if (UseMBPI && OptLevel != CodeGenOpt::None) - AU.addRequired(); + AU.addRequired(); MachineFunctionPass::getAnalysisUsage(AU); } @@ -449,7 +450,7 @@ FuncInfo->set(Fn, *MF, CurDAG); if (UseMBPI && OptLevel != CodeGenOpt::None) - FuncInfo->BPI = &getAnalysis(); + FuncInfo->BPI = &getAnalysis().getBPI(); else FuncInfo->BPI = nullptr; Index: lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp =================================================================== --- lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -215,7 +215,7 @@ AU.addRequiredID(LoopSimplifyID); AU.addRequiredID(LCSSAID); AU.addRequired(); - AU.addRequired(); + AU.addRequired(); } bool runOnLoop(Loop *L, LPPassManager &LPM) override; @@ -1400,7 +1400,8 @@ InductiveRangeCheck::AllocatorTy IRCAlloc; SmallVector RangeChecks; ScalarEvolution &SE = getAnalysis(); - BranchProbabilityInfo &BPI = getAnalysis(); + BranchProbabilityInfo &BPI = + getAnalysis().getBPI(); for (auto BBI : L->getBlocks()) if (BranchInst *TBI = dyn_cast(BBI->getTerminator()))