Index: lib/Transforms/Utils/SimplifyCFG.cpp =================================================================== --- lib/Transforms/Utils/SimplifyCFG.cpp +++ lib/Transforms/Utils/SimplifyCFG.cpp @@ -123,6 +123,15 @@ cl::desc("Limit maximum recursion depth when calculating costs of " "speculatively executed instructions")); +static cl::opt PeelDominantSwitchCase( + "peel-dominant-switch-case", cl::Hidden, cl::init(true), + cl::desc("peel off the dominant case in switch statement")); + +static cl::opt DominantSwitchCaseThreshold( + "dominant-switch-case-threshold", cl::Hidden, cl::init(80), + cl::desc("The case is considered to be dominant if its block weight is " + "greater than this percent of the total block weights")); + STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps"); STATISTIC(NumLinearMaps, "Number of switch instructions turned into linear mapping"); @@ -135,6 +144,8 @@ STATISTIC(NumSinkCommons, "Number of common instructions sunk down to the end block"); STATISTIC(NumSpeculations, "Number of speculative executed instructions"); +STATISTIC(NumPeeledCases, + "Number of dominant cases peeled in switch statements"); namespace { @@ -960,9 +971,10 @@ // default weight to be the first entry. if (BranchInst *BI = dyn_cast(TI)) { assert(Weights.size() == 2); - ICmpInst *ICI = cast(BI->getCondition()); - if (ICI->getPredicate() == ICmpInst::ICMP_EQ) - std::swap(Weights.front(), Weights.back()); + if (ICmpInst *ICI = dyn_cast(BI->getCondition())) { + if (ICI->getPredicate() == ICmpInst::ICMP_EQ) + std::swap(Weights.front(), Weights.back()); + } } } @@ -976,6 +988,11 @@ } } +static bool WeightIsDominant(uint64_t Weight, uint64_t TotalWeight) { + assert(Weight <= TotalWeight); + return (TotalWeight / 100 < Weight / DominantSwitchCaseThreshold); +} + /// The specified terminator is a value equality comparison instruction /// (either a switch or a branch on "X == c"). /// See if any of the predecessors of the terminator block are value comparisons @@ -991,10 +1008,24 @@ while (!Preds.empty()) { BasicBlock *Pred = Preds.pop_back_val(); - // See if the predecessor is a comparison with the same value. + // Don't fold with predessors that having strongly biased branches. With + // PeelDominantSwitchCase on, we will peel them off. TerminatorInst *PTI = Pred->getTerminator(); - Value *PCV = isValueEqualityComparison(PTI); // PredCondVal + SmallVector Weights; + bool PredHasWeights = HasBranchWeights(PTI); + if (PredHasWeights) { + GetBranchWeights(PTI, Weights); + if (PeelDominantSwitchCase && dyn_cast(PTI)) { + assert(Weights.size() == 2); + uint64_t TotalWeight = Weights[0] + Weights[1]; + if (WeightIsDominant(Weights[0], TotalWeight) || + WeightIsDominant(Weights[1], TotalWeight)) + return false; + } + } + // See if the predecessor is a comparison with the same value. + Value *PCV = isValueEqualityComparison(PTI); // PredCondVal if (PCV == CV && TI != PTI) { SmallSetVector FailBlocks; if (!SafeToMergeTerminators(TI, PTI, &FailBlocks)) { @@ -1017,12 +1048,9 @@ SmallVector NewSuccessors; // Update the branch weight metadata along the way - SmallVector Weights; - bool PredHasWeights = HasBranchWeights(PTI); bool SuccHasWeights = HasBranchWeights(TI); if (PredHasWeights) { - GetBranchWeights(PTI, Weights); // branch-weight metadata is inconsistent here. if (Weights.size() != 1 + PredCases.size()) PredHasWeights = SuccHasWeights = false; @@ -4492,6 +4520,103 @@ return Changed; } +// Peel off the dominant case in the switch statement. +// For a switch stament like the following: +// switch %cond, label %default [ +// constant1: label %c1 +// constant2: label %c2 +// constant3: label %c3 +// ... +// ], +// If label_c2 is the domninant target, transform it into: +// %cmp = icmp eq %cond, constant2 +// br %cmp, label %c2, label %peeled.switch +// peeled.switch: +// switch %cond, label %default [ +// constant1: label %c1 +// constant3: label %c3 +// ... +// ], +static bool PeelDominantCase(SwitchInst *SI, IRBuilder<> &Builder) { + if (!PeelDominantSwitchCase) + return false; + // Collect branch weights into a vector. + SmallVector Weights; + MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); + bool HasWeight = MD && (MD->getNumOperands() == 2 + SI->getNumCases()); + if (!HasWeight) + return false; + + uint32_t MaxWeight = 0; + uint64_t TotalWeight = MaxWeight; + uint32_t MaxWeightIndex; + for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e; ++MD_i) { + ConstantInt *CI = mdconst::extract(MD->getOperand(MD_i)); + uint32_t Weight = CI->getValue().getZExtValue(); + TotalWeight += Weight; + if (Weight > MaxWeight) { + MaxWeight = Weight; + MaxWeightIndex = MD_i - 1; + } + Weights.push_back(CI->getValue().getZExtValue()); + } + + // Don't do anything if the default has the max weight. + if (MaxWeightIndex == 0) + return false; + // Don't perform if there is no dominent case. + if (!WeightIsDominant(MaxWeight, TotalWeight)) + return false; + // Don't perform if this is already peeled. + BasicBlock *BB = SI->getParent(); + if (BasicBlock *Pred = BB->getSinglePredecessor()) { + if (BranchInst *BI = dyn_cast(Pred->getTerminator())) { + ICmpInst *ICI = cast(BI->getCondition()); + if (ICI && ICI->getPredicate() == ICmpInst::ICMP_EQ && + ICI->getOperand(0) == SI->getCondition() && + isa(ICI->getOperand(1))) + return false; + } + } + + DEBUG(dbgs() << "Peel DominantCase case.\n== Basic Block Before ==" << *BB); + + Weights.erase(Weights.begin() + MaxWeightIndex); + uint32_t Index = 1; + ConstantInt *CaseVal = nullptr; + BasicBlock *CaseSuccessor = nullptr; + for (auto i = SI->case_begin(), e = SI->case_end(); i != e; ++i, ++Index) { + if (Index == MaxWeightIndex) { + CaseVal = i->getCaseValue(); + CaseSuccessor = i->getCaseSuccessor(); + SI->removeCase(i); + SI->setMetadata(LLVMContext::MD_prof, + MDBuilder(SI->getParent()->getContext()) + .createBranchWeights(Weights)); + break; + } + } + assert(CaseVal != nullptr); + + Value *Cmp = + Builder.CreateICmpEQ(SI->getCondition(), CaseVal, "switch.peeledcmp"); + MDNode *BranchWeights = + MDBuilder(SI->getContext()) + .createBranchWeights(MaxWeight, TotalWeight - MaxWeight); + BasicBlock *NewSwitchBB = BB->splitBasicBlock(SI->getIterator()); + NewSwitchBB->setName("peeled.switch"); + + TerminatorInst *OldTI = BB->getTerminator(); + IRBuilder<> CmpBB_Builder(NewSwitchBB->getSinglePredecessor()); + CmpBB_Builder.CreateCondBr(Cmp, CaseSuccessor, NewSwitchBB, BranchWeights); + OldTI->eraseFromParent(); + + DEBUG(dbgs() << "== Basic Block After ==" << *BB << *(SI->getParent()) + << "\n"); + NumPeeledCases++; + return true; +} + /// Return true if the backend will be able to handle /// initializing an array of constants like C. static bool ValidLookupTableConstant(Constant *C, const TargetTransformInfo &TTI) { @@ -5552,6 +5677,9 @@ if (ForwardSwitchConditionToPHI(SI)) return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + if (PeelDominantCase(SI, Builder)) + return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + // The conversion from switch to lookup tables results in difficult // to analyze code and makes pruning branches much harder. // This is a problem of the switch expression itself can still be Index: test/Transforms/SimplifyCFG/peel_switch_case.ll =================================================================== --- test/Transforms/SimplifyCFG/peel_switch_case.ll +++ test/Transforms/SimplifyCFG/peel_switch_case.ll @@ -0,0 +1,44 @@ +; RUN: opt -simplifycfg -S < %s | FileCheck %s + +define i32 @foo(i16 signext %n) local_unnamed_addr !prof !1 { +entry: + %conv = sext i16 %n to i32 + switch i32 %conv, label %sw.epilog [ + i32 8, label %return + i32 -8826, label %sw.bb1 + i32 18312, label %sw.bb3 + i32 18568, label %sw.bb5 + i32 129, label %sw.bb7 + ], !prof !2 +; CHECK: %switch.peeledcmp = icmp eq i32 %conv, 8 +; CHECK: br i1 %switch.peeledcmp, label %return, label %peeled.switch, !prof ![[BRANCH_COUNT:[0-9]+]] +; CHECK: peeled.switch: +; CHECK-NOT: i32 8, label %return +; CHECK: ], !prof ![[SWITCH_COUNT:[0-9]+]] + + +sw.bb1: + br label %return + +sw.bb3: + br label %return + +sw.bb5: + br label %return + +sw.bb7: + br label %return + +sw.epilog: + br label %return + +return: + %retval = phi i32 [ 0, %sw.epilog ], [ 5, %sw.bb7 ], [ 4, %sw.bb5 ], [ 3, %sw.bb3 ], [ 2, %sw.bb1 ], [ 1, %entry ] + ret i32 %retval +} + +!1 = !{!"function_entry_count", i64 101000} +!2 = !{!"branch_weights", i32 0, i32 100000, i32 0, i32 1000, i32 0, i32 0} +; CHECK: ![[BRANCH_COUNT]] = !{!"branch_weights", i32 100000, i32 1000} +; CHECK: ![[SWITCH_COUNT]] = !{!"branch_weights", i32 0, i32 0, i32 1000, i32 0, i32 0} + Index: test/Transforms/SimplifyCFG/preserve-branchweights-switch-create.ll =================================================================== --- test/Transforms/SimplifyCFG/preserve-branchweights-switch-create.ll +++ test/Transforms/SimplifyCFG/preserve-branchweights-switch-create.ll @@ -1,4 +1,4 @@ -; RUN: opt -simplifycfg -S -o - < %s | FileCheck %s +; RUN: opt -simplifycfg -peel-dominant-switch-case=false -S -o - < %s | FileCheck %s declare void @func2(i32) declare void @func4(i32)