The patch is intended to fix this test case:
extern void bar(); void foo(char a, char b, char c) { char aa = 2 * a, bb = 2 * b, cc = 2 * c; if (__builtin_expect_with_probability(((aa == 2) || (bb == 2) || (cc == 0)), 1, 0)) { bar(); } }
The user defines the probability of calling function bar as 0%.
Before the second SimplifyCFG pass, the LLVM IR is as follows:
define dso_local void @foo(i8 noundef signext %0, i8 noundef signext %1, i8 noundef signext %2) local_unnamed_addr #0 { %4 = and i8 %0, 127 %5 = icmp eq i8 %4, 1 %6 = and i8 %1, 127 %7 = icmp eq i8 %6, 1 %8 = or i1 %5, %7 br i1 %8, label %12, label %9 9: ; preds = %3 %10 = and i8 %2, 127 %11 = icmp eq i8 %10, 0 br label %12 12: ; preds = %9, %3 %13 = phi i1 [ true, %3 ], [ %11, %9 ] br i1 %13, label %14, label %15, !prof !5 14: ; preds = %12 call void (...) @bar() #2 br label %15 15: ; preds = %14, %12 ret void } !5 = !{!"branch_weights", i32 1, i32 2147483647}
Here, the probability of calling function bar() is still 0%, since the probability of executing block12 is 100%, and the probability from block12 to block14 is 0%.
Then after the FoldCondBranchOnValueKnownInPredecessor pass, the entry block directly points to block14, and the IR becomes as follows:
define dso_local void @foo(i8 noundef signext %0, i8 noundef signext %1, i8 noundef signext %2) local_unnamed_addr #0 { %4 = and i8 %0, 127 %5 = icmp eq i8 %4, 1 %6 = and i8 %1, 127 %7 = icmp eq i8 %6, 1 %8 = or i1 %5, %7 br i1 %8, label %12, label %9 9: ; preds = %3 %10 = and i8 %2, 127 %11 = icmp eq i8 %10, 0 br i1 %11, label %12, label %13, !prof !5 12: ; preds = %3, %9 call void (...) @bar() #2 br label %13 13: ; preds = %12, %9 ret void } !5 = !{!"branch_weights", i32 1, i32 2147483647}
Both entry block and block9 have BranchInst and they both point to block12, so after the performBranchToCommonDestFolding pass, they are merged. And the branch weight is recalculated.
So the final IR is as follows:
define dso_local void @foo(i8 noundef signext %0, i8 noundef signext %1, i8 noundef signext %2) local_unnamed_addr #0 { %4 = and i8 %0, 127 %5 = icmp eq i8 %4, 1 %6 = and i8 %1, 127 %7 = icmp eq i8 %6, 1 %8 = or i1 %5, %7 %9 = and i8 %2, 127 %10 = icmp eq i8 %9, 0 %11 = or i1 %8, %10 br i1 %11, label %12, label %13, !prof !5 12: ; preds = %3 call void (...) @bar() #2 br label %13 13: ; preds = %3, %12 ret void } !5 = !{!"branch_weights", i32 -2147483647, i32 2147483647}
Now the probability of calling function bar() is over 50%, which is very different from what the user defined.
We believe the issue is caused by the branch weight not being updated in the FoldCondBranchOnValueKnownInPredecessor pass. The optimization that this pass does can be simplified to this case:
before: PredBB other2 \ / \/ other1 BB \ /\ \/ \ RealDest other3
The block execution frequency of RealDest here can be represented as:
Formula 1: BlockFreq(RealDest) = BlockFreq(BB) * P(BB->RealDest) + other1 = (BlockFreq(PredBB) * P(PredBB->BB) + other2) * P(BB->RealDest) + other1 = BlockFreq(PredBB) * P(PredBB->BB) * P(BB->RealDest) + other2 * P(BB->RealDest) + other1
in which P(A->B) represents the probability from blockA to blockB.
Then after the pass, the graph is updated to:
after: PredBB other2 | / | / other1 | BB \ | /\ \| / \ RealDest other3
PredBB is redirected to RealDest if the BranchInst in BB is always pointing to RealDest when the condition is from PredBB.
The new block execution frequency of RealDest is:
Formula 2: BlockFreq(RealDest) = BlockFreq(PredBB) * newP(PredBB->RealDest) + BlockFreq(BB) * P(BB->RealDest) + other1
in which BlockFreq(BB) is equal to other2 since PredBB no longer points to BB. So it can be represented as:
Formula 3: BlockFreq(RealDest) = BlockFreq(PredBB) * newP(PredBB->RealDest) + other2 * P(BB->RealDest) + other1
Thus, to match Formula 1 and Formula 3, we need to make newP(PredBB->RealDest) in the new graph equal to P(PredBB->BB) * P(BB->RealDest) in the original graph.
One way is to scale all the branch weights from PredBB but not to BB by multiplying a factor x.
Supposing:
(1) A means the sum of all branch weights from PredBB to BB, which is PTITotalTakenWeight in the source code
(2) B means sum of those from PredBB but not to BB, which is PTITotalWeight-PTITotalTakenWeight in the source code
(3) S means the sum of all branch weights from PredBB, which is PTITotalWeight in the source code and equal to A+B
(4) C means the branch weight from BB to RealDest, which is BITakenWeight in the source code
(5) D means the branch weight from BB to the other successor, which is BINotTakenWeight in the source code
So:
newP(PredBB->RealDest) = P(PredBB->BB) * P(BB->RealDest) => A/(A+x*B) = A/(A+B) * C/(C+D) => x = 1 + (S/B) * (D/C)
Then for each branch weight from PredBB but not to BB, we do this transformation:
NewWeight = OldWeight * x => NewWeight = OldWeight * (1 + (S/B) * (D/C)) => NewWeight = OldWeight + (S*OldWeight/B) * (D/C)
which will prevent overflow (see source code for detailed information)
Also, since this change may give branch weights to some BranchInst/SwitchInst that do not have them originally, we skip cases when just giving equal branch weight (default behavior) to them.
I don't think you need this helper function.
the setBranchWeights(SwitchInst*,...) version should be able to handle these cases if you just modify the signature to take Instruction* instead and add the assert below.
The reason being is that CreateBranchWeights() only offers a convenient api for branch/select that takes two weights to avoid writing boilerplate at callstites. Internally it just calls the version of CreateBranchWeights() that takes an ArrayRef.
In this case, I think in this case it’s better to follow the same approach and avoid introducing a redundant function implementation.