Index: compiler-rt/test/cfi/simple-pass.cpp =================================================================== --- compiler-rt/test/cfi/simple-pass.cpp +++ compiler-rt/test/cfi/simple-pass.cpp @@ -1,5 +1,7 @@ // RUN: %clangxx_cfi -o %t %s // RUN: %run %t +// RUN: %clangxx_cfi -mretpoline -o %t2 %s +// RUN: %run %t2 // Tests that the CFI mechanism does not crash the program when making various // kinds of valid calls involving classes with various different linkages and Index: llvm/include/llvm/ADT/PointerUnion.h =================================================================== --- llvm/include/llvm/ADT/PointerUnion.h +++ llvm/include/llvm/ADT/PointerUnion.h @@ -346,6 +346,12 @@ }; }; +template +bool operator<(PointerUnion3 lhs, + PointerUnion3 rhs) { + return lhs.getOpaqueValue() < rhs.getOpaqueValue(); +} + /// A pointer union of four pointer types. See documentation for PointerUnion /// for usage. template Index: llvm/include/llvm/IR/Intrinsics.td =================================================================== --- llvm/include/llvm/IR/Intrinsics.td +++ llvm/include/llvm/IR/Intrinsics.td @@ -874,6 +874,11 @@ [llvm_ptr_ty, llvm_i32_ty, llvm_metadata_ty], [IntrNoMem]>; +// Create a binary search jump table that implements an indirect call to a +// limited set of callees. This expands to inline asm that implements the jump +// table, so it needs to be the only instruction in a naked function. +def int_icall_jumptable : Intrinsic<[], [llvm_vararg_ty], []>; + def int_load_relative: Intrinsic<[llvm_ptr_ty], [llvm_ptr_ty, llvm_anyint_ty], [IntrReadMem, IntrArgMemOnly]>; Index: llvm/include/llvm/IR/ModuleSummaryIndex.h =================================================================== --- llvm/include/llvm/IR/ModuleSummaryIndex.h +++ llvm/include/llvm/IR/ModuleSummaryIndex.h @@ -586,6 +586,8 @@ enum Kind { Indir, ///< Just do a regular virtual call SingleImpl, ///< Single implementation devirtualization + JumpTable, ///< When retpoline mitigation is enabled, use a jump table that + ///< is defined in the merged module. Otherwise same as Indir. } TheKind = Indir; std::string SingleImplName; Index: llvm/include/llvm/IR/ModuleSummaryIndexYAML.h =================================================================== --- llvm/include/llvm/IR/ModuleSummaryIndexYAML.h +++ llvm/include/llvm/IR/ModuleSummaryIndexYAML.h @@ -98,6 +98,7 @@ static void enumeration(IO &io, WholeProgramDevirtResolution::Kind &value) { io.enumCase(value, "Indir", WholeProgramDevirtResolution::Indir); io.enumCase(value, "SingleImpl", WholeProgramDevirtResolution::SingleImpl); + io.enumCase(value, "JumpTable", WholeProgramDevirtResolution::JumpTable); } }; Index: llvm/lib/Transforms/IPO/LowerTypeTests.cpp =================================================================== --- llvm/lib/Transforms/IPO/LowerTypeTests.cpp +++ llvm/lib/Transforms/IPO/LowerTypeTests.cpp @@ -7,7 +7,8 @@ // //===----------------------------------------------------------------------===// // -// This pass lowers type metadata and calls to the llvm.type.test intrinsic. +// This pass lowers type metadata and calls to the llvm.type.test and +// llvm.icall.jumptable intrinsic. // See http://llvm.org/docs/TypeMetadata.html for more information. // //===----------------------------------------------------------------------===// @@ -25,6 +26,7 @@ #include "llvm/ADT/TinyPtrVector.h" #include "llvm/ADT/Triple.h" #include "llvm/Analysis/TypeMetadataUtils.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -291,6 +293,35 @@ } }; +struct IcallJumptableTarget { + GlobalTypeMember *GTM; + uint64_t Offset; + Constant *Callee; +}; + +struct IcallJumptable final + : TrailingObjects { + static IcallJumptable *create(BumpPtrAllocator &Alloc, CallInst *CI, + ArrayRef Targets) { + auto *Call = static_cast( + Alloc.Allocate(totalSizeToAlloc(Targets.size()), + alignof(IcallJumptable))); + Call->CI = CI; + Call->NTargets = Targets.size(); + std::uninitialized_copy(Targets.begin(), Targets.end(), + Call->getTrailingObjects()); + return Call; + } + + CallInst *CI; + ArrayRef targets() const { + return makeArrayRef(getTrailingObjects(), NTargets); + } + +private: + size_t NTargets; +}; + class LowerTypeTestsModule { Module &M; @@ -372,22 +403,38 @@ const DenseMap &GlobalLayout); Value *lowerTypeTestCall(Metadata *TypeId, CallInst *CI, const TypeIdLowering &TIL); - void buildBitSetsFromGlobalVariables(ArrayRef TypeIds, - ArrayRef Globals); + + void emitTernaryJumptable(Constant *CombinedGlobalAddr, + ArrayRef Targets, + raw_ostream &AsmOS, std::vector &Args, + uint64_t &LabelNo); + void lowerIcallJumptableCalls( + ArrayRef IcallJumptables, + Constant *CombinedGlobalAddr, + const DenseMap &GlobalLayout); + + void buildBitSetsFromGlobalVariables( + ArrayRef TypeIds, ArrayRef Globals, + ArrayRef IcallJumptables); unsigned getJumpTableEntrySize(); Type *getJumpTableEntryType(); void createJumpTableEntry(raw_ostream &AsmOS, raw_ostream &ConstraintOS, Triple::ArchType JumpTableArch, SmallVectorImpl &AsmArgs, Function *Dest); void verifyTypeMDNode(GlobalObject *GO, MDNode *Type); - void buildBitSetsFromFunctions(ArrayRef TypeIds, - ArrayRef Functions); - void buildBitSetsFromFunctionsNative(ArrayRef TypeIds, - ArrayRef Functions); + void + buildBitSetsFromFunctions(ArrayRef TypeIds, + ArrayRef Functions, + ArrayRef IcallJumptables); + void buildBitSetsFromFunctionsNative( + ArrayRef TypeIds, ArrayRef Functions, + ArrayRef IcallJumptables); void buildBitSetsFromFunctionsWASM(ArrayRef TypeIds, ArrayRef Functions); - void buildBitSetsFromDisjointSet(ArrayRef TypeIds, - ArrayRef Globals); + void + buildBitSetsFromDisjointSet(ArrayRef TypeIds, + ArrayRef Globals, + ArrayRef IcallJumptables); void replaceWeakDeclarationWithJumpTablePtr(Function *F, Constant *JT); void moveInitializerToModuleConstructor(GlobalVariable *GV); @@ -714,7 +761,8 @@ /// Given a disjoint set of type identifiers and globals, lay out the globals, /// build the bit sets and lower the llvm.type.test calls. void LowerTypeTestsModule::buildBitSetsFromGlobalVariables( - ArrayRef TypeIds, ArrayRef Globals) { + ArrayRef TypeIds, ArrayRef Globals, + ArrayRef IcallJumptables) { // Build a new global with the combined contents of the referenced globals. // This global is a struct whose even-indexed elements contain the original // contents of the referenced globals and whose odd-indexed elements contain @@ -754,6 +802,7 @@ GlobalLayout[Globals[I]] = CombinedGlobalLayout->getElementOffset(I * 2); lowerTypeTestCalls(TypeIds, CombinedGlobal, GlobalLayout); + lowerIcallJumptableCalls(IcallJumptables, CombinedGlobal, GlobalLayout); // Build aliases pointing to offsets into the combined global for each // global from which we built the combined global, and replace references @@ -1048,6 +1097,107 @@ } } +void LowerTypeTestsModule::emitTernaryJumptable( + Constant *CombinedGlobalAddr, ArrayRef Targets, + raw_ostream &AsmOS, std::vector &Args, uint64_t &LabelNo) { + assert(!Targets.empty()); + + auto AdjR11 = [&](const IcallJumptableTarget &T) { + AsmOS << "leaq ${" << Args.size() << ":c}+" << T.Offset + << "(%rip), %r11\n"; + Args.push_back(CombinedGlobalAddr); + }; + + if (Targets.size() == 1) { + AsmOS << "jmp ${" << Args.size() << ":c}@plt\n"; + Args.push_back(Targets[0].Callee); + return; + } + + if (Targets.size() == 2) { + AdjR11(Targets[1]); + AsmOS << "cmp %r11, %r10\n"; + AsmOS << "jb ${" << Args.size() << ":c}@plt\n"; + Args.push_back(Targets[0].Callee); + AsmOS << "jmp ${" << Args.size() << ":c}@plt\n"; + Args.push_back(Targets[1].Callee); + return; + } + + if (Targets.size() < 6) { + AdjR11(Targets[1]); + AsmOS << "cmp %r11, %r10\n"; + AsmOS << "jb ${" << Args.size() << ":c}@plt\n"; + Args.push_back(Targets[0].Callee); + AsmOS << "je ${" << Args.size() << ":c}@plt\n"; + Args.push_back(Targets[1].Callee); + emitTernaryJumptable(CombinedGlobalAddr, Targets.slice(2), + AsmOS, Args, LabelNo); + return; + } + + uint64_t L = LabelNo++; + AdjR11(Targets[Targets.size() / 2]); + AsmOS << "cmp %r11, %r10\n"; + AsmOS << "jb " << L << "f\n"; + AsmOS << "je ${" << Args.size() << ":c}@plt\n"; + Args.push_back(Targets[Targets.size() / 2].Callee); + emitTernaryJumptable(CombinedGlobalAddr, + Targets.slice(Targets.size() / 2 + 1), AsmOS, Args, + LabelNo); + + AsmOS << L << ":\n"; + emitTernaryJumptable(CombinedGlobalAddr, Targets.slice(0, Targets.size() / 2), + AsmOS, Args, LabelNo); +} + +void LowerTypeTestsModule::lowerIcallJumptableCalls( + ArrayRef IcallJumptables, + Constant *CombinedGlobalAddr, + const DenseMap &GlobalLayout) { + for (auto *JT : IcallJumptables) { + std::vector Targets = JT->targets(); + + // Adjust the offsets to be relative to the CombinedGlobalAddr. + for (auto &T : Targets) { + auto I = GlobalLayout.find(T.GTM); + assert(I != GlobalLayout.end()); + T.Offset += I->second; + } + std::sort( + Targets.begin(), Targets.end(), + [](const IcallJumptableTarget &T1, const IcallJumptableTarget &T2) { + return T1.Offset < T2.Offset; + }); + + std::string Asm; + raw_string_ostream AsmOS(Asm); + std::vector Args; + uint64_t LabelNo = 0; + emitTernaryJumptable(CombinedGlobalAddr, Targets, AsmOS, Args, LabelNo); + + std::string Constraints; + if (!Args.empty()) { + Constraints = "s"; + for (unsigned I = 1; I != Args.size(); ++I) + Constraints += ",s"; + } + + SmallVector ArgTypes; + ArgTypes.reserve(Args.size()); + for (const auto &Arg : Args) + ArgTypes.push_back(Arg->getType()); + + InlineAsm *JumpTableAsm = + InlineAsm::get(FunctionType::get(Type::getVoidTy(M.getContext()), ArgTypes, false), + AsmOS.str(), Constraints, + /*hasSideEffects=*/true); + CallInst::Create(JumpTableAsm, Args, "", JT->CI); + + JT->CI->eraseFromParent(); + } +} + void LowerTypeTestsModule::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) { if (Type->getNumOperands() != 2) report_fatal_error("All operands of type metadata must have 2 elements"); @@ -1118,10 +1268,11 @@ /// Given a disjoint set of type identifiers and functions, build the bit sets /// and lower the llvm.type.test calls, architecture dependently. void LowerTypeTestsModule::buildBitSetsFromFunctions( - ArrayRef TypeIds, ArrayRef Functions) { + ArrayRef TypeIds, ArrayRef Functions, + ArrayRef IcallJumptables) { if (Arch == Triple::x86 || Arch == Triple::x86_64 || Arch == Triple::arm || Arch == Triple::thumb || Arch == Triple::aarch64) - buildBitSetsFromFunctionsNative(TypeIds, Functions); + buildBitSetsFromFunctionsNative(TypeIds, Functions, IcallJumptables); else if (Arch == Triple::wasm32 || Arch == Triple::wasm64) buildBitSetsFromFunctionsWASM(TypeIds, Functions); else @@ -1283,7 +1434,8 @@ /// Given a disjoint set of type identifiers and functions, build a jump table /// for the functions, build the bit sets and lower the llvm.type.test calls. void LowerTypeTestsModule::buildBitSetsFromFunctionsNative( - ArrayRef TypeIds, ArrayRef Functions) { + ArrayRef TypeIds, ArrayRef Functions, + ArrayRef IcallJumptables) { // Unlike the global bitset builder, the function bitset builder cannot // re-arrange functions in a particular order and base its calculations on the // layout of the functions' entry points, as we have no idea how large a @@ -1378,6 +1530,7 @@ ConstantExpr::getPointerCast(JumpTableFn, JumpTableType->getPointerTo(0)); lowerTypeTestCalls(TypeIds, JumpTable, GlobalLayout); + lowerIcallJumptableCalls(IcallJumptables, JumpTable, GlobalLayout); // Build aliases pointing to offsets into the jump table, and replace // references to the original functions with references to the aliases. @@ -1462,7 +1615,8 @@ } void LowerTypeTestsModule::buildBitSetsFromDisjointSet( - ArrayRef TypeIds, ArrayRef Globals) { + ArrayRef TypeIds, ArrayRef Globals, + ArrayRef IcallJumptables) { DenseMap TypeIdIndices; for (unsigned I = 0; I != TypeIds.size(); ++I) TypeIdIndices[TypeIds[I]] = I; @@ -1471,15 +1625,25 @@ // the type identifier. std::vector> TypeMembers(TypeIds.size()); unsigned GlobalIndex = 0; + DenseMap GlobalIndices; for (GlobalTypeMember *GTM : Globals) { for (MDNode *Type : GTM->types()) { // Type = { offset, type identifier } - unsigned TypeIdIndex = TypeIdIndices[Type->getOperand(1)]; - TypeMembers[TypeIdIndex].insert(GlobalIndex); + auto I = TypeIdIndices.find(Type->getOperand(1)); + if (I != TypeIdIndices.end()) + TypeMembers[I->second].insert(GlobalIndex); } + GlobalIndices[GTM] = GlobalIndex; GlobalIndex++; } + for (IcallJumptable *JT : IcallJumptables) { + TypeMembers.emplace_back(); + std::set &TMSet = TypeMembers.back(); + for (const IcallJumptableTarget &T : JT->targets()) + TMSet.insert(GlobalIndices[T.GTM]); + } + // Order the sets of indices by size. The GlobalLayoutBuilder works best // when given small index sets first. std::stable_sort( @@ -1511,9 +1675,9 @@ // Build the bitsets from this disjoint set. if (IsGlobalSet) - buildBitSetsFromGlobalVariables(TypeIds, OrderedGTMs); + buildBitSetsFromGlobalVariables(TypeIds, OrderedGTMs, IcallJumptables); else - buildBitSetsFromFunctions(TypeIds, OrderedGTMs); + buildBitSetsFromFunctions(TypeIds, OrderedGTMs, IcallJumptables); } /// Lower all type tests in this module. @@ -1567,8 +1731,11 @@ bool LowerTypeTestsModule::lower() { Function *TypeTestFunc = M.getFunction(Intrinsic::getName(Intrinsic::type_test)); - if ((!TypeTestFunc || TypeTestFunc->use_empty()) && !ExportSummary && - !ImportSummary) + Function *IcallJumptableFunc = + M.getFunction(Intrinsic::getName(Intrinsic::icall_jumptable)); + if ((!TypeTestFunc || TypeTestFunc->use_empty()) && + (!IcallJumptableFunc || IcallJumptableFunc->use_empty()) && + !ExportSummary && !ImportSummary) return false; if (ImportSummary) { @@ -1580,6 +1747,10 @@ } } + if (IcallJumptableFunc && !IcallJumptableFunc->use_empty()) + report_fatal_error( + "unexpected call to llvm.icall.jumptable during import phase"); + SmallVector Defs; SmallVector Decls; for (auto &F : M) { @@ -1604,8 +1775,8 @@ // Equivalence class set containing type identifiers and the globals that // reference them. This is used to partition the set of type identifiers in // the module into disjoint sets. - using GlobalClassesTy = - EquivalenceClasses>; + using GlobalClassesTy = EquivalenceClasses< + PointerUnion3>; GlobalClassesTy GlobalClasses; // Verify the type metadata and build a few data structures to let us @@ -1688,14 +1859,13 @@ } } + DenseMap GlobalTypeMembers; for (GlobalObject &GO : M.global_objects()) { if (isa(GO) && GO.isDeclarationForLinker()) continue; Types.clear(); GO.getMetadata(LLVMContext::MD_type, Types); - if (Types.empty()) - continue; bool IsDefinition = !GO.isDeclarationForLinker(); bool IsExported = false; @@ -1706,6 +1876,7 @@ auto *GTM = GlobalTypeMember::create(Alloc, &GO, IsDefinition, IsExported, Types); + GlobalTypeMembers[&GO] = GTM; for (MDNode *Type : Types) { verifyTypeMDNode(&GO, Type); auto &Info = TypeIdInfo[Type->getOperand(1)]; @@ -1746,6 +1917,48 @@ } } + if (IcallJumptableFunc) { + for (const Use &U : IcallJumptableFunc->uses()) { + if (Arch != Triple::x86_64) + report_fatal_error("llvm.icall.jumptable not supported on this target"); + + auto CI = cast(U.getUser()); + + std::vector Targets; + if (CI->getNumArgOperands() % 2) + report_fatal_error("number of arguments should be a multiple of 2"); + + GlobalClassesTy::member_iterator CurSet; + for (unsigned I = 0; I != CI->getNumArgOperands(); I += 2) { + int64_t Offset; + auto *Base = dyn_cast(GetPointerBaseWithConstantOffset( + CI->getOperand(I), Offset, M.getDataLayout())); + if (!Base) + report_fatal_error("Expected jump table operand to be global value"); + + IcallJumptableTarget Target; + Target.GTM = GlobalTypeMembers[Base]; + GlobalClassesTy::member_iterator NewSet = + GlobalClasses.findLeader(GlobalClasses.insert(Target.GTM)); + if (I == 0) + CurSet = NewSet; + else + CurSet = GlobalClasses.unionSets(CurSet, NewSet); + + Target.Offset = Offset; + auto *Callee = dyn_cast(CI->getArgOperand(I + 1)); + if (!Callee) + report_fatal_error("Callee must be a constant"); + Target.Callee = Callee; + Targets.push_back(Target); + } + + GlobalClasses.unionSets(CurSet, + GlobalClasses.findLeader(GlobalClasses.insert( + IcallJumptable::create(Alloc, CI, Targets)))); + } + } + if (ExportSummary) { DenseMap> MetadataByGUID; for (auto &P : TypeIdInfo) { @@ -1798,13 +2011,16 @@ // Build the list of type identifiers in this disjoint set. std::vector TypeIds; std::vector Globals; + std::vector IcallJumptables; for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(S.first); MI != GlobalClasses.member_end(); ++MI) { - if ((*MI).is()) + if (MI->is()) TypeIds.push_back(MI->get()); - else + else if (MI->is()) Globals.push_back(MI->get()); + else + IcallJumptables.push_back(MI->get()); } // Order type identifiers by global index for determinism. This ordering is @@ -1814,7 +2030,7 @@ }); // Build bitsets for this disjoint set. - buildBitSetsFromDisjointSet(TypeIds, Globals); + buildBitSetsFromDisjointSet(TypeIds, Globals, IcallJumptables); } allocateByteArrays(); Index: llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp =================================================================== --- llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -316,12 +316,17 @@ /// cases we are directly operating on the call sites at the IR level. std::vector CallSites; + /// Whether all call sites represented by this CallSiteInfo, including those + /// in summaries, have been devirtualized. This starts off as true because a + /// default constructed CallSiteInfo represents no call sites. + bool AllCallSitesDevirted = true; + // These fields are used during the export phase of ThinLTO and reflect // information collected from function summaries. /// Whether any function summary contains an llvm.assume(llvm.type.test) for /// this slot. - bool SummaryHasTypeTestAssumeUsers; + bool SummaryHasTypeTestAssumeUsers = false; /// CFI-specific: a vector containing the list of function summaries that use /// the llvm.type.checked.load intrinsic and therefore will require @@ -337,8 +342,22 @@ !SummaryTypeCheckedLoadUsers.empty(); } - /// As explained in the comment for SummaryTypeCheckedLoadUsers. - void markDevirt() { SummaryTypeCheckedLoadUsers.clear(); } + void markSummaryHasTypeTestAssumeUsers() { + SummaryHasTypeTestAssumeUsers = true; + AllCallSitesDevirted = false; + } + + void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) { + SummaryTypeCheckedLoadUsers.push_back(FS); + AllCallSitesDevirted = false; + } + + void markDevirt() { + AllCallSitesDevirted = true; + + // As explained in the comment for SummaryTypeCheckedLoadUsers. + SummaryTypeCheckedLoadUsers.clear(); + } }; // Call site information collected for a specific VTableSlot. @@ -373,7 +392,9 @@ void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses) { - findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses}); + auto &CSI = findCallSiteInfo(CS); + CSI.AllCallSitesDevirted = false; + CSI.CallSites.push_back({VTable, CS, NumUnsafeUses}); } struct DevirtModule { @@ -438,6 +459,12 @@ VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res); + void applyIcallJumpTable(VTableSlotInfo &SlotInfo, Constant *JT, + bool &IsExported); + void tryIcallJumpTable(MutableArrayRef TargetsForSlot, + VTableSlotInfo &SlotInfo, + WholeProgramDevirtResolution *Res, VTableSlot Slot); + bool tryEvaluateFunctionsWithArgs( MutableArrayRef TargetsForSlot, ArrayRef Args); @@ -471,6 +498,8 @@ StringRef Name, IntegerType *IntTy, uint32_t Storage); + Constant *getMemberAddr(const TypeMemberInfo *M); + void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne, Constant *UniqueMemberAddr); bool tryUniqueRetValOpt(unsigned BitWidth, @@ -726,10 +755,9 @@ if (VCallSite.NumUnsafeUses) --*VCallSite.NumUnsafeUses; } - if (CSInfo.isExported()) { + if (CSInfo.isExported()) IsExported = true; - CSInfo.markDevirt(); - } + CSInfo.markDevirt(); }; Apply(SlotInfo.CSInfo); for (auto &P : SlotInfo.ConstCSInfo) @@ -785,6 +813,136 @@ return true; } +void DevirtModule::tryIcallJumpTable( + MutableArrayRef TargetsForSlot, VTableSlotInfo &SlotInfo, + WholeProgramDevirtResolution *Res, VTableSlot Slot) { + Triple T(M.getTargetTriple()); + if (T.getArch() != Triple::x86_64) + return; + + const unsigned kJumpTableThreshold = 10; + if (TargetsForSlot.size() > kJumpTableThreshold) + return; + + bool HasNonDevirt = !SlotInfo.CSInfo.AllCallSitesDevirted; + if (!HasNonDevirt) + for (auto &P : SlotInfo.ConstCSInfo) + if (!P.second.AllCallSitesDevirted) { + HasNonDevirt = true; + break; + } + + if (!HasNonDevirt) + return; + + std::vector JTArgs; + for (auto &T : TargetsForSlot) { + JTArgs.push_back(getMemberAddr(T.TM)); + JTArgs.push_back(T.Fn); + } + + FunctionType *FT = FunctionType::get(Type::getVoidTy(M.getContext()), {}, false); + Function *JT; + if (isa(Slot.TypeID)) { + JT = Function::Create(FT, Function::ExternalLinkage, + getGlobalName(Slot, {}, "jumptable"), &M); + JT->setVisibility(GlobalValue::HiddenVisibility); + } else { + JT = Function::Create(FT, Function::InternalLinkage, "jumptable", &M); + } + // Skip prologue. + // Disabled on win32 due to https://llvm.org/bugs/show_bug.cgi?id=28641#c3. + // Luckily, this function does not get any prologue even without the + // attribute. + if (T.getOS() != Triple::Win32) + JT->addFnAttr(Attribute::Naked); + + BasicBlock *BB = BasicBlock::Create(M.getContext(), "", JT, nullptr); + Constant *Intr = + Intrinsic::getDeclaration(&M, llvm::Intrinsic::icall_jumptable, {}); + + CallInst::Create(Intr, JTArgs, "", BB); + ReturnInst::Create(M.getContext(), nullptr, BB); + + bool IsExported = false; + applyIcallJumpTable(SlotInfo, JT, IsExported); + if (IsExported) + Res->TheKind = WholeProgramDevirtResolution::JumpTable; +} + +void DevirtModule::applyIcallJumpTable(VTableSlotInfo &SlotInfo, Constant *JT, + bool &IsExported) { + auto Apply = [&](CallSiteInfo &CSInfo) { + if (CSInfo.isExported()) + IsExported = true; + if (CSInfo.AllCallSitesDevirted) + return; + for (auto &&VCallSite : CSInfo.CallSites) { + CallSite CS = VCallSite.CS; + + // Jump tables are only profitable if the retpoline mitigation is enabled. + Attribute FSAttr = CS.getCaller()->getFnAttribute("target-features"); + if (FSAttr.hasAttribute(Attribute::None) || + !FSAttr.getValueAsString().contains("+retpoline")) + continue; + + if (RemarksEnabled) + VCallSite.emitRemark("jump-table", JT->getName(), OREGetter); + + // Pass the address of the vtable in the nest register, which is r10 on + // x86_64. + std::vector NewArgs; + NewArgs.push_back(Int8PtrTy); + for (Type *T : CS.getFunctionType()->params()) + NewArgs.push_back(T); + PointerType *NewFT = PointerType::getUnqual( + FunctionType::get(CS.getFunctionType()->getReturnType(), NewArgs, + CS.getFunctionType()->isVarArg())); + + IRBuilder<> IRB(CS.getInstruction()); + std::vector Args; + Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy)); + for (unsigned I = 0; I != CS.getNumArgOperands(); ++I) + Args.push_back(CS.getArgOperand(I)); + + CallSite NewCS; + if (CS.isCall()) + NewCS = IRB.CreateCall(IRB.CreateBitCast(JT, NewFT), Args); + else + NewCS = IRB.CreateInvoke( + IRB.CreateBitCast(JT, NewFT), + cast(CS.getInstruction())->getNormalDest(), + cast(CS.getInstruction())->getUnwindDest(), Args); + NewCS.setCallingConv(CS.getCallingConv()); + + AttributeList Attrs = CS.getAttributes(); + std::vector NewArgAttrs; + NewArgAttrs.push_back(AttributeSet::get( + M.getContext(), + ArrayRef{Attribute::get(M.getContext(), Attribute::Nest)})); + for (unsigned I = 0; I + 2 < Attrs.getNumAttrSets(); ++I) + NewArgAttrs.push_back(Attrs.getParamAttributes(I)); + NewCS.setAttributes( + AttributeList::get(M.getContext(), Attrs.getFnAttributes(), + Attrs.getRetAttributes(), NewArgAttrs)); + + CS->replaceAllUsesWith(NewCS.getInstruction()); + CS->eraseFromParent(); + + // This use is no longer unsafe. + if (VCallSite.NumUnsafeUses) + --*VCallSite.NumUnsafeUses; + } + // Don't mark as devirtualized because there may be callers compiled without + // retpoline mitigation, which would mean that they are lowered to + // llvm.type.test and therefore require an llvm.type.test resolution for the + // type identifier. + }; + Apply(SlotInfo.CSInfo); + for (auto &P : SlotInfo.ConstCSInfo) + Apply(P.second); +} + bool DevirtModule::tryEvaluateFunctionsWithArgs( MutableArrayRef TargetsForSlot, ArrayRef Args) { @@ -937,6 +1095,12 @@ CSInfo.markDevirt(); } +Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) { + Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy); + return ConstantExpr::getGetElementPtr(Int8Ty, C, + ConstantInt::get(Int64Ty, M->Offset)); +} + bool DevirtModule::tryUniqueRetValOpt( unsigned BitWidth, MutableArrayRef TargetsForSlot, CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res, @@ -956,12 +1120,7 @@ // checked for a uniform return value in tryUniformRetValOpt. assert(UniqueMember); - Constant *UniqueMemberAddr = - ConstantExpr::getBitCast(UniqueMember->Bits->GV, Int8PtrTy); - UniqueMemberAddr = ConstantExpr::getGetElementPtr( - Int8Ty, UniqueMemberAddr, - ConstantInt::get(Int64Ty, UniqueMember->Offset)); - + Constant *UniqueMemberAddr = getMemberAddr(UniqueMember); if (CSInfo.isExported()) { Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal; Res->Info = IsOne; @@ -1348,6 +1507,14 @@ break; } } + + if (Res.TheKind == WholeProgramDevirtResolution::JumpTable) { + auto *JT = M.getOrInsertFunction(getGlobalName(Slot, {}, "jumptable"), + Type::getVoidTy(M.getContext())); + bool IsExported = false; + applyIcallJumpTable(SlotInfo, JT, IsExported); + assert(!IsExported); + } } void DevirtModule::removeRedundantTypeTests() { @@ -1417,14 +1584,13 @@ // FIXME: Only add live functions. for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { for (Metadata *MD : MetadataByGUID[VF.GUID]) { - CallSlots[{MD, VF.Offset}].CSInfo.SummaryHasTypeTestAssumeUsers = - true; + CallSlots[{MD, VF.Offset}] + .CSInfo.markSummaryHasTypeTestAssumeUsers(); } } for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { for (Metadata *MD : MetadataByGUID[VF.GUID]) { - CallSlots[{MD, VF.Offset}] - .CSInfo.SummaryTypeCheckedLoadUsers.push_back(FS); + CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS); } } for (const FunctionSummary::ConstVCall &VC : @@ -1432,7 +1598,7 @@ for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { CallSlots[{MD, VC.VFunc.Offset}] .ConstCSInfo[VC.Args] - .SummaryHasTypeTestAssumeUsers = true; + .markSummaryHasTypeTestAssumeUsers(); } } for (const FunctionSummary::ConstVCall &VC : @@ -1440,7 +1606,7 @@ for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { CallSlots[{MD, VC.VFunc.Offset}] .ConstCSInfo[VC.Args] - .SummaryTypeCheckedLoadUsers.push_back(FS); + .addSummaryTypeCheckedLoadUser(FS); } } } @@ -1464,9 +1630,12 @@ cast(S.first.TypeID)->getString()) .WPDRes[S.first.ByteOffset]; - if (!trySingleImplDevirt(TargetsForSlot, S.second, Res) && - tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first)) - DidVirtualConstProp = true; + if (!trySingleImplDevirt(TargetsForSlot, S.second, Res)) { + DidVirtualConstProp |= + tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first); + + tryIcallJumpTable(TargetsForSlot, S.second, Res, S.first); + } // Collect functions devirtualized at least for one call site for stats. if (RemarksEnabled) Index: llvm/test/Transforms/LowerTypeTests/icall-jumptable.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/LowerTypeTests/icall-jumptable.ll @@ -0,0 +1,117 @@ +; RUN: opt -S -lowertypetests < %s | FileCheck %s + +target datalayout = "e-p:64:64" +target triple = "x86_64-unknown-linux" + +; CHECK: @0 = private constant { i32, [0 x i8], i32, [0 x i8], i32 } { i32 1, [0 x i8] zeroinitializer, i32 2, [0 x i8] zeroinitializer, i32 3 } +@g1 = constant i32 1 +@g2 = constant i32 2, !type !0 +@g3 = constant i32 3, !type !0 + +define void @f1() !type !1 { + ret void +} + +define void @f2() !type !1 { + ret void +} + +define void @f3() !type !1 { + ret void +} + +define void @f4() !type !1 { + ret void +} + +define void @f5() !type !1 { + ret void +} + +define void @f6() !type !1 { + ret void +} + +define void @f7() !type !1 { + ret void +} + +define void @f8() !type !1 { + ret void +} + +define void @f9() !type !1 { + ret void +} + +define void @f10() !type !1 { + ret void +} + +declare void @g1f() +declare void @g2f() + +; CHECK: define void @jt2() +define void @jt2() { + ; CHECK-NEXT: call void asm sideeffect "leaq ${0:c}+5(%rip), %r11\0Acmp %r11, %r10\0Ajb ${1:c}@plt\0Ajmp ${2:c}@plt\0A", "s,s,s"({ i32, [0 x i8], i32, [0 x i8], i32 }* @0, void ()* @g1f, void ()* @g2f) + call void (...) @llvm.icall.jumptable( + i32* @g1, void ()* @g1f, + i8* getelementptr (i8, i8* bitcast (i32* @g2 to i8*), i64 1), void ()* @g2f + ) + ret void +} + +; CHECK: define void @jt3() +define void @jt3() { + ; CHECK-NEXT: call void asm sideeffect "leaq ${0:c}+8(%rip), %r11\0Acmp %r11, %r10\0Ajb ${1:c}@plt\0Aje ${2:c}@plt\0Ajmp ${3:c}@plt\0A", "s,s,s,s"([10 x [8 x i8]]* bitcast (void ()* @.cfi.jumptable to [10 x [8 x i8]]*), void ()* @f1, void ()* @f2, void ()* @f3) + call void (...) @llvm.icall.jumptable( + void ()* @f1, void ()* @f1, + void ()* @f2, void ()* @f2, + void ()* @f3, void ()* @f3 + ) + ret void +} + +; CHECK: define void @jt7() +define void @jt7() { + ; CHECK-NEXT: call void asm sideeffect "leaq ${0:c}+24(%rip), %r11\0Acmp %r11, %r10\0Ajb 0f\0Aje ${1:c}@plt\0Aleaq ${2:c}+40(%rip), %r11\0Acmp %r11, %r10\0Ajb ${3:c}@plt\0Aje ${4:c}@plt\0Ajmp ${5:c}@plt\0A0:\0Aleaq ${6:c}+8(%rip), %r11\0Acmp %r11, %r10\0Ajb ${7:c}@plt\0Aje ${8:c}@plt\0Ajmp ${9:c}@plt\0A", "s,s,s,s,s,s,s,s,s,s"([10 x [8 x i8]]* bitcast (void ()* @.cfi.jumptable to [10 x [8 x i8]]*), void ()* @f4, [10 x [8 x i8]]* bitcast (void ()* @.cfi.jumptable to [10 x [8 x i8]]*), void ()* @f5, void ()* @f6, void ()* @f7, [10 x [8 x i8]]* bitcast (void ()* @.cfi.jumptable to [10 x [8 x i8]]*), void ()* @f1, void ()* @f2, void ()* @f3) + call void (...) @llvm.icall.jumptable( + void ()* @f1, void ()* @f1, + void ()* @f2, void ()* @f2, + void ()* @f3, void ()* @f3, + void ()* @f4, void ()* @f4, + void ()* @f5, void ()* @f5, + void ()* @f6, void ()* @f6, + void ()* @f7, void ()* @f7 + ) + ret void +} + +; CHECK: define void @jt10() +define void @jt10() { + ; CHECK-NEXT: call void asm sideeffect "leaq ${0:c}+40(%rip), %r11\0Acmp %r11, %r10\0Ajb 0f\0Aje ${1:c}@plt\0Aleaq ${2:c}+56(%rip), %r11\0Acmp %r11, %r10\0Ajb ${3:c}@plt\0Aje ${4:c}@plt\0Aleaq ${5:c}+72(%rip), %r11\0Acmp %r11, %r10\0Ajb ${6:c}@plt\0Ajmp ${7:c}@plt\0A0:\0Aleaq ${8:c}+8(%rip), %r11\0Acmp %r11, %r10\0Ajb ${9:c}@plt\0Aje ${10:c}@plt\0Aleaq ${11:c}+24(%rip), %r11\0Acmp %r11, %r10\0Ajb ${12:c}@plt\0Aje ${13:c}@plt\0Ajmp ${14:c}@plt\0A", "s,s,s,s,s,s,s,s,s,s,s,s,s,s,s"([10 x [8 x i8]]* bitcast (void ()* @.cfi.jumptable to [10 x [8 x i8]]*), void ()* @f6, [10 x [8 x i8]]* bitcast (void ()* @.cfi.jumptable to [10 x [8 x i8]]*), void ()* @f7, void ()* @f8, [10 x [8 x i8]]* bitcast (void ()* @.cfi.jumptable to [10 x [8 x i8]]*), void ()* @f9, void ()* @f10, [10 x [8 x i8]]* bitcast (void ()* @.cfi.jumptable to [10 x [8 x i8]]*), void ()* @f1, void ()* @f2, [10 x [8 x i8]]* bitcast (void ()* @.cfi.jumptable to [10 x [8 x i8]]*), void ()* @f3, void ()* @f4, void ()* @f5) + call void (...) @llvm.icall.jumptable( + void ()* @f1, void ()* @f1, + void ()* @f2, void ()* @f2, + void ()* @f3, void ()* @f3, + void ()* @f4, void ()* @f4, + void ()* @f5, void ()* @f5, + void ()* @f6, void ()* @f6, + void ()* @f7, void ()* @f7, + void ()* @f8, void ()* @f8, + void ()* @f9, void ()* @f9, + void ()* @f10, void ()* @f10 + ) + ret void +} + +define i1 @tt(i8* %ptr) { + %p = call i1 @llvm.type.test(i8* %ptr, metadata !"typeid1") + ret i1 %p +} + +!0 = !{i32 0, !"typeid1"} +!1 = !{i32 0, !"typeid2"} + +declare i1 @llvm.type.test(i8* %ptr, metadata %bitset) nounwind readnone +declare void @llvm.icall.jumptable(...) Index: llvm/test/Transforms/WholeProgramDevirt/Inputs/import-jumptable.yaml =================================================================== --- /dev/null +++ llvm/test/Transforms/WholeProgramDevirt/Inputs/import-jumptable.yaml @@ -0,0 +1,11 @@ +--- +TypeIdMap: + typeid1: + WPDRes: + 0: + Kind: JumpTable + typeid2: + WPDRes: + 8: + Kind: JumpTable +... Index: llvm/test/Transforms/WholeProgramDevirt/Inputs/import-vcp-jumptable.yaml =================================================================== --- /dev/null +++ llvm/test/Transforms/WholeProgramDevirt/Inputs/import-vcp-jumptable.yaml @@ -0,0 +1,23 @@ +--- +TypeIdMap: + typeid1: + WPDRes: + 0: + Kind: JumpTable + ResByArg: + 1: + Kind: VirtualConstProp + Info: 0 + Byte: 42 + Bit: 0 + typeid2: + WPDRes: + 8: + Kind: JumpTable + ResByArg: + 3: + Kind: VirtualConstProp + Info: 0 + Byte: 43 + Bit: 128 +... Index: llvm/test/Transforms/WholeProgramDevirt/import.ll =================================================================== --- llvm/test/Transforms/WholeProgramDevirt/import.ll +++ llvm/test/Transforms/WholeProgramDevirt/import.ll @@ -1,10 +1,12 @@ ; RUN: opt -S -wholeprogramdevirt -wholeprogramdevirt-summary-action=import -wholeprogramdevirt-read-summary=%S/Inputs/import-single-impl.yaml < %s | FileCheck --check-prefixes=CHECK,SINGLE-IMPL %s -; RUN: opt -S -wholeprogramdevirt -wholeprogramdevirt-summary-action=import -wholeprogramdevirt-read-summary=%S/Inputs/import-uniform-ret-val.yaml < %s | FileCheck --check-prefixes=CHECK,UNIFORM-RET-VAL %s -; RUN: opt -S -wholeprogramdevirt -wholeprogramdevirt-summary-action=import -wholeprogramdevirt-read-summary=%S/Inputs/import-unique-ret-val0.yaml < %s | FileCheck --check-prefixes=CHECK,UNIQUE-RET-VAL0 %s -; RUN: opt -S -wholeprogramdevirt -wholeprogramdevirt-summary-action=import -wholeprogramdevirt-read-summary=%S/Inputs/import-unique-ret-val1.yaml < %s | FileCheck --check-prefixes=CHECK,UNIQUE-RET-VAL1 %s -; RUN: opt -S -wholeprogramdevirt -wholeprogramdevirt-summary-action=import -wholeprogramdevirt-read-summary=%S/Inputs/import-vcp.yaml < %s | FileCheck --check-prefixes=CHECK,VCP,VCP-X86,VCP64 %s +; RUN: opt -S -wholeprogramdevirt -wholeprogramdevirt-summary-action=import -wholeprogramdevirt-read-summary=%S/Inputs/import-uniform-ret-val.yaml < %s | FileCheck --check-prefixes=CHECK,INDIR,UNIFORM-RET-VAL %s +; RUN: opt -S -wholeprogramdevirt -wholeprogramdevirt-summary-action=import -wholeprogramdevirt-read-summary=%S/Inputs/import-unique-ret-val0.yaml < %s | FileCheck --check-prefixes=CHECK,INDIR,UNIQUE-RET-VAL0 %s +; RUN: opt -S -wholeprogramdevirt -wholeprogramdevirt-summary-action=import -wholeprogramdevirt-read-summary=%S/Inputs/import-unique-ret-val1.yaml < %s | FileCheck --check-prefixes=CHECK,INDIR,UNIQUE-RET-VAL1 %s +; RUN: opt -S -wholeprogramdevirt -wholeprogramdevirt-summary-action=import -wholeprogramdevirt-read-summary=%S/Inputs/import-vcp.yaml < %s | FileCheck --check-prefixes=CHECK,VCP,VCP-X86,VCP64,INDIR %s ; RUN: opt -S -wholeprogramdevirt -wholeprogramdevirt-summary-action=import -wholeprogramdevirt-read-summary=%S/Inputs/import-vcp.yaml -mtriple=i686-unknown-linux -data-layout=e-p:32:32 < %s | FileCheck --check-prefixes=CHECK,VCP,VCP-X86,VCP32 %s ; RUN: opt -S -wholeprogramdevirt -wholeprogramdevirt-summary-action=import -wholeprogramdevirt-read-summary=%S/Inputs/import-vcp.yaml -mtriple=armv7-unknown-linux -data-layout=e-p:32:32 < %s | FileCheck --check-prefixes=CHECK,VCP,VCP-ARM %s +; RUN: opt -S -wholeprogramdevirt -wholeprogramdevirt-summary-action=import -wholeprogramdevirt-read-summary=%S/Inputs/import-vcp-jumptable.yaml < %s | FileCheck --check-prefixes=CHECK,VCP,VCP-X86,VCP64,JUMPTABLE %s +; RUN: opt -S -wholeprogramdevirt -wholeprogramdevirt-summary-action=import -wholeprogramdevirt-read-summary=%S/Inputs/import-jumptable.yaml < %s | FileCheck --check-prefixes=CHECK,JUMPTABLE,JUMPTABLE-NOVCP %s target datalayout = "e-p:64:64" target triple = "x86_64-unknown-linux-gnu" @@ -18,7 +20,7 @@ ; constant propagation. ; CHECK: define i32 @call1 -define i32 @call1(i8* %obj) { +define i32 @call1(i8* %obj) #0 { %vtableptr = bitcast i8* %obj to [3 x i8*]** %vtable = load [3 x i8*]*, [3 x i8*]** %vtableptr %vtablei8 = bitcast [3 x i8*]* %vtable to i8* @@ -27,16 +29,18 @@ %fptrptr = getelementptr [3 x i8*], [3 x i8*]* %vtable, i32 0, i32 0 %fptr = load i8*, i8** %fptrptr %fptr_casted = bitcast i8* %fptr to i32 (i8*, i32)* + ; CHECK: {{.*}} = bitcast {{.*}} to i8* + ; VCP: [[VT1:%.*]] = bitcast {{.*}} to i8* ; SINGLE-IMPL: call i32 bitcast (void ()* @singleimpl1 to i32 (i8*, i32)*) %result = call i32 %fptr_casted(i8* %obj, i32 1) ; UNIFORM-RET-VAL: ret i32 42 - ; VCP: {{.*}} = bitcast {{.*}} to i8* - ; VCP: [[VT1:%.*]] = bitcast {{.*}} to i8* ; VCP-X86: [[GEP1:%.*]] = getelementptr i8, i8* [[VT1]], i32 ptrtoint (i8* @__typeid_typeid1_0_1_byte to i32) ; VCP-ARM: [[GEP1:%.*]] = getelementptr i8, i8* [[VT1]], i32 42 ; VCP: [[BC1:%.*]] = bitcast i8* [[GEP1]] to i32* ; VCP: [[LOAD1:%.*]] = load i32, i32* [[BC1]] ; VCP: ret i32 [[LOAD1]] + ; JUMPTABLE-NOVCP: [[VT1:%.*]] = bitcast {{.*}} to i8* + ; JUMPTABLE-NOVCP: call i32 bitcast (void ()* @__typeid_typeid1_0_jumptable to i32 (i8*, i8*, i32)*)(i8* nest [[VT1]], i8* %obj, i32 1) ret i32 %result } @@ -44,7 +48,8 @@ ; constant propagation. ; CHECK: define i1 @call2 -define i1 @call2(i8* %obj) { +define i1 @call2(i8* %obj) #0 { + ; JUMPTABLE: [[VT1:%.*]] = bitcast {{.*}} to i8* %vtableptr = bitcast i8* %obj to [1 x i8*]** %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr %vtablei8 = bitcast [1 x i8*]* %vtable to i8* @@ -57,9 +62,8 @@ cont: %fptr_casted = bitcast i8* %fptr to i1 (i8*, i32)* ; SINGLE-IMPL: call i1 bitcast (void ()* @singleimpl2 to i1 (i8*, i32)*) - ; UNIFORM-RET-VAL: call i1 % - ; UNIQUE-RET-VAL0: call i1 % - ; UNIQUE-RET-VAL1: call i1 % + ; INDIR: call i1 % + ; JUMPTABLE: call i1 bitcast (void ()* @__typeid_typeid2_8_jumptable to i1 (i8*, i8*, i32)*)(i8* nest [[VT1]], i8* %obj, i32 undef) %result = call i1 %fptr_casted(i8* %obj, i32 undef) ret i1 %result @@ -69,7 +73,7 @@ } ; CHECK: define i1 @call3 -define i1 @call3(i8* %obj) { +define i1 @call3(i8* %obj) #0 { %vtableptr = bitcast i8* %obj to [1 x i8*]** %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr %vtablei8 = bitcast [1 x i8*]* %vtable to i8* @@ -91,6 +95,8 @@ ; VCP-ARM: [[AND2:%.*]] = and i8 [[LOAD2]], -128 ; VCP: [[ICMP2:%.*]] = icmp ne i8 [[AND2]], 0 ; VCP: ret i1 [[ICMP2]] + ; JUMPTABLE-NOVCP: [[VT2:%.*]] = bitcast {{.*}} to i8* + ; JUMPTABLE-NOVCP: call i1 bitcast (void ()* @__typeid_typeid2_8_jumptable to i1 (i8*, i8*, i32)*)(i8* nest [[VT2]], i8* %obj, i32 3) ret i1 %result trap: @@ -111,3 +117,5 @@ declare void @llvm.trap() declare {i8*, i1} @llvm.type.checked.load(i8*, i32, metadata) declare i1 @llvm.type.test(i8*, metadata) + +attributes #0 = { "target-features"="+retpoline" } Index: llvm/test/Transforms/WholeProgramDevirt/jumptable.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/WholeProgramDevirt/jumptable.ll @@ -0,0 +1,157 @@ +; RUN: opt -S -wholeprogramdevirt %s | FileCheck --check-prefixes=CHECK,RETP %s +; RUN: sed -e 's,+retpoline,-retpoline,g' %s | opt -S -wholeprogramdevirt | FileCheck --check-prefixes=CHECK,NORETP %s +; RUN: opt -wholeprogramdevirt -wholeprogramdevirt-summary-action=export -wholeprogramdevirt-read-summary=%S/Inputs/export.yaml -wholeprogramdevirt-write-summary=%t -S -o - %s | FileCheck --check-prefixes=CHECK,RETP %s +; RUN: FileCheck --check-prefix=SUMMARY %s < %t + +; SUMMARY: TypeIdMap: +; SUMMARY-NEXT: typeid1: +; SUMMARY-NEXT: TTRes: +; SUMMARY-NEXT: Kind: Unsat +; SUMMARY-NEXT: SizeM1BitWidth: 0 +; SUMMARY-NEXT: AlignLog2: 0 +; SUMMARY-NEXT: SizeM1: 0 +; SUMMARY-NEXT: BitMask: 0 +; SUMMARY-NEXT: InlineBits: 0 +; SUMMARY-NEXT: WPDRes: +; SUMMARY-NEXT: 0: +; SUMMARY-NEXT: Kind: JumpTable +; SUMMARY-NEXT: SingleImplName: '' +; SUMMARY-NEXT: ResByArg: +; SUMMARY-NEXT: typeid2: +; SUMMARY-NEXT: TTRes: +; SUMMARY-NEXT: Kind: Unsat +; SUMMARY-NEXT: SizeM1BitWidth: 0 +; SUMMARY-NEXT: AlignLog2: 0 +; SUMMARY-NEXT: SizeM1: 0 +; SUMMARY-NEXT: BitMask: 0 +; SUMMARY-NEXT: InlineBits: 0 +; SUMMARY-NEXT: WPDRes: +; SUMMARY-NEXT: 0: +; SUMMARY-NEXT: Kind: Indir +; SUMMARY-NEXT: SingleImplName: '' +; SUMMARY-NEXT: ResByArg: +; SUMMARY-NEXT: typeid3: +; SUMMARY-NEXT: TTRes: +; SUMMARY-NEXT: Kind: Unsat +; SUMMARY-NEXT: SizeM1BitWidth: 0 +; SUMMARY-NEXT: AlignLog2: 0 +; SUMMARY-NEXT: SizeM1: 0 +; SUMMARY-NEXT: BitMask: 0 +; SUMMARY-NEXT: InlineBits: 0 +; SUMMARY-NEXT: WPDRes: +; SUMMARY-NEXT: 0: +; SUMMARY-NEXT: Kind: JumpTable +; SUMMARY-NEXT: SingleImplName: '' +; SUMMARY-NEXT: ResByArg: + +target datalayout = "e-p:64:64" +target triple = "x86_64-unknown-linux-gnu" + +@vt1_1 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf1_1 to i8*)], !type !0 +@vt1_2 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf1_2 to i8*)], !type !0 + +declare i32 @vf1_1(i8* %this, i32 %arg) +declare i32 @vf1_2(i8* %this, i32 %arg) + +@vt2_1 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf2_1 to i8*)], !type !1 +@vt2_2 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf2_2 to i8*)], !type !1 +@vt2_3 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf2_3 to i8*)], !type !1 +@vt2_4 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf2_4 to i8*)], !type !1 +@vt2_5 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf2_5 to i8*)], !type !1 +@vt2_6 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf2_6 to i8*)], !type !1 +@vt2_7 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf2_7 to i8*)], !type !1 +@vt2_8 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf2_8 to i8*)], !type !1 +@vt2_9 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf2_9 to i8*)], !type !1 +@vt2_10 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf2_10 to i8*)], !type !1 +@vt2_11 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf2_11 to i8*)], !type !1 + +declare i32 @vf2_1(i8* %this, i32 %arg) +declare i32 @vf2_2(i8* %this, i32 %arg) +declare i32 @vf2_3(i8* %this, i32 %arg) +declare i32 @vf2_4(i8* %this, i32 %arg) +declare i32 @vf2_5(i8* %this, i32 %arg) +declare i32 @vf2_6(i8* %this, i32 %arg) +declare i32 @vf2_7(i8* %this, i32 %arg) +declare i32 @vf2_8(i8* %this, i32 %arg) +declare i32 @vf2_9(i8* %this, i32 %arg) +declare i32 @vf2_10(i8* %this, i32 %arg) +declare i32 @vf2_11(i8* %this, i32 %arg) + +@vt3_1 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf3_1 to i8*)], !type !2 +@vt3_2 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf3_2 to i8*)], !type !2 + +declare i32 @vf3_1(i8* %this, i32 %arg) +declare i32 @vf3_2(i8* %this, i32 %arg) + +@vt4_1 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf4_1 to i8*)], !type !3 +@vt4_2 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf4_2 to i8*)], !type !3 + +declare i32 @vf4_1(i8* %this, i32 %arg) +declare i32 @vf4_2(i8* %this, i32 %arg) + +; CHECK: define i32 @fn1 +define i32 @fn1(i8* %obj) #0 { + %vtableptr = bitcast i8* %obj to [1 x i8*]** + %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr + %vtablei8 = bitcast [1 x i8*]* %vtable to i8* + %p = call i1 @llvm.type.test(i8* %vtablei8, metadata !"typeid1") + call void @llvm.assume(i1 %p) + %fptrptr = getelementptr [1 x i8*], [1 x i8*]* %vtable, i32 0, i32 0 + %fptr = load i8*, i8** %fptrptr + %fptr_casted = bitcast i8* %fptr to i32 (i8*, i32)* + ; RETP: {{.*}} = bitcast {{.*}} to i8* + ; RETP: [[VT1:%.*]] = bitcast {{.*}} to i8* + ; RETP: call i32 bitcast (void ()* @__typeid_typeid1_0_jumptable to i32 (i8*, i8*, i32)*)(i8* nest [[VT1]], i8* %obj, i32 1) + %result = call i32 %fptr_casted(i8* %obj, i32 1) + ; NORETP: call i32 % + ret i32 %result +} + +; CHECK: define i32 @fn2 +define i32 @fn2(i8* %obj) #0 { + %vtableptr = bitcast i8* %obj to [1 x i8*]** + %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr + %vtablei8 = bitcast [1 x i8*]* %vtable to i8* + %p = call i1 @llvm.type.test(i8* %vtablei8, metadata !"typeid2") + call void @llvm.assume(i1 %p) + %fptrptr = getelementptr [1 x i8*], [1 x i8*]* %vtable, i32 0, i32 0 + %fptr = load i8*, i8** %fptrptr + %fptr_casted = bitcast i8* %fptr to i32 (i8*, i32)* + ; CHECK: call i32 % + %result = call i32 %fptr_casted(i8* %obj, i32 1) + ret i32 %result +} + +; CHECK: define i32 @fn3 +define i32 @fn3(i8* %obj) #0 { + %vtableptr = bitcast i8* %obj to [1 x i8*]** + %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr + %vtablei8 = bitcast [1 x i8*]* %vtable to i8* + %p = call i1 @llvm.type.test(i8* %vtablei8, metadata !4) + call void @llvm.assume(i1 %p) + %fptrptr = getelementptr [1 x i8*], [1 x i8*]* %vtable, i32 0, i32 0 + %fptr = load i8*, i8** %fptrptr + %fptr_casted = bitcast i8* %fptr to i32 (i8*, i32)* + ; RETP: call i32 bitcast (void ()* @jumptable to + ; NORETP: call i32 % + %result = call i32 %fptr_casted(i8* %obj, i32 1) + ret i32 %result +} + +; CHECK: define internal void @jumptable() + +; CHECK: define hidden void @__typeid_typeid1_0_jumptable() [[A:#[0-9]+]] +; CHECK-NEXT: call void (...) @llvm.icall.jumptable(i8* bitcast ([1 x i8*]* @vt1_1 to i8*), i32 (i8*, i32)* @vf1_1, i8* bitcast ([1 x i8*]* @vt1_2 to i8*), i32 (i8*, i32)* @vf1_2) + +declare i1 @llvm.type.test(i8*, metadata) +declare void @llvm.assume(i1) + +!0 = !{i32 0, !"typeid1"} +!1 = !{i32 0, !"typeid2"} +!2 = !{i32 0, !"typeid3"} +!3 = !{i32 0, !4} +!4 = distinct !{} + +; CHECK: attributes [[A]] = { naked } + +attributes #0 = { "target-features"="+retpoline" }