diff --git a/llvm/include/llvm/IR/CallSite.h b/llvm/include/llvm/IR/CallSite.h --- a/llvm/include/llvm/IR/CallSite.h +++ b/llvm/include/llvm/IR/CallSite.h @@ -693,6 +693,18 @@ User::op_iterator getCallee() const; }; +/// Establish a view to a call site for examination. +class ImmutableCallSite : public CallSiteBase<> { +public: + ImmutableCallSite() = default; + ImmutableCallSite(const CallInst *CI) : CallSiteBase(CI) {} + ImmutableCallSite(const InvokeInst *II) : CallSiteBase(II) {} + ImmutableCallSite(const CallBrInst *CBI) : CallSiteBase(CBI) {} + explicit ImmutableCallSite(const Instruction *II) : CallSiteBase(II) {} + explicit ImmutableCallSite(const Value *V) : CallSiteBase(V) {} + ImmutableCallSite(CallSite CS) : CallSiteBase(CS.getInstruction()) {} +}; + /// AbstractCallSite /// /// An abstract call site is a wrapper that allows to treat direct, @@ -765,6 +777,13 @@ /// as well as the callee of the abstract call site. AbstractCallSite(const Use *U); + /// Add operand uses of \p ICS that represent callback uses into \p CBUses. + /// + /// All uses added to \p CBUses can be used to create abstract call sites for + /// which AbstractCallSite::isCallbackCall() will return true. + static void getCallbackUses(ImmutableCallSite ICS, + SmallVectorImpl &CBUses); + /// Conversion operator to conveniently check for a valid/initialized ACS. explicit operator bool() const { return (bool)CS; } @@ -902,18 +921,6 @@ } }; -/// Establish a view to a call site for examination. -class ImmutableCallSite : public CallSiteBase<> { -public: - ImmutableCallSite() = default; - ImmutableCallSite(const CallInst *CI) : CallSiteBase(CI) {} - ImmutableCallSite(const InvokeInst *II) : CallSiteBase(II) {} - ImmutableCallSite(const CallBrInst *CBI) : CallSiteBase(CBI) {} - explicit ImmutableCallSite(const Instruction *II) : CallSiteBase(II) {} - explicit ImmutableCallSite(const Value *V) : CallSiteBase(V) {} - ImmutableCallSite(CallSite CS) : CallSiteBase(CS.getInstruction()) {} -}; - } // end namespace llvm #endif // LLVM_IR_CALLSITE_H diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -285,17 +285,7 @@ /// Return the associated argument, if any. /// ///{ - Argument *getAssociatedArgument() { - if (auto *Arg = dyn_cast(&getAnchorValue())) - return Arg; - int ArgNo = getArgNo(); - if (ArgNo < 0) - return nullptr; - Function *AssociatedFn = getAssociatedFunction(); - if (!AssociatedFn || AssociatedFn->arg_size() <= unsigned(ArgNo)) - return nullptr; - return AssociatedFn->arg_begin() + ArgNo; - } + Argument *getAssociatedArgument(); const Argument *getAssociatedArgument() const { return const_cast(this)->getAssociatedArgument(); } diff --git a/llvm/lib/IR/AbstractCallSite.cpp b/llvm/lib/IR/AbstractCallSite.cpp --- a/llvm/lib/IR/AbstractCallSite.cpp +++ b/llvm/lib/IR/AbstractCallSite.cpp @@ -33,6 +33,25 @@ STATISTIC(NumInvalidAbstractCallSitesNoCallback, "Number of invalid abstract call sites created (no callback)"); +void AbstractCallSite::getCallbackUses(ImmutableCallSite ICS, + SmallVectorImpl &CBUses) { + const Function *Callee = ICS.getCalledFunction(); + if (!Callee) + return; + + MDNode *CallbackMD = Callee->getMetadata(LLVMContext::MD_callback); + if (!CallbackMD) + return; + + for (const MDOperand &Op : CallbackMD->operands()) { + MDNode *OpMD = cast(Op.get()); + auto *CBCalleeIdxAsCM = cast(OpMD->getOperand(0)); + uint64_t CBCalleeIdx = + cast(CBCalleeIdxAsCM->getValue())->getZExtValue(); + CBUses.push_back(ICS.arg_begin() + CBCalleeIdx); + } +} + /// Create an abstract call site from a use. AbstractCallSite::AbstractCallSite(const Use *U) : CS(U->getUser()) { diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp --- a/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/llvm/lib/Transforms/IPO/Attributor.cpp @@ -153,6 +153,47 @@ } ///} +Argument *IRPosition::getAssociatedArgument() { + if (auto *Arg = dyn_cast(&getAnchorValue())) + return Arg; + int ArgNo = getArgNo(); + if (ArgNo < 0) + return nullptr; + + // Use abstract call sites to make the connection between the call site + // values and the ones in callbacks. If a callback was found that makes use + // of the underlying call site operand, we want the corresponding callback + // callee argument and not the direct callee argument. + SmallVector CBUses; + ImmutableCallSite ICS(&getAnchorValue()); + AbstractCallSite::getCallbackUses(ICS, CBUses); + for (const Use *U : CBUses) { + AbstractCallSite ACS(U); + assert(ACS && ACS.isCallbackCall()); + if (!ACS.getCalledFunction()) + continue; + + for (unsigned u = 0, e = ACS.getNumArgOperands(); u < e; u++) { + + // Test if the underlying call site operand is argument number u of the + // callback callee. + if (ACS.getCallArgOperandNo(u) != ArgNo) + continue; + + assert(ACS.getCalledFunction()->arg_size() > u && + "ACS mapped into var-args arguments!"); + return ACS.getCalledFunction()->getArg(u); + } + } + + // If no callbacks were found, or none used the underlying call site operand + // exclusively, use the direct callee argument if available. + const Function *Callee = ICS.getCalledFunction(); + if (Callee && Callee->arg_size() > ArgNo) + return Callee->getArg(ArgNo); + return nullptr; +} + /// Recursively visit all values that might become \p IRP at some point. This /// will be done by looking through cast instructions, selects, phis, and calls /// with the "returned" attribute. Once we cannot look through the value any @@ -1763,6 +1804,7 @@ // (iii) Check there is no other pointer argument which could alias with the // value. + // TODO: AbstractCallSite ImmutableCallSite ICS(&getAnchorValue()); for (unsigned i = 0; i < ICS.getNumArgOperands(); i++) { if (getArgNo() == (int)i) diff --git a/llvm/test/Transforms/FunctionAttrs/callbacks.ll b/llvm/test/Transforms/FunctionAttrs/callbacks.ll --- a/llvm/test/Transforms/FunctionAttrs/callbacks.ll +++ b/llvm/test/Transforms/FunctionAttrs/callbacks.ll @@ -11,7 +11,7 @@ ; each other but argument 3-5 of the transitive call site in the caller match ; arguments 2-4 of the callback callee. Here we should see information and value ; transfer in both directions. -; FIXME: The callee -> call site direction is not working yet. +; FIXME: %a should be align 256 in the callback and at the call site define void @t0_caller(i32* %a) { ; CHECK: @t0_caller(i32* [[A:%.*]]) @@ -22,7 +22,7 @@ ; CHECK-NEXT: [[TMP0:%.*]] = bitcast i32* [[B]] to i8* ; CHECK-NEXT: store i32 42, i32* [[B]], align 32 ; CHECK-NEXT: store i32* [[B]], i32** [[C]], align 64 -; CHECK-NEXT: call void (i32*, i32*, void (i32*, i32*, ...)*, ...) @t0_callback_broker(i32* null, i32* nonnull align 128 dereferenceable(4) [[PTR]], void (i32*, i32*, ...)* nonnull bitcast (void (i32*, i32*, i32*, i64, i32**)* @t0_callback_callee to void (i32*, i32*, ...)*), i32* [[A:%.*]], i64 99, i32** nonnull align 64 dereferenceable(8) [[C]]) +; CHECK-NEXT: call void (i32*, i32*, void (i32*, i32*, ...)*, ...) @t0_callback_broker(i32* null, i32* nonnull align 128 dereferenceable(4) [[PTR]], void (i32*, i32*, ...)* nonnull bitcast (void (i32*, i32*, i32*, i64, i32**)* @t0_callback_callee to void (i32*, i32*, ...)*), i32* [[A:%.*]], i64 99, i32** noalias nocapture nonnull align 64 dereferenceable(8) [[C]]) ; CHECK-NEXT: ret void ; entry: @@ -39,7 +39,7 @@ ; Note that the first two arguments are provided by the callback_broker according to the callback in !1 below! ; The others are annotated with alignment information, amongst others, or even replaced by the constants passed to the call. define internal void @t0_callback_callee(i32* %is_not_null, i32* %ptr, i32* %a, i64 %b, i32** %c) { -; CHECK: @t0_callback_callee(i32* nocapture writeonly [[IS_NOT_NULL:%.*]], i32* nocapture readonly [[PTR:%.*]], i32* [[A:%.*]], i64 [[B:%.*]], i32** nocapture nonnull readonly align 64 dereferenceable(8) [[C:%.*]]) +; CHECK: @t0_callback_callee(i32* nocapture writeonly [[IS_NOT_NULL:%.*]], i32* nocapture readonly [[PTR:%.*]], i32* [[A:%.*]], i64 [[B:%.*]], i32** noalias nocapture nonnull readonly align 64 dereferenceable(8) [[C:%.*]]) ; CHECK-NEXT: entry: ; CHECK-NEXT: [[PTR_VAL:%.*]] = load i32, i32* [[PTR:%.*]], align 8 ; CHECK-NEXT: store i32 [[PTR_VAL]], i32* [[IS_NOT_NULL:%.*]]