diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -1391,9 +1391,20 @@ IsExported = true; if (CSInfo.AllCallSitesDevirted) return; + + std::map CallBases; for (auto &&VCallSite : CSInfo.CallSites) { CallBase &CB = VCallSite.CB; + if (CallBases.find(&CB) != CallBases.end()) { + // When finding devirtualizable calls, it's possible to find the same + // vtable passed to multiple llvm.type.test or llvm.type.checked.load + // calls, which can cause duplicate call sites to be recorded in + // [Const]CallSites. If we've already found one of these + // call instances, just ignore it. It will be replaced later. + continue; + } + // Jump tables are only profitable if the retpoline mitigation is enabled. Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features"); if (!FSAttr.isValid() || @@ -1440,8 +1451,7 @@ AttributeList::get(M.getContext(), Attrs.getFnAttrs(), Attrs.getRetAttrs(), NewArgAttrs)); - CB.replaceAllUsesWith(NewCS); - CB.eraseFromParent(); + CallBases[&CB] = NewCS; // This use is no longer unsafe. if (VCallSite.NumUnsafeUses) @@ -1451,6 +1461,11 @@ // retpoline mitigation, which would mean that they are lowered to // llvm.type.test and therefore require an llvm.type.test resolution for the // type identifier. + + std::for_each(CallBases.begin(), CallBases.end(), [](auto &CBs) { + CBs.first->replaceAllUsesWith(CBs.second); + CBs.first->eraseFromParent(); + }); }; Apply(SlotInfo.CSInfo); for (auto &P : SlotInfo.ConstCSInfo) diff --git a/llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll b/llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll --- a/llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll +++ b/llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll @@ -233,6 +233,54 @@ ret i32 %result } +; CHECK-LABEL: define i32 @fn4 +; CHECK-NOT: call void (...) @llvm.icall.branch.funnel +define i32 @fn4(ptr %obj) #0 { + %p = call i1 @llvm.type.test(ptr @vt1_1, metadata !"typeid1") + call void @llvm.assume(i1 %p) + %fptr = load ptr, ptr @vt1_1 + ; RETP: call i32 @__typeid_typeid1_0_branch_funnel(ptr nest @vt1_1, ptr %obj, i32 1) + %result = call i32 %fptr(ptr %obj, i32 1) + ; NORETP: call i32 % + ret i32 %result +} + +; CHECK-LABEL: define i32 @fn4_cpy +; CHECK-NOT: call void (...) @llvm.icall.branch.funnel +define i32 @fn4_cpy(ptr %obj) #0 { + %p = call i1 @llvm.type.test(ptr @vt1_1, metadata !"typeid1") + call void @llvm.assume(i1 %p) + %fptr = load ptr, ptr @vt1_1 + ; RETP: call i32 @__typeid_typeid1_0_branch_funnel(ptr nest @vt1_1, ptr %obj, i32 1) + %result = call i32 %fptr(ptr %obj, i32 1) + ; NORETP: call i32 % + ret i32 %result +} + +; CHECK-LABEL: define i32 @fn4_rv +; CHECK-NOT: call void (...) @llvm.icall.branch.funnel +define i32 @fn4_rv(ptr %obj) #0 { + %p = call i1 @llvm.type.test(ptr @vt1_1_rv, metadata !"typeid1_rv") + call void @llvm.assume(i1 %p) + %fptr = call ptr @llvm.load.relative.i32(ptr @vt1_1_rv, i32 0) + ; RETP: call i32 @__typeid_typeid1_rv_0_branch_funnel(ptr nest @vt1_1_rv, ptr %obj, i32 1) + %result = call i32 %fptr(ptr %obj, i32 1) + ; NORETP: call i32 % + ret i32 %result +} + +; CHECK-LABEL: define i32 @fn4_rv_cpy +; CHECK-NOT: call void (...) @llvm.icall.branch.funnel +define i32 @fn4_rv_cpy(ptr %obj) #0 { + %p = call i1 @llvm.type.test(ptr @vt1_1_rv, metadata !"typeid1_rv") + call void @llvm.assume(i1 %p) + %fptr = call ptr @llvm.load.relative.i32(ptr @vt1_1_rv, i32 0) + ; RETP: call i32 @__typeid_typeid1_rv_0_branch_funnel(ptr nest @vt1_1_rv, ptr %obj, i32 1) + %result = call i32 %fptr(ptr %obj, i32 1) + ; NORETP: call i32 % + ret i32 %result +} + ; CHECK-LABEL: define hidden void @__typeid_typeid1_0_branch_funnel(ptr nest %0, ...) ; CHECK-NEXT: musttail call void (...) @llvm.icall.branch.funnel(ptr %0, ptr {{(nonnull )?}}@vt1_1, ptr {{(nonnull )?}}@vf1_1, ptr {{(nonnull )?}}@vt1_2, ptr {{(nonnull )?}}@vf1_2, ...)