This is an archive of the discontinued LLVM Phabricator instance.

[llvm] Handle duplicate call bases when applying branch funneling
ClosedPublic

Authored by leonardchan on Mar 16 2023, 4:50 PM.

Details

Summary

It's possible to segfault in DevirtModule::applyICallBranchFunnel when attempting to call getCaller on a call base that was erased in a prior iteration. This can occur when attempting to find devirtualizable calls via findDevirtualizableCallsForTypeTest if the vtable passed to llvm.type.test is a global and not a local. The function works by taking the first argument of the llvm.type.test call (which is a vtable), iterating through all uses of it, and adding any relevant all uses that are calls associated with that intrinsic call to a vector. For most cases where the vtable is actually a *local*, this wouldn't be an issue. Take for example:

define i32 @fn(ptr %obj) #0 {                                                                                                                                                                                                                                                            
  %vtable = load ptr, ptr %obj                                                                                                                                                                                                                                                            
  %p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid2")                                                                                                                                                                                                                          
  call void @llvm.assume(i1 %p)                                                                                                                                                                                                                                                           
  %fptr = load ptr, ptr %vtable                                                                                                                                                                                                                                                           
  %result = call i32 %fptr(ptr %obj, i32 1)                                                                                                                                                                                                                                               
  ret i32 %result                                                                                                                                                                                                                                                                         
}

findDevirtualizableCallsForTypeTest will check the call base %result = call i32 %fptr(ptr %obj, i32 1), find that it is associated with a virtualizable call from %vtable, find all loads for %vtable, and add any instances those load results are called into a vector. Now consider the case where instead %vtable was the global itself rather than a local:

define i32 @fn(ptr %obj) #0 {                                                                                                                                                                                                                                                            
  %p = call i1 @llvm.type.test(ptr @vtable, metadata !"typeid2")                                                                                                                                                                                                                          
  call void @llvm.assume(i1 %p)                                                                                                                                                                                                                                                           
  %fptr = load ptr, ptr @vtable                                                                                                                                                                                                                                                           
  %result = call i32 %fptr(ptr %obj, i32 1)                                                                                                                                                                                                                                               
  ret i32 %result                                                                                                                                                                                                                                                                         
}

findDevirtualizableCallsForTypeTest should work normally and add one unique call instance to a vector. However, if there are multiple instances where this same global is used for llvm.type.test, like with:

define i32 @fn(ptr %obj) #0 {                                                                                                                                                                                                                                                            
  %p = call i1 @llvm.type.test(ptr @vtable, metadata !"typeid2")                                                                                                                                                                                                                          
  call void @llvm.assume(i1 %p)                                                                                                                                                                                                                                                           
  %fptr = load ptr, ptr @vtable                                                                                                                                                                                                                                                           
  %result = call i32 %fptr(ptr %obj, i32 1)                                                                                                                                                                                                                                               
  ret i32 %result
}

define i32 @fn2(ptr %obj) #0 {                                                                                                                                                                                                                                                            
  %p = call i1 @llvm.type.test(ptr @vtable, metadata !"typeid2")                                                                                                                                                                                                                          
  call void @llvm.assume(i1 %p)                                                                                                                                                                                                                                                           
  %fptr = load ptr, ptr @vtable                                                                                                                                                                                                                                                           
  %result = call i32 %fptr(ptr %obj, i32 1)                                                                                                                                                                                                                                               
  ret i32 %result
}

Then each call base %result = call i32 %fptr(ptr %obj, i32 1) will be added to the vector twice. This is because for either call base %result = call i32 %fptr(ptr %obj, i32 1) , we determine it is associated with a virtualizable call from @vtable, and then we iterate through all the uses of @vtable, which is used across multiple functions. So when scanning the first %result = call i32 %fptr(ptr %obj, i32 1), then both call bases will be added to the vector, but when scanning the second one, both call bases are added again, resulting in duplicate call bases in the CSInfo.CallSites vector.

Note this is actually accounted for in every other instance WPD iterates over CallSites. What everything else does is actually add the call base to the OptimizedCalls set and just check if it's already in the set. We can't reuse that particular set since it serves a different purpose marking which calls where devirtualized which applyICallBranchFunnel explicitly says it doesn't. For this fix, we can just account for duplicates with a map and do the actual replacements afterwards by iterating over the map.

Diff Detail

Event Timeline

leonardchan created this revision.Mar 16 2023, 4:50 PM
Herald added a project: Restricted Project. · View Herald TranscriptMar 16 2023, 4:50 PM
leonardchan requested review of this revision.Mar 16 2023, 4:50 PM

@tejohnson There was an old comment you left in D134320 asking why the check was warranted. We managed to reproduce the segfault and hopefully this description answers your initial comment.

tejohnson added inline comments.Mar 17 2023, 10:01 AM
llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
1399

I don't think this is safe since when we eraseFromParent the instruction is deleted. Can we track a different way?

any overlap with D104798?

any overlap with D104798?

Yeah this this looks pretty similar. Opting to use the set approach you use there.

llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
1399

Updated to just add erased pointers to a set and check on future iterations if they're in the set. This way we don't have to deref them.

aeubanks added inline comments.Mar 21 2023, 4:48 PM
llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
1399

if the instruction is erased, can there be instructions created after that share the same pointer?

leonardchan marked an inline comment as not done.Mar 21 2023, 5:07 PM
leonardchan added inline comments.
llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
1399

Hmm, I would think yes, but I think for this specific case it doesn't matter since none of the newly created instructions in this loop get added back to CSInfo.CallSites so each of the CBs should only refer to CBs that were added earlier.

tejohnson added inline comments.Mar 23 2023, 1:41 PM
llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
1399

While I think this would work, given what you mentioned above, the following is I think preferable from a safety/clarity perspective: Change CallBases to a map from the orig CallBase to the new one. Keep this bit of code the same so it skips any repeats. After the loop walk the map and do all of the replacing / erasing. Would that work?

leonardchan added inline comments.
llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
1399

Should be fine. Updated.

tejohnson accepted this revision.Mar 23 2023, 2:34 PM

Lgtm. Patch description looks like it needs an update before you push.

This revision is now accepted and ready to land.Mar 23 2023, 2:34 PM
leonardchan edited the summary of this revision. (Show Details)Mar 23 2023, 2:42 PM
This revision was landed with ongoing or failed builds.Mar 23 2023, 2:48 PM
This revision was automatically updated to reflect the committed changes.