diff --git a/llvm/include/llvm/Analysis/TypeMetadataUtils.h b/llvm/include/llvm/Analysis/TypeMetadataUtils.h --- a/llvm/include/llvm/Analysis/TypeMetadataUtils.h +++ b/llvm/include/llvm/Analysis/TypeMetadataUtils.h @@ -15,7 +15,7 @@ #define LLVM_ANALYSIS_TYPEMETADATAUTILS_H #include "llvm/ADT/SmallVector.h" -#include "llvm/IR/CallSite.h" +#include "llvm/IR/Instructions.h" namespace llvm { @@ -33,7 +33,7 @@ /// The offset from the address point to the virtual function. uint64_t Offset; /// The call site itself. - CallSite CS; + CallBase &CB; }; /// Given a call to the intrinsic \@llvm.type.test, find all devirtualizable diff --git a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp --- a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp +++ b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp @@ -28,7 +28,6 @@ #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" @@ -99,7 +98,7 @@ if (!Visited.insert(U).second) continue; - ImmutableCallSite CS(U); + const auto *CB = dyn_cast(U); for (const auto &OI : U->operands()) { const User *Operand = dyn_cast(OI); @@ -113,7 +112,7 @@ // We have a reference to a global value. This should be added to // the reference set unless it is a callee. Callees are handled // specially by WriteFunction and are added to a separate list. - if (!(CS && CS.isCallee(&OI))) + if (!(CB && CB->isCallee(&OI))) RefEdges.insert(Index.getOrInsertValueInfo(GV)); continue; } @@ -145,7 +144,7 @@ SetVector &ConstVCalls) { std::vector Args; // Start from the second argument to skip the "this" pointer. - for (auto &Arg : make_range(Call.CS.arg_begin() + 1, Call.CS.arg_end())) { + for (auto &Arg : make_range(Call.CB.arg_begin() + 1, Call.CB.arg_end())) { auto *CI = dyn_cast(Arg); if (!CI || CI->getBitWidth() > 64) { VCalls.insert({Guid, Call.Offset}); @@ -304,8 +303,8 @@ } } findRefEdges(Index, &I, RefEdges, Visited); - auto CS = ImmutableCallSite(&I); - if (!CS) + const auto *CB = dyn_cast(&I); + if (!CB) continue; const auto *CI = dyn_cast(&I); @@ -317,8 +316,8 @@ if (HasLocalsInUsedOrAsm && CI && CI->isInlineAsm()) HasInlineAsmMaybeReferencingInternal = true; - auto *CalledValue = CS.getCalledValue(); - auto *CalledFunction = CS.getCalledFunction(); + auto *CalledValue = CB->getCalledValue(); + auto *CalledFunction = CB->getCalledFunction(); if (CalledValue && !CalledFunction) { CalledValue = CalledValue->stripPointerCasts(); // Stripping pointer casts can reveal a called function. diff --git a/llvm/lib/Analysis/TypeMetadataUtils.cpp b/llvm/lib/Analysis/TypeMetadataUtils.cpp --- a/llvm/lib/Analysis/TypeMetadataUtils.cpp +++ b/llvm/lib/Analysis/TypeMetadataUtils.cpp @@ -37,10 +37,10 @@ if (isa(User)) { findCallsAtConstantOffset(DevirtCalls, HasNonCallUses, User, Offset, CI, DT); - } else if (auto CI = dyn_cast(User)) { - DevirtCalls.push_back({Offset, CI}); - } else if (auto II = dyn_cast(User)) { - DevirtCalls.push_back({Offset, II}); + } else if (auto *CI = dyn_cast(User)) { + DevirtCalls.push_back({Offset, *CI}); + } else if (auto *II = dyn_cast(User)) { + DevirtCalls.push_back({Offset, *II}); } else if (HasNonCallUses) { *HasNonCallUses = true; } diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -64,7 +64,6 @@ #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/Bitcode/BitcodeWriter.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugLoc.h" @@ -354,20 +353,20 @@ // A virtual call site. VTable is the loaded virtual table pointer, and CS is // the indirect virtual call. struct VirtualCallSite { - Value *VTable; - CallSite CS; + Value *VTable = nullptr; + CallBase &CB; // If non-null, this field points to the associated unsafe use count stored in // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description // of that field for details. - unsigned *NumUnsafeUses; + unsigned *NumUnsafeUses = nullptr; void emitRemark(const StringRef OptName, const StringRef TargetName, function_ref OREGetter) { - Function *F = CS.getCaller(); - DebugLoc DLoc = CS->getDebugLoc(); - BasicBlock *Block = CS.getParent(); + Function *F = CB.getCaller(); + DebugLoc DLoc = CB.getDebugLoc(); + BasicBlock *Block = CB.getParent(); using namespace ore; OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block) @@ -382,12 +381,12 @@ Value *New) { if (RemarksEnabled) emitRemark(OptName, TargetName, OREGetter); - CS->replaceAllUsesWith(New); - if (auto II = dyn_cast(CS.getInstruction())) { - BranchInst::Create(II->getNormalDest(), CS.getInstruction()); + CB.replaceAllUsesWith(New); + if (auto *II = dyn_cast(&CB)) { + BranchInst::Create(II->getNormalDest(), &CB); II->getUnwindDest()->removePredecessor(II->getParent()); } - CS->eraseFromParent(); + CB.eraseFromParent(); // This use is no longer unsafe. if (NumUnsafeUses) --*NumUnsafeUses; @@ -460,18 +459,18 @@ // "this"), grouped by argument list. std::map, CallSiteInfo> ConstCSInfo; - void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses); + void addCallSite(Value *VTable, CallBase &CB, unsigned *NumUnsafeUses); private: - CallSiteInfo &findCallSiteInfo(CallSite CS); + CallSiteInfo &findCallSiteInfo(CallBase &CB); }; -CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) { +CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallBase &CB) { std::vector Args; - auto *CI = dyn_cast(CS.getType()); - if (!CI || CI->getBitWidth() > 64 || CS.arg_empty()) + auto *CBType = dyn_cast(CB.getType()); + if (!CBType || CBType->getBitWidth() > 64 || CB.arg_empty()) return CSInfo; - for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) { + for (auto &&Arg : make_range(CB.arg_begin() + 1, CB.arg_end())) { auto *CI = dyn_cast(Arg); if (!CI || CI->getBitWidth() > 64) return CSInfo; @@ -480,11 +479,11 @@ return ConstCSInfo[Args]; } -void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS, +void VTableSlotInfo::addCallSite(Value *VTable, CallBase &CB, unsigned *NumUnsafeUses) { - auto &CSI = findCallSiteInfo(CS); + auto &CSI = findCallSiteInfo(CB); CSI.AllCallSitesDevirted = false; - CSI.CallSites.push_back({VTable, CS, NumUnsafeUses}); + CSI.CallSites.push_back({VTable, CB, NumUnsafeUses}); } struct DevirtModule { @@ -1029,8 +1028,8 @@ if (RemarksEnabled) VCallSite.emitRemark("single-impl", TheFn->stripPointerCasts()->getName(), OREGetter); - VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( - TheFn, VCallSite.CS.getCalledValue()->getType())); + VCallSite.CB.setCalledOperand(ConstantExpr::getBitCast( + TheFn, VCallSite.CB.getCalledValue()->getType())); // This use is no longer unsafe. if (VCallSite.NumUnsafeUses) --*VCallSite.NumUnsafeUses; @@ -1253,10 +1252,10 @@ if (CSInfo.AllCallSitesDevirted) return; for (auto &&VCallSite : CSInfo.CallSites) { - CallSite CS = VCallSite.CS; + CallBase &CB = VCallSite.CB; // Jump tables are only profitable if the retpoline mitigation is enabled. - Attribute FSAttr = CS.getCaller()->getFnAttribute("target-features"); + Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features"); if (FSAttr.hasAttribute(Attribute::None) || !FSAttr.getValueAsString().contains("+retpoline")) continue; @@ -1269,42 +1268,40 @@ // x86_64. std::vector NewArgs; NewArgs.push_back(Int8PtrTy); - for (Type *T : CS.getFunctionType()->params()) + for (Type *T : CB.getFunctionType()->params()) NewArgs.push_back(T); FunctionType *NewFT = - FunctionType::get(CS.getFunctionType()->getReturnType(), NewArgs, - CS.getFunctionType()->isVarArg()); + FunctionType::get(CB.getFunctionType()->getReturnType(), NewArgs, + CB.getFunctionType()->isVarArg()); PointerType *NewFTPtr = PointerType::getUnqual(NewFT); - IRBuilder<> IRB(CS.getInstruction()); + IRBuilder<> IRB(&CB); std::vector Args; Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy)); - for (unsigned I = 0; I != CS.getNumArgOperands(); ++I) - Args.push_back(CS.getArgOperand(I)); + Args.insert(Args.end(), CB.arg_begin(), CB.arg_end()); - CallSite NewCS; - if (CS.isCall()) + CallBase *NewCS = nullptr; + if (isa(CB)) NewCS = IRB.CreateCall(NewFT, IRB.CreateBitCast(JT, NewFTPtr), Args); else - NewCS = IRB.CreateInvoke( - NewFT, IRB.CreateBitCast(JT, NewFTPtr), - cast(CS.getInstruction())->getNormalDest(), - cast(CS.getInstruction())->getUnwindDest(), Args); - NewCS.setCallingConv(CS.getCallingConv()); + NewCS = IRB.CreateInvoke(NewFT, IRB.CreateBitCast(JT, NewFTPtr), + cast(CB).getNormalDest(), + cast(CB).getUnwindDest(), Args); + NewCS->setCallingConv(CB.getCallingConv()); - AttributeList Attrs = CS.getAttributes(); + AttributeList Attrs = CB.getAttributes(); std::vector NewArgAttrs; NewArgAttrs.push_back(AttributeSet::get( M.getContext(), ArrayRef{Attribute::get( M.getContext(), Attribute::Nest)})); for (unsigned I = 0; I + 2 < Attrs.getNumAttrSets(); ++I) NewArgAttrs.push_back(Attrs.getParamAttributes(I)); - NewCS.setAttributes( + NewCS->setAttributes( AttributeList::get(M.getContext(), Attrs.getFnAttributes(), Attrs.getRetAttributes(), NewArgAttrs)); - CS->replaceAllUsesWith(NewCS.getInstruction()); - CS->eraseFromParent(); + CB.replaceAllUsesWith(NewCS); + CB.eraseFromParent(); // This use is no longer unsafe. if (VCallSite.NumUnsafeUses) @@ -1355,7 +1352,7 @@ for (auto Call : CSInfo.CallSites) Call.replaceAndErase( "uniform-ret-val", FnName, RemarksEnabled, OREGetter, - ConstantInt::get(cast(Call.CS.getType()), TheRetVal)); + ConstantInt::get(cast(Call.CB.getType()), TheRetVal)); CSInfo.markDevirt(); } @@ -1461,11 +1458,11 @@ bool IsOne, Constant *UniqueMemberAddr) { for (auto &&Call : CSInfo.CallSites) { - IRBuilder<> B(Call.CS.getInstruction()); + IRBuilder<> B(&Call.CB); Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable, B.CreateBitCast(UniqueMemberAddr, Call.VTable->getType())); - Cmp = B.CreateZExt(Cmp, Call.CS->getType()); + Cmp = B.CreateZExt(Cmp, Call.CB.getType()); Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter, Cmp); } @@ -1529,8 +1526,8 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, Constant *Byte, Constant *Bit) { for (auto Call : CSInfo.CallSites) { - auto *RetType = cast(Call.CS.getType()); - IRBuilder<> B(Call.CS.getInstruction()); + auto *RetType = cast(Call.CB.getType()); + IRBuilder<> B(&Call.CB); Value *Addr = B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte); if (RetType->getBitWidth() == 1) { @@ -1716,7 +1713,7 @@ // points to a member of the type identifier %md. Group calls by (type ID, // offset) pair (effectively the identity of the virtual function) and store // to CallSlots. - DenseSet SeenCallSites; + DenseSet SeenCallSites; for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end(); I != E;) { auto CI = dyn_cast(I->getUser()); @@ -1741,8 +1738,8 @@ // and we don't want to process call sites multiple times. We can't // just skip the vtable Ptr if it has been seen before, however, since // it may be shared by type tests that dominate different calls. - if (SeenCallSites.insert(Call.CS).second) - CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, nullptr); + if (SeenCallSites.insert(&Call.CB).second) + CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, nullptr); } } @@ -1828,7 +1825,7 @@ if (HasNonCallUses) ++NumUnsafeUses; for (DevirtCallSite Call : DevirtCalls) { - CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, + CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, &NumUnsafeUses); }