diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -875,6 +875,18 @@ I->setMetadata(LLVMContext::MD_prof, N); } +// A general interface to set branch_weights for SwitchInst/BranchInst/SelInst +static void setBranchWeights(Instruction *I, ArrayRef Weights) { + assert(isa(I) || isa(I) || isa(I)); + if (isa(I)) { + SwitchInst *SI = dyn_cast(I); + setBranchWeights(SI, Weights); + } else { + assert(Weights.size() == 2); + setBranchWeights(I, Weights[0], Weights[1]); + } +} + /// If TI is known to be a terminator instruction and its block is known to /// only have a single predecessor block, check to see if that predecessor is /// also a value comparison with the same value, and if that comparison @@ -1084,6 +1096,16 @@ } } +/// Keep halving the weights until the sum of weights can fit in uint32_t. +static void FitTotalWeights(MutableArrayRef Weights) { + uint64_t Sum = std::accumulate(Weights.begin(), Weights.end(), (uint64_t)0); + if (Sum > UINT_MAX) { + unsigned Offset = 32 - countLeadingZeros(Sum); + for (uint64_t &I : Weights) + I >>= Offset; + } +} + static void CloneInstructionsIntoPredecessorBlockAndUpdateSSAUses( BasicBlock *BB, BasicBlock *PredBlock, ValueToValueMapTy &VMap) { Instruction *PTI = PredBlock->getTerminator(); @@ -2973,6 +2995,94 @@ return nullptr; } +/// We want to make the new branch probability of PTI taken(from PredBB to BB) +/// to be original one times the probability of BI taken(from BB to Dest), +/// one simple way is to scale all the PTI not_taken branch weight by a factor +/// X. By calculation, X should be the formula below: +/// X = 1 + (BNTW/BTW) * (totW/totNTW) +/// in which "BNTW" is BINotTakenWeight and "BTW" is BITakenWeight, "totW" is +/// the sum of branch weights of PTI, "totNTW" is the sum of all PTI not_taken +/// branch weight(which is totW minus sum of PTI taken branches) +/// The value will not overflow if totW is in uint32_t range(see details below) +static void ScaleWeights(uint64_t BITakenWeight, uint64_t BINotTakenWeight, + SmallVector &PTIWeights, + const SmallDenseSet &PTITakenIndexes) { + uint64_t PTITotalWeight = + std::accumulate(PTIWeights.begin(), PTIWeights.end(), (uint64_t)0); + assert(PTITotalWeight <= UINT32_MAX); + uint64_t PTITotalTakenWeight = 0; + for (unsigned I : PTITakenIndexes) + PTITotalTakenWeight += PTIWeights[I]; + + for (unsigned I = 0, E = PTIWeights.size(); I != E; I++) { + if (PTITakenIndexes.find(I) != PTITakenIndexes.end()) + continue; + + // Small adjustment by doing add below, + // so newWeights[I] = Weights[I] + (totW*Weights[I]/totNTW) * (BNTW/BTW) + + // Branch "I" is also one of the PTI not_taken branch and that means + // Weights[I]/totNTW <= 1, thus Inc is at most totW and at most UINT_MAX. + // Add max(1,w) because nothing prevents weight to be 0 + uint64_t Inc = PTITotalWeight * PTIWeights[I] / + std::max(1UL, (PTITotalWeight - PTITotalTakenWeight)); + + // The Inc result here is at most UINT32_MAX ^ 2 + Inc = (Inc * BINotTakenWeight) / std::max(1UL, BITakenWeight); + + // UINT64_MAX = UINT32_MAX ^ 2 + 2 * UINT32_MAX, and here Weights[I] is + // at most UINT32_MAX ^ 2 + UINT32_MAX and thus will not exceed UINT64_MAX + PTIWeights[I] += Inc; + } +} + +static void FixupBranchWeight(BasicBlock *BB, + const ArrayRef &PredBBs, + uint64_t BITakenWeight, + uint64_t BINotTakenWeight) { + for (unsigned i = 0, e = PredBBs.size(); i != e; ++i) { + BasicBlock *PredBB = PredBBs[i]; + Instruction *PTI = PredBB->getTerminator(); + // We only need to fix branch weights for BranchInst/SwitchInst + SmallVector PTIWeights; + if (BranchInst *PBI = dyn_cast(PTI)) { + if (PBI->isUnconditional()) + continue; + uint64_t PBITrue, PBIFalse; + bool PBIHasWeights = PBI->extractProfMetadata(PBITrue, PBIFalse); + if (!PBIHasWeights) + PBITrue = PBIFalse = 1; + PTIWeights.push_back(PBITrue); + PTIWeights.push_back(PBIFalse); + } else if (SwitchInst *PSI = dyn_cast(PTI)) { + if (HasBranchWeights(PSI)) + GetBranchWeights(PSI, PTIWeights); + else + PTIWeights.assign(1 + PSI->getNumCases(), 1); + } else + continue; + + assert(PTIWeights.size() == PTI->getNumSuccessors()); + SmallVector PTIOriginalWeights = PTIWeights; + FitTotalWeights(PTIWeights); // required by ScaleWeights above + + SmallDenseSet PTITakenIndexes; + for (unsigned I = 0, E = PTIWeights.size(); I != E; I++) + if (PTI->getSuccessor(I) == BB) + PTITakenIndexes.insert(I); + assert(!PTITakenIndexes.empty()); + + ScaleWeights(BITakenWeight, BINotTakenWeight, PTIWeights, PTITakenIndexes); + FitWeights(PTIWeights); + if (PTIWeights != PTIOriginalWeights) { + // only set when the new weights is different from original one + SmallVector PTIMDWeights(PTIWeights.begin(), + PTIWeights.end()); + setBranchWeights(PTI, PTIMDWeights); + } + } +} + /// If we have a conditional branch on something for which we know the constant /// value in predecessors (e.g. a phi node in the current block), thread edges /// from the predecessor to their ultimate destination. @@ -3042,6 +3152,27 @@ EdgeBB->setName(RealDest->getName() + ".critedge"); EdgeBB->moveBefore(RealDest); + // After redirecting PredBB to RealDest, we also need to update the branch + // weight of PTI(the terminator of PredBB). + // Let's use P(A->B) to represent the branch probability from block A to + // block B, and BlockFreq(A) to represent the block frequency of block A. + // So that the new P(PredBB->RealDest) = P(PredBB->BB) * P(BB->RealDest) + // Because original block frequency of RealDest is + // BlockFreq(RealDest) + // = BlockFreq(BB) * P(BB->RealDest) + eps1 + // = (BlockFreq(PredBB) * P(PredBB->BB) + eps2) * P(BB->RealDest) + eps1 + // = BlockFreq(PredBB) * P(PredBB->BB) * P(BB->RealDest) + eps + // Assuming that there are no other paths from PredBB to RealDest + uint64_t BITakenWeight, BINotTakenWeight; + bool BIHasWeights = + BI->extractProfMetadata(BITakenWeight, BINotTakenWeight); + // Only update when BI has branch weights + if (BIHasWeights) { + if (!CB->getZExtValue()) + std::swap(BITakenWeight, BINotTakenWeight); + FixupBranchWeight(EdgeBB, PredBBs, BITakenWeight, BINotTakenWeight); + } + // Update PHI nodes. AddPredecessorToBlock(RealDest, EdgeBB, BB); diff --git a/llvm/test/Transforms/SimplifyCFG/fold-cond-bi-fix-weight.ll b/llvm/test/Transforms/SimplifyCFG/fold-cond-bi-fix-weight.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/SimplifyCFG/fold-cond-bi-fix-weight.ll @@ -0,0 +1,236 @@ +; RUN: opt -S -simplifycfg < %s | FileCheck %s + +declare void @foo() +declare void @foo1() +declare void @foo2() +declare void @foo3() + +;; This test is for fixing branch_weights in FoldCondBranchOnValueKnownInPredecessor + +;; Testing cases when PredBB has no branchweight(all are initialised as 1) +;; Actual probability of hitting foo is 1/10(10%) +;; Before fix, the probability of hitting foo is 11/20(55%) +;; After fix, the probability of hitting foo is 29/200(14.5%) +define void @branchinst_noweight(i8 %a0, i8 %a1, i8 %a2) { +; CHECK-LABEL: @branchinst_noweight( +; CHECK-NEXT: entry: +; CHECK: br i1 [[COND:%.*]], label %if.then, label %if.end, !prof [[PROF0:![0-9]+]] +; CHECK: if.then: +; CHECK-NEXT: call void @foo() +; CHECK-NEXT: br label %if.end +; CHECK: if.end: +; CHECK-NEXT: ret void +entry: + %and0 = and i8 %a0, 127 + %c0 = icmp eq i8 %and0, 1 + %and1 = and i8 %a1, 127 + %c1 = icmp eq i8 %and1, 1 + %or = or i1 %c0, %c1 + br i1 %or, label %lor.end, label %lor.rhs + +lor.rhs: + %and2 = and i8 %a2, 127 + %c2 = icmp eq i8 %and2, 0 + br label %lor.end + +lor.end: + %c3 = phi i1 [ true, %entry ], [ %c2, %lor.rhs ] + br i1 %c3, label %if.then, label %if.end, !prof !0 + +if.then: + call void @foo() + br label %if.end + +if.end: + ret void +} + +;; Testing cases when PredBB already has branch weight +;; Actual probability of hitting foo is 1/10(10%) +;; Before fix, the probability of hitting foo is 73/100(73%) +;; After fix, the probability of hitting foo is 163/1000(16.3%) +define void @branchinst_withweight(i8 %a0, i8 %a1, i8 %a2) { +; CHECK-LABEL: @branchinst_withweight( +; CHECK-NEXT: entry: +; CHECK: br i1 [[COND:%.*]], label %if.then, label %if.end, !prof [[PROF1:![0-9]+]] +; CHECK: if.then: +; CHECK-NEXT: call void @foo() +; CHECK-NEXT: br label %if.end +; CHECK: if.end: +; CHECK-NEXT: ret void +entry: + %and0 = and i8 %a0, 127 + %c0 = icmp eq i8 %and0, 1 + %and1 = and i8 %a1, 127 + %c1 = icmp eq i8 %and1, 1 + %or = or i1 %c0, %c1 + br i1 %or, label %lor.end, label %lor.rhs, !prof !1 + +lor.rhs: + %and2 = and i8 %a2, 127 + %c2 = icmp eq i8 %and2, 0 + br label %lor.end + +lor.end: + %c3 = phi i1 [ true, %entry ], [ %c2, %lor.rhs ] + br i1 %c3, label %if.then, label %if.end, !prof !0 + +if.then: + call void @foo() + br label %if.end + +if.end: + ret void +} + + +;; Testing cases when the terminator of PredBB is SwitchInst +;; Actual probability of hitting foo is 9/10(90%) +;; Before fix, the probability of hitting foo is 4/100(4%) +;; After fix, the probability of hitting foo is 904/1000(90.4%) +define void @switch_test(i32 %a) { +; CHECK-LABEL: @switch_test( +; CHECK-NEXT: entry: +; CHECK-NEXT: switch i32 [[SA1:%.*]], label %end [ +; CHECK-NEXT: i32 2, label %mid +; CHECK-NEXT: i32 4, label %mid +; CHECK-NEXT: i32 64, label %mid +; CHECK-NEXT: i32 128, label %mid +; CHECK-NEXT: ], !prof [[PROF2:![0-9]+]] +; CHECK: mid: +; CHECK-NEXT: call void @foo() +; CHECK-NEXT: br label %end +; CHECK: end: +; CHECK-NEXT: ret void +entry: + switch i32 %a, label %key [ + i32 2, label %mid + i32 4, label %mid + i32 64, label %mid + i32 128, label %mid + ], !prof !2 + +mid: + br label %key + +key: + %c = phi i1 [ true, %entry ], [ false, %mid ] + br i1 %c, label %end, label %func, !prof !0 + +func: + call void @foo() + br label %end + +end: + ret void +} + + +;; Create a test case that will definitely overflow when doing naive add/mul, +;; by making branch_weights which used in the calculation to be UINT32_MAX. +;; After calculation, the probability of hitting foo is roughly the same. +define void @overflow_test(i32 %a) { +; CHECK-LABEL: @overflow_test( +; CHECK-NEXT: entry: +; CHECK-NEXT: switch i32 [[SA2:%.*]], label %end [ +; CHECK-NEXT: i32 2, label %mid +; CHECK-NEXT: i32 4, label %mid +; CHECK-NEXT: i32 64, label %mid +; CHECK-NEXT: i32 128, label %mid +; CHECK-NEXT: ], !prof [[PROF3:![0-9]+]] +; CHECK: mid: +; CHECK-NEXT: call void @foo() +; CHECK-NEXT: br label %end +; CHECK: end: +; CHECK-NEXT: ret void +entry: + switch i32 %a, label %key [ + i32 2, label %mid + i32 4, label %mid + i32 64, label %mid + i32 128, label %mid + ], !prof !3 + +mid: + br label %key + +key: + %c = phi i1 [ true, %entry ], [ false, %mid ] + br i1 %c, label %end, label %func, !prof !4 + +func: + call void @foo() + br label %end + +end: + ret void +} + + +;; Create a little bit mixed test case +;; original probability is: +;; func1: 100%, func2: 4%, func3: 8.96% +;; before fix, they are: func1: 100%, func2: 30%, func3: 50% +;; after fix, they are: func1: 100%, func2: 1.4%, func3: 1.5% +define void @more_complex_test(i32 %a) { +; CHECK-LABEL: @more_complex_test( +; CHECK-NEXT: entry: +; CHECK-NEXT: switch i32 [[SA3:%.*]], label %end [ +; CHECK-NEXT: i32 2, label %[[FUNC2:.*]] +; CHECK-NEXT: i32 4, label %[[FUNC3:.*]] +; CHECK-NEXT: ], !prof [[PROF4:![0-9]+]] +; CHECK: [[FUNC2]]: +; CHECK-NEXT: call void @foo2() +; CHECK-NEXT: br label %end +; CHECK: [[FUNC3]]: +; CHECK-NEXT: call void @foo3() +; CHECK-NEXT: br label %end +; CHECK: end: +; CHECK-NEXT: call void @foo1() +; CHECK-NEXT: ret void +entry: + switch i32 %a, label %end [ + i32 2, label %mid1 + i32 4, label %mid2 + i32 8, label %mid3 + ], !prof !5 + +mid3: + br label %mid1 + +mid1: + %c0 = phi i1 [ true, %entry ], [ false, %mid3 ] + br i1 %c0, label %func2, label %mid2, !prof !6 + +func2: + call void @foo2() + br label %end + +mid2: + %c1 = phi i1 [ true, %entry ], [ false, %mid1 ] + br i1 %c1, label %func3, label %end, !prof !6 + +func3: + call void @foo3() + br label %end + +end: + call void @foo1() + ret void +} + + +!0 = !{!"branch_weights", i32 1, i32 9} +!1 = !{!"branch_weights", i32 7, i32 3} +!2 = !{!"branch_weights", i32 96, i32 1, i32 1, i32 1, i32 1} +!3 = !{!"branch_weights", i32 1, i32 -1, i32 -1, i32 -1, i32 -1} +!4 = !{!"branch_weights", i32 1, i32 -1} +!5 = !{!"branch_weights", i32 1, i32 3, i32 5, i32 1} +!6 = !{!"branch_weights", i32 1, i32 99} + + +; CHECK: [[PROF0]] = !{!"branch_weights", i32 29, i32 171} +; CHECK: [[PROF1]] = !{!"branch_weights", i32 163, i32 837} +; CHECK: [[PROF2]] = !{!"branch_weights", i32 96, i32 226, i32 226, i32 226, i32 226} +; CHECK: [[PROF3]] = !{!"branch_weights", i32 0, i32 -4, i32 -4, i32 -4, i32 -4} +; CHECK: [[PROF4]] = !{!"branch_weights", i32 87914, i32 1290, i32 698}