Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h @@ -652,6 +652,8 @@ void visitUnreachable(const UnreachableInst &I); // Helpers for visitSwitch + MachineBasicBlock *handleHotSwitchCase(CaseVector& Cases, + const SwitchInst& SI); bool handleSmallSwitchRange(CaseRec& CR, CaseRecVector& WorkList, const Value* SV, Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -75,6 +75,11 @@ cl::location(LimitFloatPrecision), cl::init(0)); +static cl::opt +HotCaseProb("hot-switch-case-probability", cl::Hidden, cl::init(80), + cl::desc("Branch probability necessary to consider a case statement" + "hot (default 80%)")); + // Limit the width of DAG chains. This is important in general to prevent // prevent DAG-based analysis from blowing up. For example, alias analysis and // load clustering may not complete in reasonable time. It is difficult to @@ -2068,6 +2073,81 @@ setValue(&LP, Res); } +/// Look for a "hot" case and, if found, insert conditional logic to branch +/// to the hot case prior to jumping into the switching logic. +MachineBasicBlock * +SelectionDAGBuilder::handleHotSwitchCase(CaseVector& Cases, + const SwitchInst& SI) { + BranchProbabilityInfo *BPI = FuncInfo.BPI; + if (!BPI) + return nullptr; + + // handleSmallSwitchRange handles switches with few cases. + if (Cases.size() <= 3) + return nullptr; + + // Find the edge with the largest weight and sum the total weight to compute + // the branch probability of the hottest case. + CaseItr HotCase; + uint32_t MaxCaseWeight = 0; + uint32_t TotalWeight = BPI->getEdgeWeight(SI.getParent(), SI.getDefaultDest()); + for (CaseItr I = Cases.begin(), E = Cases.end(); I!=E; ++I) { + TotalWeight += I->ExtraWeight; + if (I->ExtraWeight > MaxCaseWeight) { + MaxCaseWeight = I->ExtraWeight; + HotCase = I; + } + } + + // Only handle single branch. + if (HotCase->High != HotCase->Low) + return nullptr; + + // Check if we have a hot case. + if (BranchProbability(MaxCaseWeight, TotalWeight) < + BranchProbability(HotCaseProb, 100)) + return nullptr; + + DEBUG(dbgs() << "Found hot switch case: " + << cast(HotCase->Low)->getValue() << '\n' + << "Num cases: " << Cases.size() << "\nProbability: " + << BranchProbability(MaxCaseWeight, TotalWeight) << '\n' + << "Weight: " << MaxCaseWeight<< '\n'); + + // Get the MachineFunction which holds the current MBB. This is used when + // inserting any additional MBBs necessary to represent the switch. + MachineFunction *CurMF = FuncInfo.MF; + MachineBasicBlock *SwitchMBB = FuncInfo.MBB; + + // Figure out which block is immediately after the current one. + MachineFunction::iterator BBI = SwitchMBB; + ++BBI; + + // Create a CaseBlock record representing a conditional branch to the Case's + // hot MBB if the value being switched on SV is equal to C. + MachineBasicBlock *FallThrough = + CurMF->CreateMachineBasicBlock(SwitchMBB->getBasicBlock()); + CurMF->insert(BBI, FallThrough); + + // Put SV in a virtual register to make it available from the new blocks. + const Value *SV = SI.getCondition(); + ExportFromCurrentBlock(SV); + + assert(HotCase->High == HotCase->Low && "Expected a single case statement."); + CaseBlock CB(ISD::SETEQ, SV, HotCase->Low, nullptr, /* truebb */ HotCase->BB, + /* falsebb */ FallThrough, /* me */ SwitchMBB, + /* trueweight */ HotCase->ExtraWeight, + /* falseweight */ TotalWeight - HotCase->ExtraWeight); + + // Push the CaseBlock onto the vector to be later processed by SDISel. + SwitchCases.push_back(CB); + + // Remove the hot case from the vector of cases. + Cases.erase(HotCase); + + return FallThrough; +} + /// handleSmallSwitchCaseRange - Emit a series of specific tests (suitable for /// small case ranges). bool SelectionDAGBuilder::handleSmallSwitchRange(CaseRec& CR, @@ -2661,7 +2741,6 @@ for (auto &I : Cases) // A range counts double, since it requires two compares. numCmps += I.Low != I.High ? 2 : 1; - dbgs() << "Clusterify finished. Total clusters: " << Cases.size() << ". Total compares: " << numCmps << '\n'; }); @@ -2708,6 +2787,12 @@ CaseVector Cases; Clusterify(Cases, SI); + // Look for a "hot" case and, if found, insert conditional logic to branch + // to the hot case prior to jumping into the switching logic. + MachineBasicBlock *CaseRecordMBB = SwitchMBB; + if (MachineBasicBlock *NewMBB = handleHotSwitchCase(Cases, SI)) + CaseRecordMBB = NewMBB; + // Get the Value to be switched on and default basic blocks, which will be // inserted into CaseBlock records, representing basic blocks in the binary // search tree. @@ -2715,7 +2800,7 @@ // Push the initial CaseRec onto the worklist CaseRecVector WorkList; - WorkList.push_back(CaseRec(SwitchMBB,nullptr,nullptr, + WorkList.push_back(CaseRec(CaseRecordMBB, nullptr, nullptr, CaseRange(Cases.begin(),Cases.end()))); while (!WorkList.empty()) { Index: test/CodeGen/AArch64/switch-pgo.ll =================================================================== --- /dev/null +++ test/CodeGen/AArch64/switch-pgo.ll @@ -0,0 +1,261 @@ +; RUN: llc -verify-machineinstrs < %s -mtriple=aarch64-linux-gnu | FileCheck %s + +; CHECK-LABEL: test1: +; CHECK: cmp w0, #111 +; CHECK: cmp w0, #26 +; CHECK: cmp w0, #10 +; CHECK: cmp w0, #12 +; CHECK: cmp w0, #112 +; CHECK: cmp w0, #27 +; CHECK: cmp w0, #16 + +define void @test1(i32 %cmp, i32* %ptr) { +entry: + switch i32 %cmp, label %exit [ + i32 10, label %sw.bb1 + i32 12, label %sw.bb2 + i32 16, label %sw.bb3 + i32 27, label %sw.bb4 + i32 112, label %sw.bb5 + ], !prof !1 + +sw.bb1: + store i32 1, i32* %ptr, align 4 + br label %exit + +sw.bb2: + store i32 2, i32* %ptr, align 4 + br label %exit + +sw.bb3: + store i32 3, i32* %ptr, align 4 + br label %exit + +sw.bb4: + store i32 4, i32* %ptr, align 4 + br label %exit + +sw.bb5: + store i32 5, i32* %ptr, align 4 + br label %exit + +exit: + ret void +} + +; CHECK-LABEL: test2: +; CHECK: cmp w0, #10 +; CHECK: cmp w0, #111 +; CHECK: cmp w0, #12 +; CHECK: cmp w0, #16 +; CHECK: cmp w0, #112 +; CHECK: cmp w0, #27 + +define void @test2(i32 %cmp, i32* %ptr) { +entry: + switch i32 %cmp, label %exit [ + i32 10, label %sw.bb1 + i32 12, label %sw.bb2 + i32 16, label %sw.bb3 + i32 27, label %sw.bb4 + i32 112, label %sw.bb5 + ], !prof !2 + +sw.bb1: + store i32 1, i32* %ptr, align 4 + br label %exit + +sw.bb2: + store i32 2, i32* %ptr, align 4 + br label %exit + +sw.bb3: + store i32 3, i32* %ptr, align 4 + br label %exit + +sw.bb4: + store i32 4, i32* %ptr, align 4 + br label %exit + +sw.bb5: + store i32 5, i32* %ptr, align 4 + br label %exit + +exit: + ret void +} + +; CHECK-LABEL: test3: +; CHECK: cmp w0, #12 +; CHECK: cmp w0, #111 +; CHECK: cmp w0, #10 +; CHECK: cmp w0, #16 +; CHECK: cmp w0, #112 +; CHECK: cmp w0, #27 + +define void @test3(i32 %cmp, i32* %ptr) { +entry: + switch i32 %cmp, label %exit [ + i32 10, label %sw.bb1 + i32 12, label %sw.bb2 + i32 16, label %sw.bb3 + i32 27, label %sw.bb4 + i32 112, label %sw.bb5 + ], !prof !3 + +sw.bb1: + store i32 1, i32* %ptr, align 4 + br label %exit + +sw.bb2: + store i32 2, i32* %ptr, align 4 + br label %exit + +sw.bb3: + store i32 3, i32* %ptr, align 4 + br label %exit + +sw.bb4: + store i32 4, i32* %ptr, align 4 + br label %exit + +sw.bb5: + store i32 5, i32* %ptr, align 4 + br label %exit + +exit: + ret void +} + +; CHECK-LABEL: test4: +; CHECK: cmp w0, #16 +; CHECK: cmp w0, #111 +; CHECK: cmp w0, #10 +; CHECK: cmp w0, #12 +; CHECK: cmp w0, #112 +; CHECK: cmp w0, #27 + +define void @test4(i32 %cmp, i32* %ptr) { +entry: + switch i32 %cmp, label %exit [ + i32 10, label %sw.bb1 + i32 12, label %sw.bb2 + i32 16, label %sw.bb3 + i32 27, label %sw.bb4 + i32 112, label %sw.bb5 + ], !prof !4 + +sw.bb1: + store i32 1, i32* %ptr, align 4 + br label %exit + +sw.bb2: + store i32 2, i32* %ptr, align 4 + br label %exit + +sw.bb3: + store i32 3, i32* %ptr, align 4 + br label %exit + +sw.bb4: + store i32 4, i32* %ptr, align 4 + br label %exit + +sw.bb5: + store i32 5, i32* %ptr, align 4 + br label %exit + +exit: + ret void +} + +; CHECK-LABEL: test5: +; CHECK: cmp w0, #27 +; CHECK: cmp w0, #111 +; CHECK: cmp w0, #10 +; CHECK: cmp w0, #12 +; CHECK: cmp w0, #112 +; CHECK: cmp w0, #16 + +define void @test5(i32 %cmp, i32* %ptr) { +entry: + switch i32 %cmp, label %exit [ + i32 10, label %sw.bb1 + i32 12, label %sw.bb2 + i32 16, label %sw.bb3 + i32 27, label %sw.bb4 + i32 112, label %sw.bb5 + ], !prof !5 + +sw.bb1: + store i32 1, i32* %ptr, align 4 + br label %exit + +sw.bb2: + store i32 2, i32* %ptr, align 4 + br label %exit + +sw.bb3: + store i32 3, i32* %ptr, align 4 + br label %exit + +sw.bb4: + store i32 4, i32* %ptr, align 4 + br label %exit + +sw.bb5: + store i32 5, i32* %ptr, align 4 + br label %exit + +exit: + ret void +} + +; CHECK-LABEL: test6: +; CHECK: cmp w0, #112 +; CHECK: cmp w0, #26 +; CHECK: cmp w0, #10 +; CHECK: cmp w0, #12 +; CHECK: cmp w0, #27 +; CHECK: cmp w0, #16 + +define void @test6(i32 %cmp, i32* %ptr) { +entry: + switch i32 %cmp, label %exit [ + i32 10, label %sw.bb1 + i32 12, label %sw.bb2 + i32 16, label %sw.bb3 + i32 27, label %sw.bb4 + i32 112, label %sw.bb5 + ], !prof !6 + +sw.bb1: + store i32 1, i32* %ptr, align 4 + br label %exit + +sw.bb2: + store i32 2, i32* %ptr, align 4 + br label %exit + +sw.bb3: + store i32 3, i32* %ptr, align 4 + br label %exit + +sw.bb4: + store i32 4, i32* %ptr, align 4 + br label %exit + +sw.bb5: + store i32 5, i32* %ptr, align 4 + br label %exit + +exit: + ret void +} + +!1 = metadata !{metadata !"branch_weights", i32 20, i32 20, i32 20, i32 20, i32 10, i32 10} +!2 = metadata !{metadata !"branch_weights", i32 1, i32 95, i32 1, i32 1, i32 1, i32 1} +!3 = metadata !{metadata !"branch_weights", i32 1, i32 1, i32 95, i32 1, i32 1, i32 1} +!4 = metadata !{metadata !"branch_weights", i32 1, i32 1, i32 1, i32 95, i32 1, i32 1} +!5 = metadata !{metadata !"branch_weights", i32 1, i32 1, i32 1, i32 1, i32 95, i32 1} +!6 = metadata !{metadata !"branch_weights", i32 1, i32 1, i32 1, i32 1, i32 1, i32 95}