Index: lib/Transforms/Scalar/LoopUnswitch.cpp =================================================================== --- lib/Transforms/Scalar/LoopUnswitch.cpp +++ lib/Transforms/Scalar/LoopUnswitch.cpp @@ -158,7 +158,8 @@ // Redistribute unswitching quotas. // Note, that new loop data is stored inside the VMap. void cloneData(const Loop *NewLoop, const Loop *OldLoop, - const ValueToValueMapTy &VMap); + const ValueToValueMapTy &VMap, uint64_t NewLoopWeight, + uint64_t OldLoopWeight); }; class LoopUnswitch : public LoopPass { @@ -179,6 +180,13 @@ BlockFrequencyInfo BFI; BlockFrequency ColdEntryFreq; + // After unswitch, the remaining quota should be split + // to old loop and new loop based on branch weight. If + // no branch weight is available, use the default below + // to split evenly. + uint64_t DefaultOldLoopWeight = 50; + uint64_t DefaultNewLoopWeight = 50; + bool OptimizeForSize; bool redoLoop; @@ -243,10 +251,13 @@ bool TryTrivialLoopUnswitch(bool &Changed); bool UnswitchIfProfitable(Value *LoopCond, Constant *Val, + uint64_t NewLoopWeight, uint64_t OldLoopWeight, TerminatorInst *TI = nullptr); void UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, BasicBlock *ExitBlock, TerminatorInst *TI); void UnswitchNontrivialCondition(Value *LIC, Constant *OnVal, Loop *L, + uint64_t NewLoopWeight, + uint64_t OldLoopWeight, TerminatorInst *TI); void RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, @@ -349,7 +360,9 @@ // Redistribute unswitching quotas. // Note, that new loop data is stored inside the VMap. void LUAnalysisCache::cloneData(const Loop *NewLoop, const Loop *OldLoop, - const ValueToValueMapTy &VMap) { + const ValueToValueMapTy &VMap, + uint64_t NewLoopWeight, + uint64_t OldLoopWeight) { LoopProperties &NewLoopProps = LoopsProperties[NewLoop]; LoopProperties &OldLoopProps = *CurrentLoopProperties; @@ -361,8 +374,18 @@ ++OldLoopProps.WasUnswitchedCount; NewLoopProps.WasUnswitchedCount = 0; unsigned Quota = OldLoopProps.CanBeUnswitchedCount; - NewLoopProps.CanBeUnswitchedCount = Quota / 2; - OldLoopProps.CanBeUnswitchedCount = Quota - Quota / 2; + + // Spill quota based on branch weight. Notice that integer + // division will round to floor. Therefore to make sure the + // hotter branch always get a higher quota, add one if + // NewLoopWeight is greater OldLoopWeight (if they are equal, + // OldLoop get the extra count which matches the original + // behavior without branch weight). + NewLoopProps.CanBeUnswitchedCount = + Quota * NewLoopWeight / (NewLoopWeight + OldLoopWeight) + + NewLoopWeight > OldLoopWeight ? 1 : 0; + OldLoopProps.CanBeUnswitchedCount = + Quota - NewLoopProps.CanBeUnswitchedCount; NewLoopProps.SizeEstimation = OldLoopProps.SizeEstimation; @@ -544,8 +567,27 @@ // unswitch on it if we desire. Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), currentLoop, Changed); + + // Get branch weights from metadata. Use default if metadata + // is not available. + uint64_t NewLoopWeight = DefaultNewLoopWeight, + OldLoopWeight = DefaultOldLoopWeight; + MDNode *MD = BI->getMetadata(LLVMContext::MD_prof); + if (MD && MD->getOperand(0)) { + if (MDString *MDS = dyn_cast(MD->getOperand(0))) { + if (MDS->getString().equals("branch_weights")) { + + ConstantInt *CI = mdconst::extract(MD->getOperand(1)); + NewLoopWeight = CI->getValue().getZExtValue(); + CI = mdconst::extract(MD->getOperand(2)); + OldLoopWeight = CI->getValue().getZExtValue(); + } + } + } + if (LoopCond && - UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) { + UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), + NewLoopWeight, OldLoopWeight, TI)) { ++NumBranches; return true; } @@ -575,7 +617,8 @@ if (!UnswitchVal) continue; - if (UnswitchIfProfitable(LoopCond, UnswitchVal)) { + if (UnswitchIfProfitable(LoopCond, UnswitchVal, DefaultNewLoopWeight, + DefaultOldLoopWeight)) { ++NumSwitches; return true; } @@ -589,7 +632,9 @@ Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), currentLoop, Changed); if (LoopCond && UnswitchIfProfitable(LoopCond, - ConstantInt::getTrue(Context))) { + ConstantInt::getTrue(Context), + DefaultNewLoopWeight, + DefaultOldLoopWeight)) { ++NumSelects; return true; } @@ -652,6 +697,8 @@ /// simplify the loop. If we decide that this is profitable, /// unswitch the loop, reprocess the pieces, then return true. bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val, + uint64_t NewLoopWeight, + uint64_t OldLoopWeight, TerminatorInst *TI) { // Check to see if it would be profitable to unswitch current loop. if (!BranchesInfo.CostAllowsUnswitching()) { @@ -663,7 +710,8 @@ return false; } - UnswitchNontrivialCondition(LoopCond, Val, currentLoop, TI); + UnswitchNontrivialCondition(LoopCond, Val, currentLoop, NewLoopWeight, + OldLoopWeight, TI); return true; } @@ -969,7 +1017,9 @@ /// Split it into loop versions and test the condition outside of either loop. /// Return the loops created as Out1/Out2. void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, - Loop *L, TerminatorInst *TI) { + Loop *L, uint64_t NewLoopWeight, + uint64_t OldLoopWeight, + TerminatorInst *TI) { Function *F = loopHeader->getParent(); DEBUG(dbgs() << "loop-unswitch: Unswitching loop %" << loopHeader->getName() << " [" << L->getBlocks().size() @@ -1031,7 +1081,7 @@ // Recalculate unswitching quota, inherit simplified switches info for NewBB, // Probably clone more loop-unswitch related loop properties. - BranchesInfo.cloneData(NewLoop, L, VMap); + BranchesInfo.cloneData(NewLoop, L, VMap, NewLoopWeight, OldLoopWeight); Loop *ParentLoop = L->getParentLoop(); if (ParentLoop) { Index: test/Transforms/LoopUnswitch/unswitch-with-branch-weight.ll =================================================================== --- /dev/null +++ test/Transforms/LoopUnswitch/unswitch-with-branch-weight.ll @@ -0,0 +1,78 @@ +; RUN: opt < %s -loop-unswitch -loop-unswitch-threshold=16 -simplifycfg -S < %s 2>&1 | FileCheck %s + + +;; This test is verify that loop unswitch takes advantage of branch weight +;; to spill unswitch quota to old loop and newly generated loop. The test +;; is equivalent to the following code: +;; for (...) +;; if (cond1) +;; dummy1() +;; else +;; dummy2() +;; if (cond2) +;; dummy3() +;; else +;; break +;; +;; This can be unswitched twice based on cond1 and cond2, ending with 4 +;; loops if we have enough quota. However, if we dont have enough quota, +;; we should spend most of them on the hot branch based on branch weight. +;; In this test, the branch to dummy1() is hotter than dummy2(), so we +;; give more quota to that branch and end up like: +;; +;; if (cond1) +;; if (cond2) +;; for (...) +;; dummy1() +;; dummy3() +;; else +;; for (...) +;; dummy1() +;; break +;; else +;; for (...) +;; dummy2() +;; if (cond2) +;; dummy3() +;; else +;; break +;; +;; If we don't spill quota based on branch weight, the colder branch +;; will get unswitched. + +define i32 @test(i1 %cond1, i1 %cond2) { + +; CHECK: call void @dummy1() #0 +; CHECK-NEXT: call void @dummy3() #0 + +; CHECK: call void @dummy2() #0 +; CHECK-NEXT: br i1 %cond2 + + br label %loop_begin + +loop_begin: + br i1 %cond1, label %loop_br1, label %loop_br2, !prof !0 + +loop_br1: + call void @dummy1() nounwind + br label %loop_continue +loop_br2: + call void @dummy2() nounwind + br label %loop_continue + +loop_continue: + br i1 %cond2, label %loop_br3, label %loop_exit + +loop_br3: + call void @dummy3() nounwind + br label %loop_begin + +loop_exit: + ret i32 0 +} + +declare void @dummy1() +declare void @dummy2() +declare void @dummy3() + +!0 = !{!"branch_weights", i32 1000, i32 1}