diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -569,6 +569,10 @@ // optimize a call more than once. SmallPtrSet OptimizedCalls; + // Store calls that had their ptrauth bundle removed. They are to be deleted + // at the end of the optimization. + SmallVector CallsWithPtrAuthBundleRemoved; + // This map keeps track of the number of "unsafe" uses of a loaded function // pointer. The key is the associated llvm.type.test intrinsic call generated // by this pass. An unsafe use is one that calls the loaded function pointer @@ -1158,6 +1162,15 @@ // !callees metadata. CB.setMetadata(LLVMContext::MD_prof, nullptr); CB.setMetadata(LLVMContext::MD_callees, nullptr); + if (CB.getCalledOperand() && + CB.getOperandBundle(LLVMContext::OB_ptrauth)) { + auto *NewCS = CallBase::removeOperandBundle(&CB, + LLVMContext::OB_ptrauth, + &CB); + CB.replaceAllUsesWith(NewCS); + // Schedule for deletion at the end of pass run. + CallsWithPtrAuthBundleRemoved.push_back(&CB); + } } // This use is no longer unsafe. @@ -2301,6 +2314,9 @@ for (GlobalVariable &GV : M.globals()) GV.eraseMetadata(LLVMContext::MD_vcall_visibility); + for (auto *CI : CallsWithPtrAuthBundleRemoved) + CI->eraseFromParent(); + return true; } diff --git a/llvm/test/Transforms/WholeProgramDevirt/devirt-single-impl-check-ptrauth.ll b/llvm/test/Transforms/WholeProgramDevirt/devirt-single-impl-check-ptrauth.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/WholeProgramDevirt/devirt-single-impl-check-ptrauth.ll @@ -0,0 +1,39 @@ +; RUN: opt -S -passes=wholeprogramdevirt,verify -whole-program-visibility -pass-remarks=wholeprogramdevirt %s 2>&1 | FileCheck %s + +target datalayout = "e-p:64:64" +target triple = "x86_64-unknown-linux-gnu" + +; CHECK: remark: :0:0: single-impl: devirtualized a call to vf +; CHECK: remark: :0:0: devirtualized vf +; CHECK-NOT: devirtualized + +@vt1 = constant [1 x ptr] [ptr @vf], !type !0 +@vt2 = constant [1 x ptr] [ptr @vf], !type !0 + +define void @vf(ptr %this) { + ret void +} + +; CHECK: define void @call +define void @call(ptr %obj) { + %vtable = load ptr, ptr %obj + %pair = call {ptr, i1} @llvm.type.checked.load(ptr %vtable, i32 0, metadata !"typeid") + %fptr = extractvalue {ptr, i1} %pair, 0 + %p = extractvalue {ptr, i1} %pair, 1 + ; CHECK: br i1 true, + br i1 %p, label %cont, label %trap + +cont: + ; CHECK: call void @vf( + call void %fptr(ptr %obj) [ "ptrauth"(i32 5, i64 120) ] + ret void + +trap: + call void @llvm.trap() + unreachable +} + +declare {ptr, i1} @llvm.type.checked.load(ptr, i32, metadata) +declare void @llvm.trap() + +!0 = !{i32 0, !"typeid"}