Index: llvm/include/llvm/IR/ModuleSummaryIndex.h =================================================================== --- llvm/include/llvm/IR/ModuleSummaryIndex.h +++ llvm/include/llvm/IR/ModuleSummaryIndex.h @@ -370,6 +370,14 @@ return TIdInfo->TypeCheckedLoadConstVCalls; return {}; } + + /// Add a type test to the summary. This is used by WholeProgramDevirt if we + /// were unable to devirtualize a checked call. + void addTypeTest(GlobalValue::GUID Guid) { + if (!TIdInfo) + TIdInfo = llvm::make_unique(); + TIdInfo->TypeTests.push_back(Guid); + } }; template <> struct DenseMapInfo { Index: llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp =================================================================== --- llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -25,6 +25,20 @@ // returns 0, or a single vtable's function returns 1, replace each virtual // call with a comparison of the vptr against that vtable's address. // +// This pass is intended to be used during the regular and thin LTO pipelines. +// During regular LTO, the pass determines the best optimization for each +// virtual call and applies the resolutions directly to virtual calls that are +// eligible for virtual call optimization (i.e. calls that use either of the +// llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). During +// ThinLTO, the pass operates in two phases: +// - Export phase: this is run during the thin link over a single merged module +// that contains all vtables with !type metadata that participate in the link. +// The pass computes a resolution for each virtual call and stores it in the +// type identifier summary. +// - Import phase: this is run during the thin backends over the individual +// modules. The pass applies the resolutions previously computed during the +// import phase to each eligible virtual call. +// //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/WholeProgramDevirt.h" @@ -289,7 +303,23 @@ // of constant integer arguments. The grouping by arguments is handled by the // VTableSlotInfo class. struct CallSiteInfo { + /// The set of call sites for this slot. Used during regular LTO and the + /// import phase of ThinLTO (as well as the export phase of ThinLTO for any + /// call sites that appear in the merged module itself); in each of these + /// cases we are directly operating on the call sites at the IR level. std::vector CallSites; + + // These fields are used during the export phase of ThinLTO and reflect + // information collected from function summaries. + + /// CFI-specific: a vector containing the list of function summaries that use + /// the llvm.type.checked.load intrinsic and therefore will require + /// resolutions for llvm.type.test in order to implement CFI checks if + /// devirtualization was unsuccessful. If devirtualization was successful, the + /// pass will clear this vector. If at the end of the pass the vector is + /// non-empty, we will need to add a use of llvm.type.test to each of the + /// function summaries in the vector. + std::vector SummaryTypeCheckedLoadUsers; }; // Call site information collected for a specific VTableSlot. @@ -1035,7 +1065,11 @@ M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); - if ((!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || + // Normally if there are no users of the devirtualization intrinsics in the + // module, this pass has nothing to do. But if we are exporting, we also need + // to handle any users that appear only in the function summaries. + if (Action != PassSummaryAction::Export && + (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || AssumeFunc->use_empty()) && (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) return false; @@ -1053,6 +1087,35 @@ if (TypeIdMap.empty()) return true; + // Collect information from summary about which calls to try to devirtualize. + if (Action == PassSummaryAction::Export) { + DenseMap> MetadataByGUID; + for (auto &P : TypeIdMap) { + if (auto *TypeId = dyn_cast(P.first)) + MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back( + TypeId); + } + + for (auto &P : *Summary) { + for (auto &S : P.second) { + auto *FS = dyn_cast(S.get()); + if (!FS) + continue; + // FIXME: Only add live functions. + for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) + for (Metadata *MD : MetadataByGUID[VF.GUID]) + CallSlots[{MD, VF.Offset}] + .CSInfo.SummaryTypeCheckedLoadUsers.push_back(FS); + for (const FunctionSummary::ConstVCall &VC : + FS->type_checked_load_const_vcalls()) + for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) + CallSlots[{MD, VC.VFunc.Offset}] + .ConstCSInfo[VC.Args] + .SummaryTypeCheckedLoadUsers.push_back(FS); + } + } + } + // For each (type, offset) pair: bool DidVirtualConstProp = false; std::map DevirtTargets; @@ -1061,19 +1124,32 @@ // function implementation at offset S.first.ByteOffset, and add to // TargetsForSlot. std::vector TargetsForSlot; - if (!tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID], - S.first.ByteOffset)) - continue; - - if (!trySingleImplDevirt(TargetsForSlot, S.second) && - tryVirtualConstProp(TargetsForSlot, S.second)) + if (tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID], + S.first.ByteOffset)) { + if (!trySingleImplDevirt(TargetsForSlot, S.second) && + tryVirtualConstProp(TargetsForSlot, S.second)) DidVirtualConstProp = true; - // Collect functions devirtualized at least for one call site for stats. - if (RemarksEnabled) - for (const auto &T : TargetsForSlot) - if (T.WasDevirt) - DevirtTargets[T.Fn->getName()] = T.Fn; + // Collect functions devirtualized at least for one call site for stats. + if (RemarksEnabled) + for (const auto &T : TargetsForSlot) + if (T.WasDevirt) + DevirtTargets[T.Fn->getName()] = T.Fn; + } + + // CFI-specific: if we are exporting and any llvm.type.checked.load + // intrinsics were *not* devirtualized, we need to add the resulting + // llvm.type.test intrinsics to the function summaries so that the + // LowerTypeTests pass will export them. + if (Action == PassSummaryAction::Export && isa(S.first.TypeID)) { + auto GUID = + GlobalValue::getGUID(cast(S.first.TypeID)->getString()); + for (auto FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers) + FS->addTypeTest(GUID); + for (auto &CCS : S.second.ConstCSInfo) + for (auto FS : CCS.second.SummaryTypeCheckedLoadUsers) + FS->addTypeTest(GUID); + } } if (RemarksEnabled) { Index: llvm/test/Transforms/WholeProgramDevirt/Inputs/export.yaml =================================================================== --- /dev/null +++ llvm/test/Transforms/WholeProgramDevirt/Inputs/export.yaml @@ -0,0 +1,20 @@ +--- +GlobalValueMap: + 42: + - TypeTestAssumeVCalls: + - GUID: 14276520915468743435 # typeid1 + Offset: 0 + TypeCheckedLoadVCalls: + - GUID: 15427464259790519041 # typeid2 + Offset: 0 + TypeTestAssumeConstVCalls: + - VFunc: + GUID: 3515965990081467659 # typeid3 + Offset: 0 + Args: [12, 24] + TypeCheckedLoadConstVCalls: + - VFunc: + GUID: 17525413373118030901 # typeid4 + Offset: 0 + Args: [24, 12] +... Index: llvm/test/Transforms/WholeProgramDevirt/export-unsuccessful-checked.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/WholeProgramDevirt/export-unsuccessful-checked.ll @@ -0,0 +1,30 @@ +; RUN: opt -wholeprogramdevirt -wholeprogramdevirt-summary-action=export -wholeprogramdevirt-read-summary=%S/Inputs/export.yaml -wholeprogramdevirt-write-summary=%t -o /dev/null %s +; RUN: FileCheck %s < %t + +; CHECK: - TypeTests: +; CHECK-NEXT: - 15427464259790519041 +; CHECK-NEXT: - 17525413373118030901 +; CHECK-NEXT: TypeTestAssumeVCalls: + +@vt1a = constant void (i8*)* @vf1a, !type !0 +@vt1b = constant void (i8*)* @vf1b, !type !0 +@vt2a = constant void (i8*)* @vf2a, !type !1 +@vt2b = constant void (i8*)* @vf2b, !type !1 +@vt3a = constant void (i8*)* @vf3a, !type !2 +@vt3b = constant void (i8*)* @vf3b, !type !2 +@vt4a = constant void (i8*)* @vf4a, !type !3 +@vt4b = constant void (i8*)* @vf4b, !type !3 + +declare void @vf1a(i8*) +declare void @vf1b(i8*) +declare void @vf2a(i8*) +declare void @vf2b(i8*) +declare void @vf3a(i8*) +declare void @vf3b(i8*) +declare void @vf4a(i8*) +declare void @vf4b(i8*) + +!0 = !{i32 0, !"typeid1"} +!1 = !{i32 0, !"typeid2"} +!2 = !{i32 0, !"typeid3"} +!3 = !{i32 0, !"typeid4"}