Index: llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp =================================================================== --- llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -401,6 +401,11 @@ void rebuildGlobal(VTableBits &B); + void importResolution(VTableSlotInfo &SlotInfo, + const WholeProgramDevirtResolution &Res); + + void removeUnusedTypeTests(); + bool run(); // Lower the module using the action and summary passed as command line @@ -1043,6 +1048,26 @@ } } +void DevirtModule::importResolution(VTableSlotInfo &SlotInfo, + const WholeProgramDevirtResolution &Res) { + if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) { + auto *SingleImpl = M.getOrInsertFunction( + Res.SingleImplName, Type::getVoidTy(M.getContext()), nullptr); + bool IsExported; + applySingleImplDevirt(SlotInfo, SingleImpl, IsExported); + } +} + +void DevirtModule::removeUnusedTypeTests() { + auto True = ConstantInt::getTrue(M.getContext()); + for (auto &&U : NumUnsafeUsesForTypeTest) { + if (U.second == 0) { + U.first->replaceAllUsesWith(True); + U.first->eraseFromParent(); + } + } +} + bool DevirtModule::run() { Function *TypeTestFunc = M.getFunction(Intrinsic::getName(Intrinsic::type_test)); @@ -1062,6 +1087,18 @@ if (TypeCheckedLoadFunc) scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); + if (Action == PassSummaryAction::Import) { + for (auto &S : CallSlots) { + auto &Res = + Summary->getTypeIdSummary(cast(S.first.TypeID)->getString()) + .WPDRes[S.first.ByteOffset]; + importResolution(S.second, Res); + } + + removeUnusedTypeTests(); + return true; + } + // Rebuild type metadata into a map for easy lookup. std::vector Bits; DenseMap> TypeIdMap; @@ -1159,15 +1196,7 @@ // If we were able to eliminate all unsafe uses for a type checked load, // eliminate the type test by replacing it with true. - if (TypeCheckedLoadFunc) { - auto True = ConstantInt::getTrue(M.getContext()); - for (auto &&U : NumUnsafeUsesForTypeTest) { - if (U.second == 0) { - U.first->replaceAllUsesWith(True); - U.first->eraseFromParent(); - } - } - } + removeUnusedTypeTests(); // Rebuild each global we touched as part of virtual constant propagation to // include the before and after bytes. Index: llvm/test/Transforms/WholeProgramDevirt/Inputs/import-single-impl.yaml =================================================================== --- /dev/null +++ llvm/test/Transforms/WholeProgramDevirt/Inputs/import-single-impl.yaml @@ -0,0 +1,13 @@ +--- +TypeIdMap: + typeid1: + WPDRes: + 0: + Kind: SingleImpl + SingleImplName: singleimpl1 + typeid2: + WPDRes: + 8: + Kind: SingleImpl + SingleImplName: singleimpl2 +... Index: llvm/test/Transforms/WholeProgramDevirt/import.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/WholeProgramDevirt/import.ll @@ -0,0 +1,49 @@ +; RUN: opt -S -wholeprogramdevirt -wholeprogramdevirt-summary-action=import -wholeprogramdevirt-read-summary=%S/Inputs/import-single-impl.yaml < %s | FileCheck --check-prefixes=CHECK,SINGLE-IMPL %s + +target datalayout = "e-p:64:64" +target triple = "x86_64-unknown-linux-gnu" + +; CHECK: define i32 @call1 +define i32 @call1(i8* %obj) { + %vtableptr = bitcast i8* %obj to [3 x i8*]** + %vtable = load [3 x i8*]*, [3 x i8*]** %vtableptr + %vtablei8 = bitcast [3 x i8*]* %vtable to i8* + %p = call i1 @llvm.type.test(i8* %vtablei8, metadata !"typeid1") + call void @llvm.assume(i1 %p) + %fptrptr = getelementptr [3 x i8*], [3 x i8*]* %vtable, i32 0, i32 0 + %fptr = load i8*, i8** %fptrptr + %fptr_casted = bitcast i8* %fptr to i32 (i8*, i32)* + ; SINGLE-IMPL: call i32 bitcast (void ()* @singleimpl1 to i32 (i8*, i32)*) + %result = call i32 %fptr_casted(i8* %obj, i32 1) + ret i32 %result +} + +; CHECK: define i1 @call2 +define i1 @call2(i8* %obj) { + %vtableptr = bitcast i8* %obj to [1 x i8*]** + %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr + %vtablei8 = bitcast [1 x i8*]* %vtable to i8* + %pair = call {i8*, i1} @llvm.type.checked.load(i8* %vtablei8, i32 8, metadata !"typeid2") + %fptr = extractvalue {i8*, i1} %pair, 0 + %p = extractvalue {i8*, i1} %pair, 1 + ; SINGLE-IMPL: br i1 true, + br i1 %p, label %cont, label %trap + +cont: + %fptr_casted = bitcast i8* %fptr to i1 (i8*, i32)* + ; SINGLE-IMPL: call i1 bitcast (void ()* @singleimpl2 to i1 (i8*, i32)*) + %result = call i1 %fptr_casted(i8* %obj, i32 undef) + ret i1 %result + +trap: + call void @llvm.trap() + unreachable +} + +; SINGLE-IMPL-DAG: declare void @singleimpl1() +; SINGLE-IMPL-DAG: declare void @singleimpl2() + +declare void @llvm.assume(i1) +declare void @llvm.trap() +declare {i8*, i1} @llvm.type.checked.load(i8*, i32, metadata) +declare i1 @llvm.type.test(i8*, metadata)