Index: include/llvm/Analysis/BlockFrequencyInfo.h =================================================================== --- include/llvm/Analysis/BlockFrequencyInfo.h +++ include/llvm/Analysis/BlockFrequencyInfo.h @@ -45,6 +45,9 @@ /// floating points. BlockFrequency getBlockFreq(const BasicBlock *BB) const; + // Set the frequency of the given basic block. + void setBlockFreq(const BasicBlock *BB, uint64_t Freq); + /// calculate - compute block frequency info for the given function. void calculate(const Function &F, const BranchProbabilityInfo &BPI, const LoopInfo &LI); Index: include/llvm/Analysis/BlockFrequencyInfoImpl.h =================================================================== --- include/llvm/Analysis/BlockFrequencyInfoImpl.h +++ include/llvm/Analysis/BlockFrequencyInfoImpl.h @@ -477,6 +477,8 @@ BlockFrequency getBlockFreq(const BlockNode &Node) const; + void setBlockFreq(const BlockNode &Node, uint64_t Freq); + raw_ostream &printBlockFreq(raw_ostream &OS, const BlockNode &Node) const; raw_ostream &printBlockFreq(raw_ostream &OS, const BlockFrequency &Freq) const; @@ -913,6 +915,7 @@ BlockFrequency getBlockFreq(const BlockT *BB) const { return BlockFrequencyInfoImplBase::getBlockFreq(getNode(BB)); } + void setBlockFreq(const BlockT *BB, uint64_t Freq); Scaled64 getFloatingBlockFreq(const BlockT *BB) const { return BlockFrequencyInfoImplBase::getFloatingBlockFreq(getNode(BB)); } @@ -965,6 +968,21 @@ finalizeMetrics(); } +template +void BlockFrequencyInfoImpl::setBlockFreq(const BlockT *BB, uint64_t Freq) { + if (Nodes.count(BB)) + BlockFrequencyInfoImplBase::setBlockFreq(getNode(BB), Freq); + else { + // If BB is a newly added block after BFI is done, we need to create a new + // BlockNode for it assigned with a new index. The index can be determined + // by the size of Freqs. + BlockNode NewNode(Freqs.size()); + Nodes[BB] = NewNode; + Freqs.emplace_back(); + BlockFrequencyInfoImplBase::setBlockFreq(NewNode, Freq); + } +} + template void BlockFrequencyInfoImpl::initializeRPOT() { const BlockT *Entry = F->begin(); RPOT.reserve(F->size()); Index: include/llvm/Analysis/BranchProbabilityInfo.h =================================================================== --- include/llvm/Analysis/BranchProbabilityInfo.h +++ include/llvm/Analysis/BranchProbabilityInfo.h @@ -108,6 +108,15 @@ void setEdgeWeight(const BasicBlock *Src, unsigned IndexInSuccessors, uint32_t Weight); + /// \brief Set the raw edge weight for a given edge. + /// + /// In this function the index of Dst in Src's successors will be calculated + /// on the fly. Here we use std::pair<> to represent an edge to avoid + /// ambiguous call with the other setEdgeWeight(). + void + setEdgeWeight(const std::pair &Edge, + uint32_t Weight); + static uint32_t getBranchWeightStackProtector(bool IsLikely) { return IsLikely ? (1u << 20) - 1 : 1; } Index: lib/Analysis/BlockFrequencyInfo.cpp =================================================================== --- lib/Analysis/BlockFrequencyInfo.cpp +++ lib/Analysis/BlockFrequencyInfo.cpp @@ -129,6 +129,12 @@ return BFI ? BFI->getBlockFreq(BB) : 0; } +void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, + uint64_t Freq) { + assert(BFI && "Expected analysis to be available"); + BFI->setBlockFreq(BB, Freq); +} + /// Pop up a ghostview window with the current block frequency propagation /// rendered using dot. void BlockFrequencyInfo::view() const { Index: lib/Analysis/BlockFrequencyInfoImpl.cpp =================================================================== --- lib/Analysis/BlockFrequencyInfoImpl.cpp +++ lib/Analysis/BlockFrequencyInfoImpl.cpp @@ -530,6 +530,13 @@ return Freqs[Node.Index].Scaled; } +void BlockFrequencyInfoImplBase::setBlockFreq(const BlockNode &Node, + uint64_t Freq) { + assert(Node.isValid() && "Expected valid node"); + assert(Node.Index < Freqs.size() && "Expected legal index"); + Freqs[Node.Index].Integer = Freq; +} + std::string BlockFrequencyInfoImplBase::getBlockName(const BlockNode &Node) const { return std::string(); Index: lib/Analysis/BranchProbabilityInfo.cpp =================================================================== --- lib/Analysis/BranchProbabilityInfo.cpp +++ lib/Analysis/BranchProbabilityInfo.cpp @@ -604,7 +604,7 @@ Weight += MapI->second; } } - return (!FoundWeight) ? DEFAULT_WEIGHT : Weight; + return FoundWeight ? Weight : DEFAULT_WEIGHT; } /// Set the edge weight for a given edge specified by PredBlock and an index @@ -618,6 +618,24 @@ << Weight << "\n"); } +/// Set the edge weight for a given edge specified by Src and Dst. +void BranchProbabilityInfo::setEdgeWeight( + const std::pair &Edge, + uint32_t Weight) { + unsigned Index; + bool FoundEdge = false; + for (auto I = succ_begin(Edge.first), E = succ_end(Edge.first); + I != E; ++I) { + if (*I == Edge.second) { + Index = I.getSuccessorIndex(); + FoundEdge = true; + break; + } + } + assert(FoundEdge && "Could not find Dst in Src's successors"); + setEdgeWeight(Edge.first, Index, Weight); +} + /// Get an edge's probability, relative to other out-edges from Src. BranchProbability BranchProbabilityInfo:: getEdgeProbability(const BasicBlock *Src, unsigned IndexInSuccessors) const { Index: lib/Transforms/Scalar/JumpThreading.cpp =================================================================== --- lib/Transforms/Scalar/JumpThreading.cpp +++ lib/Transforms/Scalar/JumpThreading.cpp @@ -19,14 +19,19 @@ #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LazyValueInfo.h" #include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" @@ -80,6 +85,9 @@ class JumpThreading : public FunctionPass { TargetLibraryInfo *TLI; LazyValueInfo *LVI; + std::unique_ptr BFI; + std::unique_ptr BPI; + bool HasProfileData; #ifdef NDEBUG SmallPtrSet LoopHeaders; #else @@ -137,6 +145,10 @@ bool SimplifyPartiallyRedundantLoad(LoadInst *LI); bool TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB); + BasicBlock *SplitBlockPreds(BasicBlock *BB, ArrayRef Preds, + const char *Suffix); + void UpdateEdgeWeight(BasicBlock *PredBB, BasicBlock *BB, + BasicBlock *SuccBB); }; } @@ -160,6 +172,14 @@ DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n"); TLI = &getAnalysis().getTLI(); LVI = &getAnalysis(); + BFI.reset(); + BPI.reset(); + HasProfileData = F.getEntryCount().hasValue(); + if (HasProfileData) { + LoopInfo LI{DominatorTree(F)}; + BPI.reset(new BranchProbabilityInfo(F, LI)); + BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); + } // Remove unreachable blocks from function as they may result in infinite // loop. We do threading if we found something profitable. Jump threading a @@ -968,8 +988,7 @@ } // Split them out to their own block. - UnavailablePred = - SplitBlockPredecessors(LoadBB, PredsToSplit, "thread-pre-split"); + UnavailablePred = SplitBlockPreds(LoadBB, PredsToSplit, "thread-pre-split"); } // If the value isn't available in all predecessors, then there will be @@ -1394,7 +1413,7 @@ else { DEBUG(dbgs() << " Factoring out " << PredBBs.size() << " common predecessors.\n"); - PredBB = SplitBlockPredecessors(BB, PredBBs, ".thr_comm"); + PredBB = SplitBlockPreds(BB, PredBBs, ".thr_comm"); } // And finally, do it! @@ -1415,6 +1434,11 @@ BB->getParent(), BB); NewBB->moveAfter(PredBB); + // Set the block frequency of NewBB, which is identical to the frequency of + // PredBB. + if (HasProfileData) + BFI->setBlockFreq(NewBB, BFI->getBlockFreq(PredBB).getFrequency()); + BasicBlock::iterator BI = BB->begin(); for (; PHINode *PN = dyn_cast(BI); ++BI) ValueMapping[PN] = PN->getIncomingValueForBlock(PredBB); @@ -1438,7 +1462,7 @@ // We didn't copy the terminator from BB over to NewBB, because there is now // an unconditional jump to SuccBB. Insert the unconditional jump. - BranchInst *NewBI =BranchInst::Create(SuccBB, NewBB); + BranchInst *NewBI = BranchInst::Create(SuccBB, NewBB); NewBI->setDebugLoc(BB->getTerminator()->getDebugLoc()); // Check to see if SuccBB has PHI nodes. If so, we need to add entries to the @@ -1499,11 +1523,92 @@ // frequently happens because of phi translation. SimplifyInstructionsInBlock(NewBB, TLI); + // Update the edge weight from BB to SuccBB, which should be less than before. + UpdateEdgeWeight(PredBB, BB, SuccBB); + // Threaded an edge! ++NumThreads; return true; } +/// Create a new basic block that will be the predecessor of BB and successor of +/// all blocks in Preds. When profile data is availble, update the frequency of +/// this new block. +BasicBlock *JumpThreading::SplitBlockPreds(BasicBlock *BB, + ArrayRef Preds, + const char *Suffix) { + // Collect the frequencies of all predecessors of BB, which will be used to + // update the edge weight on BB->SuccBB. + BlockFrequency PredBBFreq(0); + if (HasProfileData) + for (auto Pred : Preds) + PredBBFreq += BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, BB); + + BasicBlock *PredBB = SplitBlockPredecessors(BB, Preds, Suffix); + + // Set the block frequency of the newly created PredBB, which is the sum of + // frequencies of Preds. + if (HasProfileData) + BFI->setBlockFreq(PredBB, PredBBFreq.getFrequency()); + return PredBB; +} + +/// Update the branch weight metadata on the edge BB->SuccBB. This is done by +/// scaling the weight of BB->SuccBB by 1 - Freq(PredBB->BB) / Freq(BB->SuccBB). +void JumpThreading::UpdateEdgeWeight(BasicBlock *PredBB, BasicBlock *BB, + BasicBlock *SuccBB) { + if (!HasProfileData) + return; + + // As the edge from PredBB to BB is deleted, we have to update the block + // frequency of BB. + auto BBFreq = BFI->getBlockFreq(BB); + auto PredBBFreq = BFI->getBlockFreq(PredBB); + // Theoretically PredBBFreq should not be greater than BBFreq. However, in + // some cases this may happen when the frequencies are inccorectly + // distributed to CFG nodes. In this case we just set the new frequency of + // BB to zero. + uint64_t NewBBFreq = (PredBBFreq <= BBFreq) + ? BBFreq.getFrequency() - PredBBFreq.getFrequency() + : 0; + BFI->setBlockFreq(BB, NewBBFreq); + + SmallVector Weights; + auto TI = BB->getTerminator(); + if (MDNode *MD = TI->getMetadata(LLVMContext::MD_prof)) { + for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e; ++MD_i) { + ConstantInt *CI = mdconst::extract(MD->getOperand(MD_i)); + auto W = CI->getValue().getZExtValue(); + if (TI->getSuccessor(MD_i - 1) == SuccBB) { + auto BBFreq = + BFI->getBlockFreq(BB) * BPI->getEdgeProbability(BB, SuccBB); + if (BBFreq.getFrequency() > 0) { + DEBUG(dbgs() << " Update the weight on the edge from " + << BB->getName() << " to " << SuccBB->getName() << "\n"); + if (PredBBFreq <= BBFreq) + // Use the multiplication between BlockFrequency and + // BranchProbability to avoid overflow. + W -= (BlockFrequency(W) * + BranchProbability(PredBBFreq.getFrequency(), + BBFreq.getFrequency())) + .getFrequency(); + else + W = 0; + + // Update BPI. + BPI->setEdgeWeight({BB, SuccBB}, W); + } + } + Weights.push_back(W); + } + if (Weights.size() >= 2) { + TI->setMetadata(LLVMContext::MD_prof, + MDBuilder(TI->getParent()->getContext()) + .createBranchWeights(Weights)); + } + } +} + /// DuplicateCondBranchOnPHIIntoPred - PredBB contains an unconditional branch /// to BB which contains an i1 PHI node and a conditional branch on that PHI. /// If we can duplicate the contents of BB up into PredBB do so now, this @@ -1537,7 +1642,7 @@ else { DEBUG(dbgs() << " Factoring out " << PredBBs.size() << " common predecessors.\n"); - PredBB = SplitBlockPredecessors(BB, PredBBs, ".thr_comm"); + PredBB = SplitBlockPreds(BB, PredBBs, ".thr_comm"); } // Okay, we decided to do this! Clone all the instructions in BB onto the end Index: test/Transforms/JumpThreading/update-edge-weight.ll =================================================================== --- /dev/null +++ test/Transforms/JumpThreading/update-edge-weight.ll @@ -0,0 +1,42 @@ +; RUN: opt -S -jump-threading < %s | FileCheck %s + +; Test if edge weights are properly updated after jump threading. + +; CHECK: !2 = !{!"branch_weights", i32 0, i32 5} +define void @foo(i32 %n) !prof !0 { +entry: + %cmp = icmp sgt i32 %n, 10 + br i1 %cmp, label %if.then.1, label %if.else.1, !prof !1 + +if.then.1: + tail call void @a() + br label %if.cond + +if.else.1: + tail call void @b() + br label %if.cond + +if.cond: + %cmp1 = icmp sgt i32 %n, 5 + br i1 %cmp1, label %if.then.2, label %if.else.2, !prof !2 + +if.then.2: + tail call void @c() + br label %if.end + +if.else.2: + tail call void @d() + br label %if.end + +if.end: + ret void +} + +declare void @a() +declare void @b() +declare void @c() +declare void @d() + +!0 = !{!"function_entry_count", i64 1} +!1 = !{!"branch_weights", i32 10, i32 1} +!2 = !{!"branch_weights", i32 10, i32 5}