Index: llvm/lib/Transforms/Utils/SimplifyCFG.cpp =================================================================== --- llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -857,15 +857,16 @@ return false; } -// Set branch weights on SwitchInst. This sets the metadata if there is at -// least one non-zero weight. -static void setBranchWeights(SwitchInst *SI, ArrayRef Weights) { +// Set branch weights on SwitchInst/BranchInst/SelectInst. +// This sets the metadata if there is at least one non-zero weight. +static void setBranchWeights(Instruction *I, ArrayRef Weights) { + assert(isa(I) || isa(I) || isa(I)); // Check that there is at least one non-zero weight. Otherwise, pass // nullptr to setMetadata which will erase the existing metadata. MDNode *N = nullptr; if (llvm::any_of(Weights, [](uint32_t W) { return W != 0; })) - N = MDBuilder(SI->getParent()->getContext()).createBranchWeights(Weights); - SI->setMetadata(LLVMContext::MD_prof, N); + N = MDBuilder(I->getParent()->getContext()).createBranchWeights(Weights); + I->setMetadata(LLVMContext::MD_prof, N); } // Similar to the above, but for branch and select instructions that take @@ -1091,6 +1092,16 @@ } } +/// Adjust the weights so that the sum of weights can fit in uint32_t. +static void FitTotalWeights(MutableArrayRef Weights) { + uint64_t Sum = std::accumulate(Weights.begin(), Weights.end(), (uint64_t)0); + if (Sum > UINT_MAX) { + unsigned Offset = 32 - countLeadingZeros(Sum); + for (uint64_t &I : Weights) + I >>= Offset; + } +} + static void CloneInstructionsIntoPredecessorBlockAndUpdateSSAUses( BasicBlock *BB, BasicBlock *PredBlock, ValueToValueMapTy &VMap) { Instruction *PTI = PredBlock->getTerminator(); @@ -3062,6 +3073,111 @@ return nullptr; } +/// We want to update the probability of all the branches of "PredBB Taken" +/// (PredBB->BB) to be its original value times the probability of "BB +/// Taken" (BB->Dest). The CFG is as follows: +/// +/// PredBB +/// / \\ +/// / \\ +/// BB +/// // \ +/// // \ +/// Dest +/// +/// The way to achieve that is by scaling all the branch weights of "PredBB Not +/// Taken" (from PredBB but not to BB) by a factor X, which is: +/// +/// X = 1 + (BNTW / BTW) * (TotW / TotNTW) +/// +/// in which: +/// "BTW" is the branch weight from BB to Dest +/// "BNTW" is the branch weight of from BB to the other successor +/// "TotW" is the sum of branch weights from PredBB to all successors +/// including BB +/// "TotNTW" is the sum of branch weight from PredBB to all other successors +/// excepting BB +/// The value will not overflow if TotW is within uint32_t range +static void ScaleWeights(uint64_t BITakenWeight, uint64_t BINotTakenWeight, + SmallVector &Weights, + const SmallDenseSet &PTITakenIndexes) { + uint64_t TotWeights = + std::accumulate(Weights.begin(), Weights.end(), (uint64_t)0); + assert(TotWeight <= UINT32_MAX); + uint64_t TotTakenWeights = 0; + // PTITakenIndexes stores the indexes of BB among all the successors of + // PredBB, which is also the indexes of weights that corresponding to + // PredBB->BB in the PredBB Terminator Instruction's weight array + for (unsigned Idx : PTITakenIndexes) + TotTakenWeights += Weights[Idx]; + uint64_t TotNotTakenWeights = TotWeights - TotTakenWeights; + + for (unsigned Idx = 0, End = Weights.size(); Idx != End; Idx++) { + // Skip the branches from PredBB to BB + if (PTITakenIndexes.find(Idx) != PTITakenIndexes.end()) + continue; + + // Now considering the branch indexed by "I". This branch is from PredBB + // but not to BB. + // We need to scale it as described above, while now we do small adjustment + // by doing add to avoid overflow: + // + // newWeights[I] = Weights[I] + (TotW*Weights[I]/TotNTW) * (BNTW/BTW) + // + // Since branch "I" is also one of the "PredBB Not Taken" branch and that + // means Weights[I]/TotNTW <= 1, thus (TotW*Weights[I]/TotNTW) is at + // most TotW and thus at most UINT_MAX + + // "Inc" is the second term in the formula above, and we calculate its two + // multiplers one by one. + // Add max(1,w) because nothing prevents weight to be 0 + uint64_t Inc = TotWeights * Weights[Idx] / + std::max(1UL, TotNotTakenWeights); + + // The Inc result here is at most UINT32_MAX ^ 2 + Inc = (Inc * BINotTakenWeight) / std::max(1UL, BITakenWeight); + + // UINT64_MAX = UINT32_MAX ^ 2 + 2 * UINT32_MAX, and here Weights[I] is + // at most UINT32_MAX ^ 2 + UINT32_MAX and thus will not exceed UINT64_MAX + Weights[Idx] += Inc; + } +} + +static void FixupBranchWeight(BasicBlock *BB, + const ArrayRef &PredBBs, + uint64_t BITakenWeight, + uint64_t BINotTakenWeight) { + for (BasicBlock *PredBB : PredBBs) { + Instruction *PTI = PredBB->getTerminator(); + SmallVector OriginalWeights; + // We only need to fix branch weights for BranchInst/SwitchInst + if (!isa(PTI) && !isa(PTI)) + continue; + OriginalWeights.assign(PTI->getNumSuccessors(), 1); + extractBranchWeights(*PTI, OriginalWeights); + assert(OriginalWeights.size() == PTI->getNumSuccessors()); + // No need to update weights for unconditional BranchInst + if (OriginalWeights.size() < 2) + continue; + SmallVector Weights(OriginalWeights.begin(), + OriginalWeights.end()); + FitTotalWeights(Weights); // required by ScaleWeights above + + SmallDenseSet PTITakenIndexes; + for (unsigned Idx = 0, End = Weights.size(); Idx != End; Idx++) + if (PTI->getSuccessor(Idx) == BB) + PTITakenIndexes.insert(Idx); + assert(!PTITakenIndexes.empty()); + + ScaleWeights(BITakenWeight, BINotTakenWeight, Weights, PTITakenIndexes); + FitWeights(Weights); + SmallVector NewWeights(Weights.begin(), Weights.end()); + // only set when the new weights is different from original one + if (NewWeights != OriginalWeights) + setBranchWeights(PTI, NewWeights); + } +} + /// If we have a conditional branch on something for which we know the constant /// value in predecessors (e.g. a phi node in the current block), thread edges /// from the predecessor to their ultimate destination. @@ -3131,6 +3247,32 @@ EdgeBB->setName(RealDest->getName() + ".critedge"); EdgeBB->moveBefore(RealDest); + // After redirecting PredBB to RealDest, we also need to update the branch + // weight of PTI(the terminator of PredBB). + // Let's use P(A->B) to represent the branch probability from block A to + // block B, and BlockFreq(A) to represent the block frequency of block A. + // So that the new P(PredBB->RealDest) = P(PredBB->BB) * P(BB->RealDest) + // Because original block frequency of RealDest is + // BlockFreq(RealDest) + // = BlockFreq(BB) * P(BB->RealDest) + eps1 + // = (BlockFreq(PredBB) * P(PredBB->BB) + eps2) * P(BB->RealDest) + eps1 + // = BlockFreq(PredBB) * P(PredBB->BB) * P(BB->RealDest) + eps + // Assuming that there are no other paths from PredBB to RealDest + + // BITakenWeight is the branch weight of BB->RealDest, which is the one we + // used in folding + uint64_t BITakenWeight, BINotTakenWeight; + bool BIHasWeights = + extractBranchWeights(*BI, BITakenWeight, BINotTakenWeight); + // Only update when BI has branch weights + if (BIHasWeights) { + // Swap the values if BB->RealDest is the false branch, because we + // assumed that is true branch when calling extractBranchWeights + if (!CB->getZExtValue()) + std::swap(BITakenWeight, BINotTakenWeight); + FixupBranchWeight(EdgeBB, PredBBs, BITakenWeight, BINotTakenWeight); + } + // Update PHI nodes. AddPredecessorToBlock(RealDest, EdgeBB, BB); Index: llvm/test/Transforms/SimplifyCFG/fold-cond-bi-fix-weight-overflow.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimplifyCFG/fold-cond-bi-fix-weight-overflow.ll @@ -0,0 +1,98 @@ +; RUN: opt -S -simplifycfg < %s | FileCheck %s + +declare void @foo() + +;; This test is for checking there is no overflow issue when fixing branch_weights +;; in FoldCondBranchOnValueKnownInPredecessor +;; which is related to patch [https://reviews.llvm.org/D131287] + + +;; Create a test case that will definitely overflow when doing naive add/mul, +;; by making branch_weights which used in the calculation to be UINT32_MAX. + +;; The CFG before transformation is as follows: +;; +;; entry +;; | \ \ \ \ 25% each (the branch weights are all UINT_MAX) +;; 0% | \ \ \ \ +;; | \ \ \ \ +;; | mid +;; | / +;; | / +;; | / +;; key +;; | \ +;; 0% | \ 100% (this branch weight is UINT_MAX) +;; | \ +;; | func +;; | (call foo) +;; | / +;; | / +;; | / +;; end +;; +;; Actual probability of hitting foo is almost 100%. +;; +;; After the transformation, the CFG becomes as follows: +;; +;; entry +;; | \ \ \ \ ?% +;; ?% | \ \ \ \ +;; | \ \ \ \ +;; | func +;; | (call foo) +;; | / +;; | / +;; | / +;; end +;; +;; In the transformation, we will need to do the following: +;; Weights[i] += (TotW * Weights[i] / TotNTW) * (BNTW / BTW) +;; In this test, TotW is 4*UINT_MAX+1, Weights[i] is UINT_MAX, TotNTW is 4*UINT_MAX +;; BNTW is UINT_MAX, BTW is 1 + + +;; After calculation, the branch weights from entry is 0:a:a:a:a where a is UINT_MAX-3 +;; so the probability of hitting foo is still roughly the same and no overflow in the middle +define void @overflow_test(i32 %a) { +; CHECK-LABEL: @overflow_test( +; CHECK-NEXT: entry: +; CHECK-NEXT: switch i32 [[SA0:%.*]], label %end [ +; CHECK-NEXT: i32 2, label %mid +; CHECK-NEXT: i32 4, label %mid +; CHECK-NEXT: i32 64, label %mid +; CHECK-NEXT: i32 128, label %mid +; CHECK-NEXT: ], !prof [[PROF0:![0-9]+]] +; CHECK: mid: +; CHECK-NEXT: call void @foo() +; CHECK-NEXT: br label %end +; CHECK: end: +; CHECK-NEXT: ret void +entry: + switch i32 %a, label %key [ + i32 2, label %mid + i32 4, label %mid + i32 64, label %mid + i32 128, label %mid + ], !prof !0 + +mid: + br label %key + +key: + %c = phi i1 [ true, %entry ], [ false, %mid ] + br i1 %c, label %end, label %func, !prof !1 + +func: + call void @foo() + br label %end + +end: + ret void +} + +!0 = !{!"branch_weights", i32 1, i32 -1, i32 -1, i32 -1, i32 -1} +!1 = !{!"branch_weights", i32 1, i32 -1} + + +; CHECK: [[PROF0]] = !{!"branch_weights", i32 0, i32 -4, i32 -4, i32 -4, i32 -4} Index: llvm/test/Transforms/SimplifyCFG/fold-cond-bi-fix-weight.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimplifyCFG/fold-cond-bi-fix-weight.ll @@ -0,0 +1,313 @@ +; RUN: opt -S -simplifycfg < %s | FileCheck %s + +declare void @foo() +declare void @foo1() +declare void @foo2() +declare void @foo3() + +;; This test is for fixing branch_weights in FoldCondBranchOnValueKnownInPredecessor +;; which is related to patch [https://reviews.llvm.org/D131287] +;; Basically this patch is intended to check the final probability of hitting the +;; destination block is roughly the same as that before this transformation + +;; Testing cases when PredBB has no branchweight(all are initialised as 1) +;; The CFG before the transformation is like this: +;; +;; entry +;; | \ +;; | \ +;; | lor.rhs +;; | / +;; | / +;; | / +;; lor.end +;; / | +;; 10% / | 90% +;; / | +;; if.then | +;; (call foo)| +;; \ | +;; \ | +;; \ | +;; if.end +;; +;; Actual probability of hitting foo is 1/10(10%) +;; +;; After the transformation, the CFG becomes as follows: +;; +;; entry +;; / \ ?% +;; ?% / \ +;; / lor.rhs +;; / / | +;; / 10%/ | 90% +;; / / | +;; if.then | +;; (call foo) | +;; \ | +;; \ | +;; \ | +;; if.end +;; +;; The patch is used to solve which weights we ought to put on entry->if.then +;; +;; Without the patch, the branch weights are unchanged, so the weights of entry is 1:1, +;; then the probability of hitting foo is 11/20(55%) +;; +;; After the patch, the branch weights from entry are updated as 1:19, then the +;; probability of hitting foo is 29/200(14.5%), which is much closer to original one +define void @branchinst_noweight(i8 %a0, i8 %a1, i8 %a2) { +; CHECK-LABEL: @branchinst_noweight( +; CHECK-NEXT: entry: +; CHECK: br i1 [[COND:%.*]], label %if.then, label %if.end, !prof [[PROF0:![0-9]+]] +; CHECK: if.then: +; CHECK-NEXT: call void @foo() +; CHECK-NEXT: br label %if.end +; CHECK: if.end: +; CHECK-NEXT: ret void +entry: + %and0 = and i8 %a0, 127 + %c0 = icmp eq i8 %and0, 1 + %and1 = and i8 %a1, 127 + %c1 = icmp eq i8 %and1, 1 + %or = or i1 %c0, %c1 + br i1 %or, label %lor.end, label %lor.rhs + +lor.rhs: + %and2 = and i8 %a2, 127 + %c2 = icmp eq i8 %and2, 0 + br label %lor.end + +lor.end: + %c3 = phi i1 [ true, %entry ], [ %c2, %lor.rhs ] + br i1 %c3, label %if.then, label %if.end, !prof !0 + +if.then: + call void @foo() + br label %if.end + +if.end: + ret void +} + +;; Testing cases when PredBB already has branch weight +;; The CFG is almost the same as the first test case, excepting the branches +;; from entry already has branch weights themselves: +;; +;; entry +;; | \30% +;; 70%| \ +;; | lor.rhs +;; | / +;; | / +;; | / +;; lor.end +;; / | +;; 10% / | 90% +;; / | +;; if.then | +;; (call foo)| +;; \ | +;; \ | +;; \ | +;; if.end +;; +;; Actual probability of hitting foo is 1/10(10%) +;; Without the patch, the branch weights are unchanged, so the weights of entry is 7:3, +;; then the probability of hitting foo is 73/100(73%) +;; +;; After the patch, the branch weights from entry are updated as 7:93, then the +;; probability of hitting foo is 163/1000(16.3%), which is much closer to original one +define void @branchinst_withweight(i8 %a0, i8 %a1, i8 %a2) { +; CHECK-LABEL: @branchinst_withweight( +; CHECK-NEXT: entry: +; CHECK: br i1 [[COND:%.*]], label %if.then, label %if.end, !prof [[PROF1:![0-9]+]] +; CHECK: if.then: +; CHECK-NEXT: call void @foo() +; CHECK-NEXT: br label %if.end +; CHECK: if.end: +; CHECK-NEXT: ret void +entry: + %and0 = and i8 %a0, 127 + %c0 = icmp eq i8 %and0, 1 + %and1 = and i8 %a1, 127 + %c1 = icmp eq i8 %and1, 1 + %or = or i1 %c0, %c1 + br i1 %or, label %lor.end, label %lor.rhs, !prof !1 + +lor.rhs: + %and2 = and i8 %a2, 127 + %c2 = icmp eq i8 %and2, 0 + br label %lor.end + +lor.end: + %c3 = phi i1 [ true, %entry ], [ %c2, %lor.rhs ] + br i1 %c3, label %if.then, label %if.end, !prof !0 + +if.then: + call void @foo() + br label %if.end + +if.end: + ret void +} + + +;; Testing cases when the terminator of PredBB is SwitchInst +;; The CFG before transformation is as follows: +;; +;; entry +;; | \ \ \ \ 1% each +;; 96% | \ \ \ \ +;; | \ \ \ \ +;; | mid +;; | / +;; | / +;; | / +;; key +;; | \ +;; 10% | \ 90% +;; | \ +;; | func +;; | (call foo) +;; | / +;; | / +;; | / +;; end +;; +;; Actual probability of hitting foo is 9/10(90%) +;; +;; After the transformation, the CFG becomes as follows: +;; +;; entry +;; | \ \ \ \ ?% +;; ?% | \ \ \ \ +;; | \ \ \ \ +;; | func +;; | (call foo) +;; | / +;; | / +;; | / +;; end +;; +;; Without the patch, the branch weights are unchanged, so the weights of entry is 96:1:1:1:1, +;; then the probability of hitting foo is 4/100(4%) +;; +;; After the patch, the branch weights from entry are updated as 96:226:226:226:226, then the +;; probability of hitting foo is 904/1000(90.4%), which is much closer to original one +define void @switch_test(i32 %a) { +; CHECK-LABEL: @switch_test( +; CHECK-NEXT: entry: +; CHECK-NEXT: switch i32 [[SA1:%.*]], label %end [ +; CHECK-NEXT: i32 2, label %mid +; CHECK-NEXT: i32 4, label %mid +; CHECK-NEXT: i32 64, label %mid +; CHECK-NEXT: i32 128, label %mid +; CHECK-NEXT: ], !prof [[PROF2:![0-9]+]] +; CHECK: mid: +; CHECK-NEXT: call void @foo() +; CHECK-NEXT: br label %end +; CHECK: end: +; CHECK-NEXT: ret void +entry: + switch i32 %a, label %key [ + i32 2, label %mid + i32 4, label %mid + i32 64, label %mid + i32 128, label %mid + ], !prof !2 + +mid: + br label %key + +key: + %c = phi i1 [ true, %entry ], [ false, %mid ] + br i1 %c, label %end, label %func, !prof !0 + +func: + call void @foo() + br label %end + +end: + ret void +} + + +;; Create a little bit mixed test case +;; The CFG after transformation is as follows: +;; +;; entry +;; / | \ +;; ?% / ?%| \?% +;; / | \ +;; | func2 func3 +;; | (foo2) (foo3) +;; | / / +;; | | / +;; \ | / +;; \ | / +;; end +;; (foo1) +;; +;; original probability is: func1: 100%, func2: 4%, func3: 8.96% +;; Without the patch, they are: func1: 100%, func2: 30%, func3: 50% +;; After the patch, they are: func1: 100%, func2: 1.4%, func3: 1.5% +define void @more_complex_test(i32 %a) { +; CHECK-LABEL: @more_complex_test( +; CHECK-NEXT: entry: +; CHECK-NEXT: switch i32 [[SA3:%.*]], label %end [ +; CHECK-NEXT: i32 2, label %[[FUNC2:.*]] +; CHECK-NEXT: i32 4, label %[[FUNC3:.*]] +; CHECK-NEXT: ], !prof [[PROF3:![0-9]+]] +; CHECK: [[FUNC2]]: +; CHECK-NEXT: call void @foo2() +; CHECK-NEXT: br label %end +; CHECK: [[FUNC3]]: +; CHECK-NEXT: call void @foo3() +; CHECK-NEXT: br label %end +; CHECK: end: +; CHECK-NEXT: call void @foo1() +; CHECK-NEXT: ret void +entry: + switch i32 %a, label %end [ + i32 2, label %mid1 + i32 4, label %mid2 + i32 8, label %mid3 + ], !prof !3 + +mid3: + br label %mid1 + +mid1: + %c0 = phi i1 [ true, %entry ], [ false, %mid3 ] + br i1 %c0, label %func2, label %mid2, !prof !4 + +func2: + call void @foo2() + br label %end + +mid2: + %c1 = phi i1 [ true, %entry ], [ false, %mid1 ] + br i1 %c1, label %func3, label %end, !prof !4 + +func3: + call void @foo3() + br label %end + +end: + call void @foo1() + ret void +} + + +!0 = !{!"branch_weights", i32 1, i32 9} +!1 = !{!"branch_weights", i32 7, i32 3} +!2 = !{!"branch_weights", i32 96, i32 1, i32 1, i32 1, i32 1} +!3 = !{!"branch_weights", i32 1, i32 3, i32 5, i32 1} +!4 = !{!"branch_weights", i32 1, i32 99} + + +; CHECK: [[PROF0]] = !{!"branch_weights", i32 29, i32 171} +; CHECK: [[PROF1]] = !{!"branch_weights", i32 163, i32 837} +; CHECK: [[PROF2]] = !{!"branch_weights", i32 96, i32 226, i32 226, i32 226, i32 226} +; CHECK: [[PROF3]] = !{!"branch_weights", i32 87914, i32 1290, i32 698}