Index: lib/Transforms/Scalar/JumpThreading.cpp =================================================================== --- lib/Transforms/Scalar/JumpThreading.cpp +++ lib/Transforms/Scalar/JumpThreading.cpp @@ -18,6 +18,8 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -27,6 +29,7 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" @@ -80,6 +83,8 @@ class JumpThreading : public FunctionPass { TargetLibraryInfo *TLI; LazyValueInfo *LVI; + BlockFrequencyInfo *BFI; + BranchProbabilityInfo *BPI; #ifdef NDEBUG SmallPtrSet LoopHeaders; #else @@ -114,6 +119,8 @@ void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired(); AU.addPreserved(); + AU.addRequired(); + AU.addRequired(); AU.addRequired(); } @@ -144,6 +151,8 @@ INITIALIZE_PASS_BEGIN(JumpThreading, "jump-threading", "Jump Threading", false, false) INITIALIZE_PASS_DEPENDENCY(LazyValueInfo) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfo) +INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfo) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(JumpThreading, "jump-threading", "Jump Threading", false, false) @@ -160,6 +169,8 @@ DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n"); TLI = &getAnalysis().getTLI(); LVI = &getAnalysis(); + BFI = &getAnalysis(); + BPI = &getAnalysis(); // Remove unreachable blocks from function as they may result in infinite // loop. We do threading if we found something profitable. Jump threading a @@ -1387,6 +1398,12 @@ return false; } + // Collect the frequencies of all predecessors of BB, which will be used to + // update the edge weight on BB->SuccBB. + BlockFrequency PredBBFreq(0); + for (auto Pred : PredBBs) + PredBBFreq += BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, BB); + // And finally, do it! Start by factoring the predecessors is needed. BasicBlock *PredBB; if (PredBBs.size() == 1) @@ -1499,6 +1516,34 @@ // frequently happens because of phi translation. SimplifyInstructionsInBlock(NewBB, TLI); + // We need to update the branch weight metadata on the edge BB->SuccBB. This + // is done by scaling the weight of BB->SuccBB by 1 - Freq(PredBB->BB) / + // Freq(BB->SuccBB). + SmallVector Weights; + auto TI = BB->getTerminator(); + if (MDNode *MD = TI->getMetadata(LLVMContext::MD_prof)) { + for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e; ++MD_i) { + ConstantInt *CI = mdconst::extract(MD->getOperand(MD_i)); + auto W = CI->getValue().getZExtValue(); + if (TI->getSuccessor(MD_i - 1) == SuccBB) { + auto BBFreq = + BFI->getBlockFreq(BB) * BPI->getEdgeProbability(BB, SuccBB); + if (BBFreq.getFrequency() > 0) { + // Use the multiplication between BlockFrequency and BranchProbability + // to avoid overflow. + W -= (BlockFrequency(W) * BranchProbability(PredBBFreq.getFrequency(), + BBFreq.getFrequency())) + .getFrequency(); + } + } + Weights.push_back(W); + } + if (Weights.size() >= 2) + TI->setMetadata(LLVMContext::MD_prof, + MDBuilder(TI->getParent()->getContext()) + .createBranchWeights(Weights)); + } + // Threaded an edge! ++NumThreads; return true;