diff --git a/llvm/include/llvm/IR/CallSite.h b/llvm/include/llvm/IR/CallSite.h --- a/llvm/include/llvm/IR/CallSite.h +++ b/llvm/include/llvm/IR/CallSite.h @@ -688,6 +688,18 @@ User::op_iterator getCallee() const; }; +/// Establish a view to a call site for examination. +class ImmutableCallSite : public CallSiteBase<> { +public: + ImmutableCallSite() = default; + ImmutableCallSite(const CallInst *CI) : CallSiteBase(CI) {} + ImmutableCallSite(const InvokeInst *II) : CallSiteBase(II) {} + ImmutableCallSite(const CallBrInst *CBI) : CallSiteBase(CBI) {} + explicit ImmutableCallSite(const Instruction *II) : CallSiteBase(II) {} + explicit ImmutableCallSite(const Value *V) : CallSiteBase(V) {} + ImmutableCallSite(CallSite CS) : CallSiteBase(CS.getInstruction()) {} +}; + /// AbstractCallSite /// /// An abstract call site is a wrapper that allows to treat direct, @@ -760,6 +772,13 @@ /// as well as the callee of the abstract call site. AbstractCallSite(const Use *U); + /// Add operand uses of \p ICS that represent callback uses into \p CBUses. + /// + /// All uses added to \p CBUses can be used to create abstract call sites for + /// which AbstractCallSite::isCallbackCall() will return true. + static void getCallbackUses(ImmutableCallSite ICS, + SmallVectorImpl &CBUses); + /// Conversion operator to conveniently check for a valid/initialized ACS. explicit operator bool() const { return (bool)CS; } @@ -888,18 +907,6 @@ } }; -/// Establish a view to a call site for examination. -class ImmutableCallSite : public CallSiteBase<> { -public: - ImmutableCallSite() = default; - ImmutableCallSite(const CallInst *CI) : CallSiteBase(CI) {} - ImmutableCallSite(const InvokeInst *II) : CallSiteBase(II) {} - ImmutableCallSite(const CallBrInst *CBI) : CallSiteBase(CBI) {} - explicit ImmutableCallSite(const Instruction *II) : CallSiteBase(II) {} - explicit ImmutableCallSite(const Value *V) : CallSiteBase(V) {} - ImmutableCallSite(CallSite CS) : CallSiteBase(CS.getInstruction()) {} -}; - } // end namespace llvm #endif // LLVM_IR_CALLSITE_H diff --git a/llvm/lib/IR/AbstractCallSite.cpp b/llvm/lib/IR/AbstractCallSite.cpp --- a/llvm/lib/IR/AbstractCallSite.cpp +++ b/llvm/lib/IR/AbstractCallSite.cpp @@ -33,6 +33,25 @@ STATISTIC(NumInvalidAbstractCallSitesNoCallback, "Number of invalid abstract call sites created (no callback)"); +void AbstractCallSite::getCallbackUses(ImmutableCallSite ICS, + SmallVectorImpl &CBUses) { + const Function *Callee = ICS.getCalledFunction(); + if (!Callee) + return; + + MDNode *CallbackMD = Callee->getMetadata(LLVMContext::MD_callback); + if (!CallbackMD) + return; + + for (const MDOperand &Op : CallbackMD->operands()) { + MDNode *OpMD = cast(Op.get()); + auto *CBCalleeIdxAsCM = cast(OpMD->getOperand(0)); + uint64_t CBCalleeIdx = + cast(CBCalleeIdxAsCM->getValue())->getZExtValue(); + CBUses.push_back(ICS.arg_begin() + CBCalleeIdx); + } +} + /// Create an abstract call site from a use. AbstractCallSite::AbstractCallSite(const Use *U) : CS(U->getUser()) { diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp --- a/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/llvm/lib/Transforms/IPO/Attributor.cpp @@ -211,7 +211,37 @@ const Argument *getArgument() const { ImmutableCallSite ICS(&AAImplType::getAnchoredValue()); const Function *Callee = ICS.getCalledFunction(); - if (Callee && Callee->arg_size() > ArgNo) + if (!Callee) + return nullptr; + + // Use abstract call sites to make the connection between the call site + // values and the ones in callbacks. If a callback was found that makes use + // of the underlying call site operand, we want the corresponding callback + // callee argument and not the direct callee argument. + SmallVector CBUses; + AbstractCallSite::getCallbackUses(ICS, CBUses); + for (const Use *U : CBUses) { + AbstractCallSite ACS(U); + assert(ACS && ACS.isCallbackCall()); + if (!ACS.getCalledFunction()) + continue; + + for (unsigned u = 0, e = ACS.getNumArgOperands(); u < e; u++) { + + // Test if the underlying call site operand is argument number u of the + // callback callee. + if (ACS.getCallArgOperandNo(u) != ArgNo) + continue; + + assert(ACS.getCalledFunction()->arg_size() > u && + "ACS mapped into var-args arguments!"); + return ACS.getCalledFunction()->arg_begin() + u; + } + } + + // If no callbacks were found, or none used the underlying call site operand + // exclusively, use the direct callee argument if available. + if (Callee->arg_size() > ArgNo) return Callee->arg_begin() + ArgNo; return nullptr; }