Page MenuHomePhabricator

Fix branch weight in FoldCondBranchOnValueKnownInPredecessor pass in SimplifyCFG
Needs ReviewPublic

Authored by LukeZhuang on Fri, Aug 5, 1:19 PM.

Details

Summary

The patch is intended to fix this test case:

extern void bar();
void foo(char a, char b, char c) {
  char aa = 2 * a, bb = 2 * b, cc = 2 * c;
  if (__builtin_expect_with_probability(((aa == 2) || (bb == 2) || (cc == 0)), 1, 0)) {
    bar();
  }
}

The user defines the probability of calling function bar as 0%.

Before the second SimplifyCFG pass, the LLVM IR is as follows:

define dso_local void @foo(i8 noundef signext %0, i8 noundef signext %1, i8 noundef signext %2) local_unnamed_addr #0 {
  %4 = and i8 %0, 127
  %5 = icmp eq i8 %4, 1
  %6 = and i8 %1, 127
  %7 = icmp eq i8 %6, 1
  %8 = or i1 %5, %7
  br i1 %8, label %12, label %9

9:                                                ; preds = %3
  %10 = and i8 %2, 127
  %11 = icmp eq i8 %10, 0
  br label %12

12:                                               ; preds = %9, %3
  %13 = phi i1 [ true, %3 ], [ %11, %9 ]
  br i1 %13, label %14, label %15, !prof !5

14:                                               ; preds = %12
  call void (...) @bar() #2
  br label %15

15:                                               ; preds = %14, %12
  ret void
}

!5 = !{!"branch_weights", i32 1, i32 2147483647}

Here, the probability of calling function bar() is still 0%, since the probability of executing block12 is 100%, and the probability from block12 to block14 is 0%.

Then after the FoldCondBranchOnValueKnownInPredecessor pass, the entry block directly points to block14, and the IR becomes as follows:

define dso_local void @foo(i8 noundef signext %0, i8 noundef signext %1, i8 noundef signext %2) local_unnamed_addr #0 {
  %4 = and i8 %0, 127
  %5 = icmp eq i8 %4, 1
  %6 = and i8 %1, 127
  %7 = icmp eq i8 %6, 1
  %8 = or i1 %5, %7
  br i1 %8, label %12, label %9

9:                                                ; preds = %3
  %10 = and i8 %2, 127
  %11 = icmp eq i8 %10, 0
  br i1 %11, label %12, label %13, !prof !5

12:                                               ; preds = %3, %9
  call void (...) @bar() #2
  br label %13

13:                                               ; preds = %12, %9
  ret void
}

!5 = !{!"branch_weights", i32 1, i32 2147483647}

Both entry block and block9 have BranchInst and they both point to block12, so after the performBranchToCommonDestFolding pass, they are merged. And the branch weight is recalculated.

So the final IR is as follows:

define dso_local void @foo(i8 noundef signext %0, i8 noundef signext %1, i8 noundef signext %2) local_unnamed_addr #0 {
  %4 = and i8 %0, 127
  %5 = icmp eq i8 %4, 1
  %6 = and i8 %1, 127
  %7 = icmp eq i8 %6, 1
  %8 = or i1 %5, %7
  %9 = and i8 %2, 127
  %10 = icmp eq i8 %9, 0
  %11 = or i1 %8, %10
  br i1 %11, label %12, label %13, !prof !5

12:                                               ; preds = %3
  call void (...) @bar() #2
  br label %13

13:                                               ; preds = %3, %12
  ret void
}

!5 = !{!"branch_weights", i32 -2147483647, i32 2147483647}

Now the probability of calling function bar() is over 50%, which is very different from what the user defined.

We believe the issue is caused by the branch weight not being updated in the FoldCondBranchOnValueKnownInPredecessor pass. The optimization that this pass does can be simplified to this case:

before:
  PredBB  other2
       \  /
        \/
other1  BB
     \  /\
      \/  \
 RealDest  other3

The block execution frequency of RealDest here can be represented as:

Formula 1:
  BlockFreq(RealDest)
= BlockFreq(BB) * P(BB->RealDest) + other1
= (BlockFreq(PredBB) * P(PredBB->BB) + other2) * P(BB->RealDest) + other1
= BlockFreq(PredBB) * P(PredBB->BB) * P(BB->RealDest) + other2 * P(BB->RealDest) + other1

in which P(A->B) represents the probability from blockA to blockB.

Then after the pass, the graph is updated to:

after:
  PredBB   other2
       |     /
       |    /
other1 |   BB
     \ |   /\
      \|  /  \
   RealDest    other3

PredBB is redirected to RealDest if the BranchInst in BB is always pointing to RealDest when the condition is from PredBB.
The new block execution frequency of RealDest is:

Formula 2:
  BlockFreq(RealDest)
= BlockFreq(PredBB) * newP(PredBB->RealDest) + BlockFreq(BB) * P(BB->RealDest) + other1

in which BlockFreq(BB) is equal to other2 since PredBB no longer points to BB. So it can be represented as:

Formula 3:
  BlockFreq(RealDest)
= BlockFreq(PredBB) * newP(PredBB->RealDest) + other2 * P(BB->RealDest) + other1

Thus, to match Formula 1 and Formula 3, we need to make newP(PredBB->RealDest) in the new graph equal to P(PredBB->BB) * P(BB->RealDest) in the original graph.

One way is to scale all the branch weights from PredBB but not to BB by multiplying a factor x.
Supposing:
(1) A means the sum of all branch weights from PredBB to BB, which is PTITotalTakenWeight in the source code
(2) B means sum of those from PredBB but not to BB, which is PTITotalWeight-PTITotalTakenWeight in the source code
(3) S means the sum of all branch weights from PredBB, which is PTITotalWeight in the source code and equal to A+B
(4) C means the branch weight from BB to RealDest, which is BITakenWeight in the source code
(5) D means the branch weight from BB to the other successor, which is BINotTakenWeight in the source code
So:

   newP(PredBB->RealDest) = P(PredBB->BB) * P(BB->RealDest)
=> A/(A+x*B) = A/(A+B) * C/(C+D)
=> x = 1 + (S/B) * (D/C)

Then for each branch weight from PredBB but not to BB, we do this transformation:

   NewWeight = OldWeight * x
=> NewWeight = OldWeight * (1 + (S/B) * (D/C))
=> NewWeight = OldWeight + (S*OldWeight/B) * (D/C)

which will prevent overflow (see source code for detailed information)

Also, since this change may give branch weights to some BranchInst/SwitchInst that do not have them originally, we skip cases when just giving equal branch weight (default behavior) to them.

Diff Detail

Event Timeline

LukeZhuang created this revision.Fri, Aug 5, 1:19 PM
Herald added a project: Restricted Project. · View Herald TranscriptFri, Aug 5, 1:19 PM
Herald added a subscriber: hiraditya. · View Herald Transcript
LukeZhuang requested review of this revision.Fri, Aug 5, 1:19 PM
lattner resigned from this revision.Fri, Aug 5, 3:56 PM

Hello, may I ask do anyone have comments on this change? or any suggestions of fixing this in better ways? Thanks!

nikic added a comment.Tue, Aug 9, 9:16 AM

This is a jump threading optimization -- is it possible to reuse the branch weight adjustment code from the JumpThreading pass? The updateBlockFreqAndEdgeWeight() method looks relevant.

This is a jump threading optimization -- is it possible to reuse the branch weight adjustment code from the JumpThreading pass? The updateBlockFreqAndEdgeWeight() method looks relevant.

Hi Nikic, thanks for the comments! Actually I have two questions about this:

  1. I just looked at the code, JumpThreading pass seems did the same thing as FoldCondBranchOnValueKnownInPredecessor in SimplifyCFG. May I ask why we need both of them originally? Or is this just my misunderstanding
  2. I previously also thinking of using BlockFrequencyInfo but I found it's not used and maintained in SimplifyCFG pass. Is that means we need to add it?