Index: lib/Transforms/Scalar/JumpThreading.cpp =================================================================== --- lib/Transforms/Scalar/JumpThreading.cpp +++ lib/Transforms/Scalar/JumpThreading.cpp @@ -63,6 +63,11 @@ "condition to use to thread over a weaker condition"), cl::init(3), cl::Hidden); +static cl::opt +BiasedBranchThreshold("jump-threading-biased-branch-threshold", + cl::desc("Threshold in branch weight metadata to identify a biased branch"), + cl::init(1000), cl::Hidden); + namespace { /// This pass performs 'jump threading', which looks at blocks that have /// multiple predecessors and multiple successors. If one or more of the @@ -253,6 +258,130 @@ return EverChanged; } +/// Utility function that reads branch weight metadata for a conditional branch. +/// Returned WeightTrue and WeightFalse are valid only if the funtion returns true. +static bool +getCondBranchWeight(TerminatorInst *TI, uint64_t &WeightTrue, uint64_t &WeightFalse) { + BranchInst *BI = dyn_cast(TI); + if (!(BI && BI->isConditional())) + return false; + + MDNode *ProfMD = BI->getMetadata(LLVMContext::MD_prof); + if (!ProfMD) + return false; + + MDString *MDS = dyn_cast(ProfMD->getOperand(0)); + if (!MDS->getString().equals("branch_weights") || + ProfMD->getNumOperands() != 3) + return false; + + ConstantInt *CI1 = mdconst::extract(ProfMD->getOperand(1)); + ConstantInt *CI2 = mdconst::extract(ProfMD->getOperand(2)); + if (CI1 == NULL || CI2 == NULL) + return false; + + WeightTrue = CI1->getValue().getZExtValue(); + WeightFalse = CI2->getValue().getZExtValue(); + + return true; +} + +/// If the specified conditional branch has branch hint (i.e. biased branch metadata), +/// and the cold edge is the only incoming edge to the target BB, +/// the target BB of the cold edge is a rarely-executed cold block. +/// The threshold for biased branch is controlled by BiasedBranchThreshold parameter. +/// Clang sets branch weights of 1:2000 for the _builtin_expect intrinsic. +static BasicBlock* +findColdBlockBasedOnMetaData(TerminatorInst *TI, + uint64_t &WeightHot, uint64_t &WeightCold) { + BasicBlock* ColdBlock = NULL; + uint64_t WeightTrue, WeightFalse; + if (TI == NULL || !getCondBranchWeight(TI, WeightTrue, WeightFalse)) + return NULL; + + if (WeightTrue > WeightFalse * BiasedBranchThreshold) { + WeightHot = WeightTrue; + WeightCold = WeightFalse; + ColdBlock = TI->getSuccessor(1); + } + else if (WeightFalse > WeightTrue * BiasedBranchThreshold) { + WeightCold = WeightTrue; + WeightHot = WeightFalse; + ColdBlock = TI->getSuccessor(0); + } + if (ColdBlock == NULL || ColdBlock->getSinglePredecessor() == NULL) + return NULL; + + DEBUG(dbgs() << " JT: BB '" << ColdBlock->getName() + << "' is identified as a cold block based on metadata\n "); + DEBUG(TI->getMetadata(LLVMContext::MD_prof)->dump()); + return ColdBlock; +} + +/// If a cold block is known based on a branch hint, we first identify the cold region, +/// i.e. the region post-dominated by the cold block. Then we set the branch hint for the +/// conditional branches that jump into the cold region from non-cold region. +static void +setMetaDataBasedOnColdBlock(BasicBlock* ColdBlock, unsigned WeightHot, unsigned WeightCold) { + SmallVector ColdBlocks; + ColdBlocks.push_back(ColdBlock); + + unsigned Checked = 0; + while (Checked < ColdBlocks.size()) { + BasicBlock *CBB = ColdBlocks[Checked++]; + for (BasicBlock *PrevBB : predecessors(CBB)) { + if (is_contained(ColdBlocks, PrevBB)) + continue; + + // If all the successors are cold, this block is also cold. + bool NonColdSucc = false; + for (BasicBlock *SBB : successors(PrevBB)) + if (!is_contained(ColdBlocks, SBB)) { + NonColdSucc = true; + break; + } + + if (!NonColdSucc) { + ColdBlocks.push_back(PrevBB); + DEBUG(dbgs() << " added BB '" << PrevBB->getName() << "' in the cold region\n"); + } + } + } + + for (BasicBlock *CBB : ColdBlocks) + for (BasicBlock *PrevBB : predecessors(CBB)) { + // If a previous block is not in the cold region, + // this block is at the dominance frontier. + // We set the branch hint to show that this CFG edge is cold. + if (is_contained(ColdBlocks, PrevBB)) + continue; + + BranchInst *BI = dyn_cast(PrevBB->getTerminator()); + if (!(BI && BI->isConditional() && + BI->getMetadata(LLVMContext::MD_prof) == NULL)) + continue; + + MDNode *Metadata = NULL; + MDBuilder MDB = MDBuilder(BI->getContext()); + if (BI->getSuccessor(0) == CBB) { + assert(!is_contained(ColdBlocks, BI->getSuccessor(1)) && + "At least one successor of a non-cold BB must be non-cold"); + Metadata = MDB.createBranchWeights(WeightCold, WeightHot); + } + else { + assert(!is_contained(ColdBlocks, BI->getSuccessor(0)) && + "At least one successor of a non-cold BB must be non-cold"); + Metadata = MDB.createBranchWeights(WeightHot, WeightCold); + } + + BI->setMetadata(LLVMContext::MD_prof, Metadata); + DEBUG(dbgs() << " JT: Set branch hint in BB '" << ColdBlock->getName() + << "'\n "); + DEBUG(Metadata->dump()); + DEBUG(BI->dump()); + } +} + /// Return the cost of duplicating a piece of this block from first non-phi /// and before StopAt instruction to thread across it. Stop scanning the block /// when exceeding the threshold. If duplication is impossible, returns ~0U. @@ -1204,14 +1333,24 @@ if (!SamePopularity.empty()) { SamePopularity.push_back(MostPopularDest); TerminatorInst *TI = BB->getTerminator(); - for (unsigned i = 0; ; ++i) { - assert(i != TI->getNumSuccessors() && "Didn't find any successor!"); + uint64_t WeightTrue, WeightFalse; + if (getCondBranchWeight(TI, WeightTrue, WeightFalse)) { + // If we have a branch hint on a biased conditional branch, + // we select the colder block as the destination to move the branch hint + // upword to the dominance frontier of this cold block. + if (WeightTrue < WeightFalse) MostPopularDest = TI->getSuccessor(0); + else MostPopularDest = TI->getSuccessor(1); + } + else { + for (unsigned i = 0; ; ++i) { + assert(i != TI->getNumSuccessors() && "Didn't find any successor!"); - if (!is_contained(SamePopularity, TI->getSuccessor(i))) - continue; + if (!is_contained(SamePopularity, TI->getSuccessor(i))) + continue; - MostPopularDest = TI->getSuccessor(i); - break; + MostPopularDest = TI->getSuccessor(i); + break; + } } } @@ -1545,11 +1684,21 @@ NewBB->moveAfter(PredBB); // Set the block frequency of NewBB. + uint64_t WeightHot, WeightCold; + BasicBlock *ColdBlock = NULL; if (HasProfileData) { auto NewBBFreq = BFI->getBlockFreq(PredBB) * BPI->getEdgeProbability(PredBB, BB); BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); } + else { + // Even if the full profile information is not available, + // we may have branch hint (e.g. _builtin_expect) as metadata. + // If the branch weight is highly biased, we try to propagate + // the information to other conditional branches. + ColdBlock = findColdBlockBasedOnMetaData(BB->getTerminator(), + WeightHot, WeightCold); + } BasicBlock::iterator BI = BB->begin(); for (; PHINode *PN = dyn_cast(BI); ++BI) @@ -1638,6 +1787,11 @@ // Update the edge weight from BB to SuccBB, which should be less than before. UpdateBlockFreqAndEdgeWeight(PredBB, BB, NewBB, SuccBB); + // If we know there is a cold block based on a branch hint we are deleting, + // we try to move the branch hint. + if (ColdBlock) + setMetaDataBasedOnColdBlock(ColdBlock, WeightHot, WeightCold); + // Threaded an edge! ++NumThreads; return true; Index: test/Transforms/JumpThreading/branch-metadata.ll =================================================================== --- /dev/null +++ test/Transforms/JumpThreading/branch-metadata.ll @@ -0,0 +1,162 @@ +; RUN: opt < %s -jump-threading -S | FileCheck %s + +; This test examines how jump-threading handles branch weights metadata +; if the branch to be eliminated has the metadata. + +define i32 @func(i32 %flag) { +; CHECK-LABEL: @func +; CHECK: br i1 %tobool.i, label %if.hot, label %if.cold, !prof !0 +; CHECK-DAG: if.hot: +; CHECK-DAG: if.cold: +entry: + %tobool.i = icmp eq i32 %flag, 0 + br i1 %tobool.i, label %to.be.eliminated, label %if.then + +if.then: + br label %to.be.eliminated + +to.be.eliminated: + %retval.0.i = phi i32 [ 0, %if.then ], [ 1, %entry ] + %cmp = icmp eq i32 %retval.0.i, 0 + br i1 %cmp, label %if.cold, label %if.hot, !prof !1 + +if.cold: + call void @cold_func() + br label %return + +if.hot: + call void @hot_func() + br label %return + +return: + %retval.0 = phi i32 [ 0, %if.cold ], [ 1, %if.hot ] + ret i32 %retval.0 +} + +define i32 @func2(i32 %flag) { +; CHECK-LABEL: @func2 +; CHECK: br i1 %tobool.i, label %if.cold, label %if.hot, !prof !1 +; CHECK-DAG: if.hot: +; CHECK-DAG: if.cold: +entry: + %tobool.i = icmp eq i32 %flag, 0 + br i1 %tobool.i, label %to.be.eliminated, label %if.then + +if.then: + br label %to.be.eliminated + +to.be.eliminated: + %retval.0.i = phi i32 [ 0, %if.then ], [ 1, %entry ] + %cmp = icmp eq i32 %retval.0.i, 0 + br i1 %cmp, label %if.hot, label %if.cold, !prof !0 + +if.cold: + call void @cold_func() + br label %return + +if.hot: + call void @hot_func() + br label %return + +return: + %retval.0 = phi i32 [ 0, %if.cold ], [ 1, %if.hot ] + ret i32 %retval.0 +} + + +define signext i32 @func3(i32 signext %flag) { +; This is a test with more complicated CFG. +; The following is the equivalent C code. +; We want to set branch weights for "if (flag & 1)" and "if (flag & 2)". +; +; inline int bar(int flag) { +; if (flag & 1) { +; cold_func(); +; return 0; +; } +; if (flag & 2) { +; if (__builtin_expect(flag & 4, 0)) cold_func(); +; cold_func(); +; return 0; +; } +; if (__builtin_expect(flag & 8, 0)) cold_func(); +; return 1; +; } +; +; int func3(int flag) { +; if (__builtin_expect(0 == bar(flag), 0)) { +; cold_func(); +; return 0; +; } +; hot_func(); +; return 1; +; } + +; CHECK-LABEL: @func3 +; CHECK: br i1 %tobool.i, label %if.end.i, label %if.then.i, !prof !0 +; CHECK: if.then.i: +; CHECK: br i1 %tobool2.i, label %if.end8.i, label %if.then3.i, !prof !0 +; CHECK: if.then3.i: + +entry: + %and.i = and i32 %flag, 1 + %tobool.i = icmp eq i32 %and.i, 0 + br i1 %tobool.i, label %if.end.i, label %if.then.i + +if.then.i: + call void @cold_func() + br label %bar.exit + +if.end.i: + %and1.i = and i32 %flag, 2 + %tobool2.i = icmp eq i32 %and1.i, 0 + br i1 %tobool2.i, label %if.end8.i, label %if.then3.i + +if.then3.i: + %and4.i = and i32 %flag, 4 + %tobool5.i = icmp eq i32 %and4.i, 0 + br i1 %tobool5.i, label %if.end7.i, label %if.then6.i, !prof !0 + +if.then6.i: + call void @cold_func() + br label %if.end7.i + +if.end7.i: + call void @cold_func() + br label %bar.exit + +if.end8.i: + %and9.i = and i32 %flag, 8 + %tobool12.i = icmp eq i32 %and9.i, 0 + br i1 %tobool12.i, label %bar.exit, label %if.then13.i, !prof !0 + +if.then13.i: + call void @cold_func() + br label %bar.exit + +bar.exit: + %retval.0.i = phi i32 [ 0, %if.then.i ], [ 0, %if.end7.i ], [ 1, %if.end8.i ], [ 1, %if.then13.i ] + %cmp = icmp eq i32 %retval.0.i, 0 + br i1 %cmp, label %if.then, label %if.end, !prof !1 + +if.then: + call void @cold_func() + br label %return + +if.end: + call void @hot_func() + br label %return + +return: + %retval.0 = phi i32 [ 0, %if.then ], [ 1, %if.end ] + ret i32 %retval.0 +} + +; CHECK-DAG: !0 = !{!"branch_weights", i32 2000, i32 1} +; CHECK-DAG: !1 = !{!"branch_weights", i32 1, i32 2000} + + +declare void @hot_func() +declare void @cold_func() +!0 = !{!"branch_weights", i32 2000, i32 1} +!1 = !{!"branch_weights", i32 1, i32 2000}