Index: lib/Transforms/Scalar/JumpThreading.cpp =================================================================== --- lib/Transforms/Scalar/JumpThreading.cpp +++ lib/Transforms/Scalar/JumpThreading.cpp @@ -96,6 +96,13 @@ "condition to use to thread over a weaker condition"), cl::init(3), cl::Hidden); +// We follow the default branch weights that lower-expect-intrinsic pass sets +// for _builtin_expect intrinsic (1:2000) +static cl::opt +BiasedBranchThreshold("jump-threading-biased-branch-threshold", + cl::desc("Threshold in branch weight metadata for a biased branch"), + cl::init(2000), cl::Hidden); + static cl::opt PrintLVIAfterJumpThreading( "print-lvi-after-jump-threading", cl::desc("Print the LazyValueInfo cache after JumpThreading"), cl::init(false), @@ -411,6 +418,139 @@ return EverChanged; } +/// Utility function to get branch weight metadata for a conditional branch. +/// WeightTrue and WeightFalse are valid only if the funtion returns true. +static bool +getCondBranchWeight(TerminatorInst *TI, uint64_t &WeightTrue, + uint64_t &WeightFalse) { + BranchInst *BI = dyn_cast(TI); + if (!(BI && BI->isConditional())) + return false; + + MDNode *ProfMD = BI->getMetadata(LLVMContext::MD_prof); + if (!ProfMD) + return false; + + MDString *MDS = dyn_cast(ProfMD->getOperand(0)); + // The branch weight metadata must have the name "branch_weights" + // and two branch weight values. + if (!MDS->getString().equals("branch_weights") || + ProfMD->getNumOperands() != 3) + return false; + + ConstantInt *CI1 = mdconst::extract(ProfMD->getOperand(1)); + ConstantInt *CI2 = mdconst::extract(ProfMD->getOperand(2)); + if (!CI1 || !CI2) + return false; + + WeightTrue = CI1->getValue().getZExtValue(); + WeightFalse = CI2->getValue().getZExtValue(); + return true; +} + +/// If the specified conditional branch has branch hint (i.e. biased branch +/// metadata), and the cold edge is the only incoming edge to target BB, the +/// target BB of the cold edge is a rarely-executed cold block. If a cold block +/// is identified, this method returns the pointer to the cold block and sets +/// WeightHot and WeightCold. The threshold for biased branch is controlled by +/// BiasedBranchThreshold parameter. +static BasicBlock* +findColdBlockBasedOnMetaData(TerminatorInst *TI, + uint64_t &WeightHot, uint64_t &WeightCold) { + BasicBlock* ColdBlock = nullptr; + uint64_t WeightTrue, WeightFalse; + if (!TI || !getCondBranchWeight(TI, WeightTrue, WeightFalse)) + return nullptr; + + if (WeightTrue >= WeightFalse * BiasedBranchThreshold) + ColdBlock = TI->getSuccessor(1); + else if (WeightFalse >= WeightTrue * BiasedBranchThreshold) + ColdBlock = TI->getSuccessor(0); + + if (!ColdBlock || !ColdBlock->getSinglePredecessor()) + return nullptr; + + WeightHot = std::max(WeightTrue, WeightFalse); + WeightCold = std::min(WeightTrue, WeightFalse); + + DEBUG(dbgs() << " JT: BB '" << ColdBlock->getName() + << "' is identified as a cold block based on metadata\n "); + DEBUG(TI->getMetadata(LLVMContext::MD_prof)->dump()); + return ColdBlock; +} + +/// If a cold block is known based on a branch hint, we first identify +/// the cold region, i.e. the region post-dominated by the cold block. +/// Then we set the branch hint for the conditional branches +/// that jump into the cold region from non-cold region. +static void +setMetaDataBasedOnColdBlock(BasicBlock* InitialColdBlock, unsigned WeightHot, + unsigned WeightCold) { + SmallVector ColdBlocks; + ColdBlocks.push_back(InitialColdBlock); + + // We identify the BBs post-dominated by InisialColdBlock as the cold region. + // We do not use the existing DominatorTree analyses since the CFG may be + // already changed if we have multiple branches with branch hint. + unsigned Checked = 0; + while (Checked < ColdBlocks.size()) { + BasicBlock *CBB = ColdBlocks[Checked++]; + for (BasicBlock *PrevBB : predecessors(CBB)) { + if (is_contained(ColdBlocks, PrevBB)) + continue; + + // If all the successors are known to be cold, this block is also cold. + bool NonColdSucc = false; + for (BasicBlock *SBB : successors(PrevBB)) + if (!is_contained(ColdBlocks, SBB)) { + NonColdSucc = true; + break; + } + + if (!NonColdSucc) { + ColdBlocks.push_back(PrevBB); + DEBUG(dbgs() << " added BB '" << PrevBB->getName() + << "' in the cold region\n"); + } + } + } + + for (BasicBlock *CBB : ColdBlocks) + for (BasicBlock *PrevBB : predecessors(CBB)) { + // If a previous block is not in the cold region, + // this block is at the dominance frontier. + // We set the branch hint to show that this CFG edge is cold. + if (is_contained(ColdBlocks, PrevBB)) + continue; + + BranchInst *BI = dyn_cast(PrevBB->getTerminator()); + if (!BI || !BI->isConditional()) + continue; + + // We do not set metadata for the branch already has one. + if (BI->getMetadata(LLVMContext::MD_prof)) + continue; + + MDNode *Metadata = nullptr; + MDBuilder MDB = MDBuilder(BI->getContext()); + if (BI->getSuccessor(0) == CBB) { + assert(!is_contained(ColdBlocks, BI->getSuccessor(1)) && + "At least one successor of a non-cold BB must be non-cold"); + Metadata = MDB.createBranchWeights(WeightCold, WeightHot); + } else { + assert(!is_contained(ColdBlocks, BI->getSuccessor(0)) && + "At least one successor of a non-cold BB must be non-cold"); + Metadata = MDB.createBranchWeights(WeightHot, WeightCold); + } + + BI->setMetadata(LLVMContext::MD_prof, Metadata); + DEBUG(dbgs() << " JT: Set branch hint in BB '" << CBB->getName() + << "'\n "); + DEBUG(Metadata->dump()); + DEBUG(BI->dump()); + } +} + // Replace uses of Cond with ToVal when safe to do so. If all uses are // replaced, we can remove Cond. We cannot blindly replace all uses of Cond // because we may incorrectly replace uses when guards/assumes are uses of @@ -1464,19 +1604,33 @@ // Okay, now we know the most popular destination. If there is more than one // destination, we need to determine one. This is arbitrary, but we need - // to make a deterministic decision. Pick the first one that appears in the - // successor list. + // to make a deterministic decision. if (!SamePopularity.empty()) { SamePopularity.push_back(MostPopularDest); TerminatorInst *TI = BB->getTerminator(); - for (unsigned i = 0; ; ++i) { - assert(i != TI->getNumSuccessors() && "Didn't find any successor!"); - if (!is_contained(SamePopularity, TI->getSuccessor(i))) - continue; + uint64_t WeightTrue = 0, WeightFalse = 0; + if (getCondBranchWeight(TI, WeightTrue, WeightFalse)) { + // If we have a branch hint on a biased conditional branch, + // we select the colder block as the destination to move the branch hint + // upward to the dominance frontier of this cold block. + if (WeightTrue < WeightFalse) + MostPopularDest = TI->getSuccessor(0); + else + MostPopularDest = TI->getSuccessor(1); + } + else { + // Pick the first one that appears in the successor list if no branch + // hint is available. + for (unsigned i = 0; ; ++i) { + assert(i != TI->getNumSuccessors() && "Didn't find any successor!"); - MostPopularDest = TI->getSuccessor(i); - break; + if (!is_contained(SamePopularity, TI->getSuccessor(i))) + continue; + + MostPopularDest = TI->getSuccessor(i); + break; + } } } @@ -1867,11 +2021,21 @@ NewBB->moveAfter(PredBB); // Set the block frequency of NewBB. + uint64_t WeightHot = 0, WeightCold = 0; + BasicBlock *ColdBlock = nullptr; if (HasProfileData) { auto NewBBFreq = BFI->getBlockFreq(PredBB) * BPI->getEdgeProbability(PredBB, BB); BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); } + else { + // Even if the full profile information is not available, + // we may have branch hint (e.g. _builtin_expect) as metadata. + // If the branch weight is highly biased, we try to propagate + // the information to other conditional branches. + ColdBlock = findColdBlockBasedOnMetaData(BB->getTerminator(), + WeightHot, WeightCold); + } BasicBlock::iterator BI = BB->begin(); for (; PHINode *PN = dyn_cast(BI); ++BI) @@ -1959,6 +2123,11 @@ // Update the edge weight from BB to SuccBB, which should be less than before. UpdateBlockFreqAndEdgeWeight(PredBB, BB, NewBB, SuccBB); + // If we know there is a cold block based on a branch hint associated with + // the branch we are deleting, we try to move the branch hint upward. + if (ColdBlock) + setMetaDataBasedOnColdBlock(ColdBlock, WeightHot, WeightCold); + // Threaded an edge! ++NumThreads; return true; Index: test/Transforms/JumpThreading/branch-metadata.ll =================================================================== --- /dev/null +++ test/Transforms/JumpThreading/branch-metadata.ll @@ -0,0 +1,162 @@ +; RUN: opt < %s -jump-threading -S | FileCheck %s + +; This test examines how jump-threading handles branch weights metadata +; if the branch to be eliminated has the metadata. + +define i32 @func(i32 %flag) { +; CHECK-LABEL: @func +; CHECK: br i1 %tobool.i, label %if.hot, label %if.cold, !prof !0 +; CHECK-DAG: if.hot: +; CHECK-DAG: if.cold: +entry: + %tobool.i = icmp eq i32 %flag, 0 + br i1 %tobool.i, label %to.be.eliminated, label %if.then + +if.then: + br label %to.be.eliminated + +to.be.eliminated: + %retval.0.i = phi i32 [ 0, %if.then ], [ 1, %entry ] + %cmp = icmp eq i32 %retval.0.i, 0 + br i1 %cmp, label %if.cold, label %if.hot, !prof !1 + +if.cold: + call void @cold_func() + br label %return + +if.hot: + call void @hot_func() + br label %return + +return: + %retval.0 = phi i32 [ 0, %if.cold ], [ 1, %if.hot ] + ret i32 %retval.0 +} + +define i32 @func2(i32 %flag) { +; CHECK-LABEL: @func2 +; CHECK: br i1 %tobool.i, label %if.cold, label %if.hot, !prof !1 +; CHECK-DAG: if.hot: +; CHECK-DAG: if.cold: +entry: + %tobool.i = icmp eq i32 %flag, 0 + br i1 %tobool.i, label %to.be.eliminated, label %if.then + +if.then: + br label %to.be.eliminated + +to.be.eliminated: + %retval.0.i = phi i32 [ 0, %if.then ], [ 1, %entry ] + %cmp = icmp eq i32 %retval.0.i, 0 + br i1 %cmp, label %if.hot, label %if.cold, !prof !0 + +if.cold: + call void @cold_func() + br label %return + +if.hot: + call void @hot_func() + br label %return + +return: + %retval.0 = phi i32 [ 0, %if.cold ], [ 1, %if.hot ] + ret i32 %retval.0 +} + + +define signext i32 @func3(i32 signext %flag) { +; This is a test with more complicated CFG. +; The following is the equivalent C code. +; We want to set branch weights for "if (flag & 1)" and "if (flag & 2)". +; +; inline int bar(int flag) { +; if (flag & 1) { +; cold_func(); +; return 0; +; } +; if (flag & 2) { +; if (__builtin_expect(flag & 4, 0)) cold_func(); +; cold_func(); +; return 0; +; } +; if (__builtin_expect(flag & 8, 0)) cold_func(); +; return 1; +; } +; +; int func3(int flag) { +; if (__builtin_expect(0 == bar(flag), 0)) { +; cold_func(); +; return 0; +; } +; hot_func(); +; return 1; +; } + +; CHECK-LABEL: @func3 +; CHECK: br i1 %tobool.i, label %if.end.i, label %if.then.i, !prof !0 +; CHECK: if.then.i: +; CHECK: br i1 %tobool2.i, label %if.end8.i, label %if.then3.i, !prof !0 +; CHECK: if.then3.i: + +entry: + %and.i = and i32 %flag, 1 + %tobool.i = icmp eq i32 %and.i, 0 + br i1 %tobool.i, label %if.end.i, label %if.then.i + +if.then.i: + call void @cold_func() + br label %bar.exit + +if.end.i: + %and1.i = and i32 %flag, 2 + %tobool2.i = icmp eq i32 %and1.i, 0 + br i1 %tobool2.i, label %if.end8.i, label %if.then3.i + +if.then3.i: + %and4.i = and i32 %flag, 4 + %tobool5.i = icmp eq i32 %and4.i, 0 + br i1 %tobool5.i, label %if.end7.i, label %if.then6.i, !prof !0 + +if.then6.i: + call void @cold_func() + br label %if.end7.i + +if.end7.i: + call void @cold_func() + br label %bar.exit + +if.end8.i: + %and9.i = and i32 %flag, 8 + %tobool12.i = icmp eq i32 %and9.i, 0 + br i1 %tobool12.i, label %bar.exit, label %if.then13.i, !prof !0 + +if.then13.i: + call void @cold_func() + br label %bar.exit + +bar.exit: + %retval.0.i = phi i32 [ 0, %if.then.i ], [ 0, %if.end7.i ], [ 1, %if.end8.i ], [ 1, %if.then13.i ] + %cmp = icmp eq i32 %retval.0.i, 0 + br i1 %cmp, label %if.then, label %if.end, !prof !1 + +if.then: + call void @cold_func() + br label %return + +if.end: + call void @hot_func() + br label %return + +return: + %retval.0 = phi i32 [ 0, %if.then ], [ 1, %if.end ] + ret i32 %retval.0 +} + +; CHECK-DAG: !0 = !{!"branch_weights", i32 2000, i32 1} +; CHECK-DAG: !1 = !{!"branch_weights", i32 1, i32 2000} + + +declare void @hot_func() +declare void @cold_func() +!0 = !{!"branch_weights", i32 2000, i32 1} +!1 = !{!"branch_weights", i32 1, i32 2000}