diff --git a/llvm/include/llvm/IR/CallSite.h b/llvm/include/llvm/IR/CallSite.h --- a/llvm/include/llvm/IR/CallSite.h +++ b/llvm/include/llvm/IR/CallSite.h @@ -757,7 +757,7 @@ /// For direct/indirect calls the parameter encoding is empty. If it is not, /// the abstract call site represents a callback. In that case, the first /// element of the encoding vector represents which argument of the call - /// site CS is the callback callee. The remaining elements map parameters + /// site CB is the callback callee. The remaining elements map parameters /// (identified by their position) to the arguments that will be passed /// through (also identified by position but in the call site instruction). /// @@ -774,7 +774,7 @@ /// The underlying call site: /// caller -> callee, if this is a direct or indirect call site /// caller -> broker function, if this is a callback call site - CallSite CS; + CallBase *CB; /// The encoding of a callback with regards to the underlying instruction. CallbackInfo CI; @@ -802,26 +802,23 @@ /// /// All uses added to \p CBUses can be used to create abstract call sites for /// which AbstractCallSite::isCallbackCall() will return true. - static void getCallbackUses(ImmutableCallSite ICS, - SmallVectorImpl &CBUses); + static void getCallbackUses(const CallBase &CB, + SmallVectorImpl &CallbackUses); /// Conversion operator to conveniently check for a valid/initialized ACS. - explicit operator bool() const { return (bool)CS; } + explicit operator bool() const { return CB != nullptr; } /// Return the underlying instruction. - Instruction *getInstruction() const { return CS.getInstruction(); } - - /// Return the call site abstraction for the underlying instruction. - CallSite getCallSite() const { return CS; } + CallBase *getInstruction() const { return CB; } /// Return true if this ACS represents a direct call. bool isDirectCall() const { - return !isCallbackCall() && !CS.isIndirectCall(); + return !isCallbackCall() && !CB->isIndirectCall(); } /// Return true if this ACS represents an indirect call. bool isIndirectCall() const { - return !isCallbackCall() && CS.isIndirectCall(); + return !isCallbackCall() && CB->isIndirectCall(); } /// Return true if this ACS represents a callback call. @@ -839,18 +836,18 @@ /// Return true if @p U is the use that defines the callee of this ACS. bool isCallee(const Use *U) const { if (isDirectCall()) - return CS.isCallee(U); + return CB->isCallee(U); assert(!CI.ParameterEncoding.empty() && "Callback without parameter encoding!"); - return (int)CS.getArgumentNo(U) == CI.ParameterEncoding[0]; + return (int)CB->getArgOperandNo(U) == CI.ParameterEncoding[0]; } /// Return the number of parameters of the callee. unsigned getNumArgOperands() const { if (isDirectCall()) - return CS.getNumArgOperands(); + return CB->getNumArgOperands(); // Subtract 1 for the callee encoding. return CI.ParameterEncoding.size() - 1; } @@ -879,10 +876,10 @@ /// function parameter number @p ArgNo or nullptr if there is none. Value *getCallArgOperand(unsigned ArgNo) const { if (isDirectCall()) - return CS.getArgOperand(ArgNo); + return CB->getArgOperand(ArgNo); // Add 1 for the callee encoding. return CI.ParameterEncoding[ArgNo + 1] >= 0 - ? CS.getArgOperand(CI.ParameterEncoding[ArgNo + 1]) + ? CB->getArgOperand(CI.ParameterEncoding[ArgNo + 1]) : nullptr; } @@ -906,8 +903,8 @@ /// Return the pointer to function that is being called. Value *getCalledValue() const { if (isDirectCall()) - return CS.getCalledValue(); - return CS.getArgOperand(getCallArgOperandNoForCallee()); + return CB->getCalledValue(); + return CB->getArgOperand(getCallArgOperandNoForCallee()); } /// Return the function being called if this is a direct call, otherwise diff --git a/llvm/lib/IR/AbstractCallSite.cpp b/llvm/lib/IR/AbstractCallSite.cpp --- a/llvm/lib/IR/AbstractCallSite.cpp +++ b/llvm/lib/IR/AbstractCallSite.cpp @@ -33,9 +33,9 @@ STATISTIC(NumInvalidAbstractCallSitesNoCallback, "Number of invalid abstract call sites created (no callback)"); -void AbstractCallSite::getCallbackUses(ImmutableCallSite ICS, - SmallVectorImpl &CBUses) { - const Function *Callee = ICS.getCalledFunction(); +void AbstractCallSite::getCallbackUses(const CallBase &CB, + SmallVectorImpl &CallbackUses) { + const Function *Callee = CB.getCalledFunction(); if (!Callee) return; @@ -48,57 +48,58 @@ auto *CBCalleeIdxAsCM = cast(OpMD->getOperand(0)); uint64_t CBCalleeIdx = cast(CBCalleeIdxAsCM->getValue())->getZExtValue(); - if (CBCalleeIdx < ICS.arg_size()) - CBUses.push_back(ICS.arg_begin() + CBCalleeIdx); + if (CBCalleeIdx < CB.arg_size()) + CallbackUses.push_back(CB.arg_begin() + CBCalleeIdx); } } /// Create an abstract call site from a use. -AbstractCallSite::AbstractCallSite(const Use *U) : CS(U->getUser()) { +AbstractCallSite::AbstractCallSite(const Use *U) + : CB(dyn_cast(U->getUser())) { // First handle unknown users. - if (!CS) { + if (!CB) { // If the use is actually in a constant cast expression which itself // has only one use, we look through the constant cast expression. // This happens by updating the use @p U to the use of the constant - // cast expression and afterwards re-initializing CS accordingly. + // cast expression and afterwards re-initializing CB accordingly. if (ConstantExpr *CE = dyn_cast(U->getUser())) if (CE->getNumUses() == 1 && CE->isCast()) { U = &*CE->use_begin(); - CS = CallSite(U->getUser()); + CB = dyn_cast(U->getUser()); } - if (!CS) { + if (!CB) { NumInvalidAbstractCallSitesUnknownUse++; return; } } // Then handle direct or indirect calls. Thus, if U is the callee of the - // call site CS it is not a callback and we are done. - if (CS.isCallee(U)) { + // call site CB it is not a callback and we are done. + if (CB->isCallee(U)) { NumDirectAbstractCallSites++; return; } // If we cannot identify the broker function we cannot create a callback and // invalidate the abstract call site. - Function *Callee = CS.getCalledFunction(); + Function *Callee = CB->getCalledFunction(); if (!Callee) { NumInvalidAbstractCallSitesUnknownCallee++; - CS = CallSite(); + CB = nullptr; return; } MDNode *CallbackMD = Callee->getMetadata(LLVMContext::MD_callback); if (!CallbackMD) { NumInvalidAbstractCallSitesNoCallback++; - CS = CallSite(); + CB = nullptr; return; } - unsigned UseIdx = CS.getArgumentNo(U); + unsigned UseIdx = CB->getArgOperandNo(U); MDNode *CallbackEncMD = nullptr; for (const MDOperand &Op : CallbackMD->operands()) { MDNode *OpMD = cast(Op.get()); @@ -113,7 +114,7 @@ if (!CallbackEncMD) { NumInvalidAbstractCallSitesNoCallback++; - CS = CallSite(); + CB = nullptr; return; } @@ -121,7 +122,7 @@ assert(CallbackEncMD->getNumOperands() >= 2 && "Incomplete !callback metadata"); - unsigned NumCallOperands = CS.getNumArgOperands(); + unsigned NumCallOperands = CB->getNumArgOperands(); // Skip the var-arg flag at the end when reading the metadata. for (unsigned u = 0, e = CallbackEncMD->getNumOperands() - 1; u < e; u++) { Metadata *OpAsM = CallbackEncMD->getOperand(u).get(); diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp --- a/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/llvm/lib/Transforms/IPO/Attributor.cpp @@ -151,10 +151,10 @@ // of the underlying call site operand, we want the corresponding callback // callee argument and not the direct callee argument. Optional CBCandidateArg; - SmallVector CBUses; - ImmutableCallSite ICS(&getAnchorValue()); - AbstractCallSite::getCallbackUses(ICS, CBUses); - for (const Use *U : CBUses) { + SmallVector CallbackUses; + const auto &CB = cast(getAnchorValue()); + AbstractCallSite::getCallbackUses(CB, CallbackUses); + for (const Use *U : CallbackUses) { AbstractCallSite ACS(U); assert(ACS && ACS.isCallbackCall()); if (!ACS.getCalledFunction()) @@ -183,7 +183,7 @@ // If no callbacks were found, or none used the underlying call site operand // exclusively, use the direct callee argument if available. - const Function *Callee = ICS.getCalledFunction(); + const Function *Callee = CB.getCalledFunction(); if (Callee && Callee->arg_size() > unsigned(ArgNo)) return Callee->getArg(ArgNo); @@ -1328,7 +1328,7 @@ auto CallSiteCanBeChanged = [](AbstractCallSite ACS) { // Forbid must-tail calls for now. - return !ACS.isCallbackCall() && !ACS.getCallSite().isMustTailCall(); + return !ACS.isCallbackCall() && !ACS.getInstruction()->isMustTailCall(); }; Function *Fn = Arg.getParent(); diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -4993,9 +4993,10 @@ // Helper to check if for the given call site the associated argument is // passed to a callback where the privatization would be different. auto IsCompatiblePrivArgOfCallback = [&](CallSite CS) { - SmallVector CBUses; - AbstractCallSite::getCallbackUses(CS, CBUses); - for (const Use *U : CBUses) { + SmallVector CallbackUses; + AbstractCallSite::getCallbackUses(cast(*CS.getInstruction()), + CallbackUses); + for (const Use *U : CallbackUses) { AbstractCallSite CBACS(U); assert(CBACS && CBACS.isCallbackCall()); for (Argument &CBArg : CBACS.getCalledFunction()->args()) { @@ -5081,7 +5082,7 @@ << Arg->getParent()->getName() << ")\n[AAPrivatizablePtr] because it is an argument in a " "direct call of (" - << ACS.getCallSite().getCalledFunction()->getName() + << ACS.getInstruction()->getCalledFunction()->getName() << ").\n[AAPrivatizablePtr] for which the argument " "privatization is not compatible.\n"; }); @@ -5093,7 +5094,7 @@ // here. auto IsCompatiblePrivArgOfOtherCallSite = [&](AbstractCallSite ACS) { if (ACS.isDirectCall()) - return IsCompatiblePrivArgOfCallback(ACS.getCallSite()); + return IsCompatiblePrivArgOfCallback(CallSite(ACS.getInstruction())); if (ACS.isCallbackCall()) return IsCompatiblePrivArgOfDirectCS(ACS); return false;