Index: llvm/trunk/include/llvm/Analysis/BlockFrequencyInfo.h =================================================================== --- llvm/trunk/include/llvm/Analysis/BlockFrequencyInfo.h +++ llvm/trunk/include/llvm/Analysis/BlockFrequencyInfo.h @@ -45,6 +45,9 @@ /// floating points. BlockFrequency getBlockFreq(const BasicBlock *BB) const; + // Set the frequency of the given basic block. + void setBlockFreq(const BasicBlock *BB, uint64_t Freq); + /// calculate - compute block frequency info for the given function. void calculate(const Function &F, const BranchProbabilityInfo &BPI, const LoopInfo &LI); Index: llvm/trunk/include/llvm/Analysis/BlockFrequencyInfoImpl.h =================================================================== --- llvm/trunk/include/llvm/Analysis/BlockFrequencyInfoImpl.h +++ llvm/trunk/include/llvm/Analysis/BlockFrequencyInfoImpl.h @@ -477,6 +477,8 @@ BlockFrequency getBlockFreq(const BlockNode &Node) const; + void setBlockFreq(const BlockNode &Node, uint64_t Freq); + raw_ostream &printBlockFreq(raw_ostream &OS, const BlockNode &Node) const; raw_ostream &printBlockFreq(raw_ostream &OS, const BlockFrequency &Freq) const; @@ -913,6 +915,7 @@ BlockFrequency getBlockFreq(const BlockT *BB) const { return BlockFrequencyInfoImplBase::getBlockFreq(getNode(BB)); } + void setBlockFreq(const BlockT *BB, uint64_t Freq); Scaled64 getFloatingBlockFreq(const BlockT *BB) const { return BlockFrequencyInfoImplBase::getFloatingBlockFreq(getNode(BB)); } @@ -965,6 +968,21 @@ finalizeMetrics(); } +template +void BlockFrequencyInfoImpl::setBlockFreq(const BlockT *BB, uint64_t Freq) { + if (Nodes.count(BB)) + BlockFrequencyInfoImplBase::setBlockFreq(getNode(BB), Freq); + else { + // If BB is a newly added block after BFI is done, we need to create a new + // BlockNode for it assigned with a new index. The index can be determined + // by the size of Freqs. + BlockNode NewNode(Freqs.size()); + Nodes[BB] = NewNode; + Freqs.emplace_back(); + BlockFrequencyInfoImplBase::setBlockFreq(NewNode, Freq); + } +} + template void BlockFrequencyInfoImpl::initializeRPOT() { const BlockT *Entry = &F->front(); RPOT.reserve(F->size()); Index: llvm/trunk/include/llvm/CodeGen/MachineBranchProbabilityInfo.h =================================================================== --- llvm/trunk/include/llvm/CodeGen/MachineBranchProbabilityInfo.h +++ llvm/trunk/include/llvm/CodeGen/MachineBranchProbabilityInfo.h @@ -83,8 +83,35 @@ raw_ostream &printEdgeProbability(raw_ostream &OS, const MachineBasicBlock *Src, const MachineBasicBlock *Dst) const; + + // Normalize a list of weights by scaling them down so that the sum of them + // doesn't exceed UINT32_MAX. Return the scale. + template + static uint32_t normalizeEdgeWeights(WeightListIter Begin, + WeightListIter End); }; +template +uint32_t +MachineBranchProbabilityInfo::normalizeEdgeWeights(WeightListIter Begin, + WeightListIter End) { + // First we compute the sum with 64-bits of precision. + uint64_t Sum = std::accumulate(Begin, End, uint64_t(0)); + + // If the computed sum fits in 32-bits, we're done. + if (Sum <= UINT32_MAX) + return 1; + + // Otherwise, compute the scale necessary to cause the weights to fit, and + // re-sum with that scale applied. + assert((Sum / UINT32_MAX) < UINT32_MAX && + "The sum of weights exceeds UINT32_MAX^2!"); + uint32_t Scale = (Sum / UINT32_MAX) + 1; + for (auto I = Begin; I != End; ++I) + *I /= Scale; + return Scale; +} + } Index: llvm/trunk/lib/Analysis/BlockFrequencyInfo.cpp =================================================================== --- llvm/trunk/lib/Analysis/BlockFrequencyInfo.cpp +++ llvm/trunk/lib/Analysis/BlockFrequencyInfo.cpp @@ -129,6 +129,12 @@ return BFI ? BFI->getBlockFreq(BB) : 0; } +void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, + uint64_t Freq) { + assert(BFI && "Expected analysis to be available"); + BFI->setBlockFreq(BB, Freq); +} + /// Pop up a ghostview window with the current block frequency propagation /// rendered using dot. void BlockFrequencyInfo::view() const { Index: llvm/trunk/lib/Analysis/BlockFrequencyInfoImpl.cpp =================================================================== --- llvm/trunk/lib/Analysis/BlockFrequencyInfoImpl.cpp +++ llvm/trunk/lib/Analysis/BlockFrequencyInfoImpl.cpp @@ -530,6 +530,13 @@ return Freqs[Node.Index].Scaled; } +void BlockFrequencyInfoImplBase::setBlockFreq(const BlockNode &Node, + uint64_t Freq) { + assert(Node.isValid() && "Expected valid node"); + assert(Node.Index < Freqs.size() && "Expected legal index"); + Freqs[Node.Index].Integer = Freq; +} + std::string BlockFrequencyInfoImplBase::getBlockName(const BlockNode &Node) const { return std::string(); Index: llvm/trunk/lib/Transforms/Scalar/JumpThreading.cpp =================================================================== --- llvm/trunk/lib/Transforms/Scalar/JumpThreading.cpp +++ llvm/trunk/lib/Transforms/Scalar/JumpThreading.cpp @@ -20,14 +20,20 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LazyValueInfo.h" #include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/CodeGen/MachineBranchProbabilityInfo.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" @@ -37,6 +43,8 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" +#include +#include using namespace llvm; #define DEBUG_TYPE "jump-threading" @@ -81,6 +89,9 @@ class JumpThreading : public FunctionPass { TargetLibraryInfo *TLI; LazyValueInfo *LVI; + std::unique_ptr BFI; + std::unique_ptr BPI; + bool HasProfileData; #ifdef NDEBUG SmallPtrSet LoopHeaders; #else @@ -119,6 +130,11 @@ AU.addRequired(); } + void releaseMemory() override { + BFI.reset(); + BPI.reset(); + } + void FindLoopHeaders(Function &F); bool ProcessBlock(BasicBlock *BB); bool ThreadEdge(BasicBlock *BB, const SmallVectorImpl &PredBBs, @@ -139,6 +155,12 @@ bool SimplifyPartiallyRedundantLoad(LoadInst *LI); bool TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB); + + private: + BasicBlock *SplitBlockPreds(BasicBlock *BB, ArrayRef Preds, + const char *Suffix); + void UpdateBlockFreqAndEdgeWeight(BasicBlock *PredBB, BasicBlock *BB, + BasicBlock *NewBB, BasicBlock *SuccBB); }; } @@ -162,6 +184,16 @@ DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n"); TLI = &getAnalysis().getTLI(); LVI = &getAnalysis(); + BFI.reset(); + BPI.reset(); + // When profile data is available, we need to update edge weights after + // successful jump threading, which requires both BPI and BFI being available. + HasProfileData = F.getEntryCount().hasValue(); + if (HasProfileData) { + LoopInfo LI{DominatorTree(F)}; + BPI.reset(new BranchProbabilityInfo(F, LI)); + BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); + } // Remove unreachable blocks from function as they may result in infinite // loop. We do threading if we found something profitable. Jump threading a @@ -977,8 +1009,7 @@ } // Split them out to their own block. - UnavailablePred = - SplitBlockPredecessors(LoadBB, PredsToSplit, "thread-pre-split"); + UnavailablePred = SplitBlockPreds(LoadBB, PredsToSplit, "thread-pre-split"); } // If the value isn't available in all predecessors, then there will be @@ -1403,7 +1434,7 @@ else { DEBUG(dbgs() << " Factoring out " << PredBBs.size() << " common predecessors.\n"); - PredBB = SplitBlockPredecessors(BB, PredBBs, ".thr_comm"); + PredBB = SplitBlockPreds(BB, PredBBs, ".thr_comm"); } // And finally, do it! @@ -1424,6 +1455,13 @@ BB->getParent(), BB); NewBB->moveAfter(PredBB); + // Set the block frequency of NewBB. + if (HasProfileData) { + auto NewBBFreq = + BFI->getBlockFreq(PredBB) * BPI->getEdgeProbability(PredBB, BB); + BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); + } + BasicBlock::iterator BI = BB->begin(); for (; PHINode *PN = dyn_cast(BI); ++BI) ValueMapping[PN] = PN->getIncomingValueForBlock(PredBB); @@ -1447,7 +1485,7 @@ // We didn't copy the terminator from BB over to NewBB, because there is now // an unconditional jump to SuccBB. Insert the unconditional jump. - BranchInst *NewBI =BranchInst::Create(SuccBB, NewBB); + BranchInst *NewBI = BranchInst::Create(SuccBB, NewBB); NewBI->setDebugLoc(BB->getTerminator()->getDebugLoc()); // Check to see if SuccBB has PHI nodes. If so, we need to add entries to the @@ -1508,11 +1546,86 @@ // frequently happens because of phi translation. SimplifyInstructionsInBlock(NewBB, TLI); + // Update the edge weight from BB to SuccBB, which should be less than before. + UpdateBlockFreqAndEdgeWeight(PredBB, BB, NewBB, SuccBB); + // Threaded an edge! ++NumThreads; return true; } +/// Create a new basic block that will be the predecessor of BB and successor of +/// all blocks in Preds. When profile data is availble, update the frequency of +/// this new block. +BasicBlock *JumpThreading::SplitBlockPreds(BasicBlock *BB, + ArrayRef Preds, + const char *Suffix) { + // Collect the frequencies of all predecessors of BB, which will be used to + // update the edge weight on BB->SuccBB. + BlockFrequency PredBBFreq(0); + if (HasProfileData) + for (auto Pred : Preds) + PredBBFreq += BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, BB); + + BasicBlock *PredBB = SplitBlockPredecessors(BB, Preds, Suffix); + + // Set the block frequency of the newly created PredBB, which is the sum of + // frequencies of Preds. + if (HasProfileData) + BFI->setBlockFreq(PredBB, PredBBFreq.getFrequency()); + return PredBB; +} + +/// Update the block frequency of BB and branch weight and the metadata on the +/// edge BB->SuccBB. This is done by scaling the weight of BB->SuccBB by 1 - +/// Freq(PredBB->BB) / Freq(BB->SuccBB). +void JumpThreading::UpdateBlockFreqAndEdgeWeight(BasicBlock *PredBB, + BasicBlock *BB, + BasicBlock *NewBB, + BasicBlock *SuccBB) { + if (!HasProfileData) + return; + + assert(BFI && BPI && "BFI & BPI should have been created here"); + + // As the edge from PredBB to BB is deleted, we have to update the block + // frequency of BB. + auto BBOrigFreq = BFI->getBlockFreq(BB); + auto NewBBFreq = BFI->getBlockFreq(NewBB); + auto BB2SuccBBFreq = BBOrigFreq * BPI->getEdgeProbability(BB, SuccBB); + auto BBNewFreq = BBOrigFreq - NewBBFreq; + BFI->setBlockFreq(BB, BBNewFreq.getFrequency()); + + // Collect updated outgoing edges' frequencies from BB and use them to update + // edge weights. + SmallVector BBSuccFreq; + for (auto I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { + auto SuccFreq = (*I == SuccBB) + ? BB2SuccBBFreq - NewBBFreq + : BBOrigFreq * BPI->getEdgeProbability(BB, *I); + BBSuccFreq.push_back(SuccFreq.getFrequency()); + } + + // Normalize edge weights in Weights64 so that the sum of them can fit in + MachineBranchProbabilityInfo::normalizeEdgeWeights(BBSuccFreq.begin(), + BBSuccFreq.end()); + + SmallVector Weights; + for (auto Freq : BBSuccFreq) + Weights.push_back(static_cast(Freq)); + + // Update edge weights in BPI. + for (int I = 0, E = Weights.size(); I < E; I++) + BPI->setEdgeWeight(BB, I, Weights[I]); + + if (Weights.size() >= 2) { + auto TI = BB->getTerminator(); + TI->setMetadata( + LLVMContext::MD_prof, + MDBuilder(TI->getParent()->getContext()).createBranchWeights(Weights)); + } +} + /// DuplicateCondBranchOnPHIIntoPred - PredBB contains an unconditional branch /// to BB which contains an i1 PHI node and a conditional branch on that PHI. /// If we can duplicate the contents of BB up into PredBB do so now, this @@ -1546,7 +1659,7 @@ else { DEBUG(dbgs() << " Factoring out " << PredBBs.size() << " common predecessors.\n"); - PredBB = SplitBlockPredecessors(BB, PredBBs, ".thr_comm"); + PredBB = SplitBlockPreds(BB, PredBBs, ".thr_comm"); } // Okay, we decided to do this! Clone all the instructions in BB onto the end Index: llvm/trunk/test/Transforms/JumpThreading/update-edge-weight.ll =================================================================== --- llvm/trunk/test/Transforms/JumpThreading/update-edge-weight.ll +++ llvm/trunk/test/Transforms/JumpThreading/update-edge-weight.ll @@ -0,0 +1,43 @@ +; RUN: opt -S -jump-threading %s | FileCheck %s + +; Test if edge weights are properly updated after jump threading. + +; CHECK: !2 = !{!"branch_weights", i32 22, i32 7} + +define void @foo(i32 %n) !prof !0 { +entry: + %cmp = icmp sgt i32 %n, 10 + br i1 %cmp, label %if.then.1, label %if.else.1, !prof !1 + +if.then.1: + tail call void @a() + br label %if.cond + +if.else.1: + tail call void @b() + br label %if.cond + +if.cond: + %cmp1 = icmp sgt i32 %n, 5 + br i1 %cmp1, label %if.then.2, label %if.else.2, !prof !2 + +if.then.2: + tail call void @c() + br label %if.end + +if.else.2: + tail call void @d() + br label %if.end + +if.end: + ret void +} + +declare void @a() +declare void @b() +declare void @c() +declare void @d() + +!0 = !{!"function_entry_count", i64 1} +!1 = !{!"branch_weights", i32 10, i32 5} +!2 = !{!"branch_weights", i32 10, i32 1}