diff --git a/llvm/include/llvm/IR/CallSite.h b/llvm/include/llvm/IR/CallSite.h --- a/llvm/include/llvm/IR/CallSite.h +++ b/llvm/include/llvm/IR/CallSite.h @@ -854,6 +854,15 @@ return CI.ParameterEncoding[0]; } + /// Return the use of the callee value in the underlying instruction. Only + /// valid for callback calls! + const Use &getCalleeUseForCallback() const { + int CalleeArgIdx = getCallArgOperandNoForCallee(); + assert(CalleeArgIdx >= 0 && + unsigned(CalleeArgIdx) < getInstruction()->getNumOperands()); + return getInstruction()->getOperandUse(CalleeArgIdx); + } + /// Return the pointer to function that is being called. Value *getCalledValue() const { if (isDirectCall()) diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -216,6 +216,16 @@ ArgNo); } + /// Create a position describing the argument of \p ACS at position \p ArgNo. + static const IRPosition callsite_argument(AbstractCallSite ACS, + unsigned ArgNo) { + int CSArgNo = ACS.getCallArgOperandNo(ArgNo); + if (CSArgNo >= 0) + return IRPosition::callsite_argument( + cast(*ACS.getInstruction()), CSArgNo); + return IRPosition(); + } + /// Create a position with function scope matching the "context" of \p IRP. /// If \p IRP is a call site (see isAnyCallSitePosition()) then the result /// will be a call site position, otherwise the function position of the @@ -825,7 +835,7 @@ /// This method will evaluate \p Pred on call sites and return /// true if \p Pred holds in every call sites. However, this is only possible /// all call sites are known, hence the function has internal linkage. - bool checkForAllCallSites(const function_ref &Pred, + bool checkForAllCallSites(const function_ref &Pred, const AbstractAttribute &QueryingAA, bool RequireAllCallSites); diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp --- a/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/llvm/lib/Transforms/IPO/Attributor.cpp @@ -596,11 +596,16 @@ // The argument number which is also the call site argument number. unsigned ArgNo = QueryingAA.getIRPosition().getArgNo(); - auto CallSiteCheck = [&](CallSite CS) { - const IRPosition &CSArgPos = IRPosition::callsite_argument(CS, ArgNo); - const AAType &AA = A.getAAFor(QueryingAA, CSArgPos); - LLVM_DEBUG(dbgs() << "[Attributor] CS: " << *CS.getInstruction() - << " AA: " << AA.getAsStr() << " @" << CSArgPos << "\n"); + auto CallSiteCheck = [&](AbstractCallSite ACS) { + const IRPosition &ACSArgPos = IRPosition::callsite_argument(ACS, ArgNo); + // Check if a coresponding argument was found or if it is on not associated + // (which can happen for callback calls). + if (ACSArgPos.getPositionKind() == IRPosition::IRP_INVALID) + return false; + + const AAType &AA = A.getAAFor(QueryingAA, ACSArgPos); + LLVM_DEBUG(dbgs() << "[Attributor] ACS: " << *ACS.getInstruction() + << " AA: " << AA.getAsStr() << " @" << ACSArgPos << "\n"); const StateType &AAS = static_cast(AA.getState()); if (T.hasValue()) *T &= AAS; @@ -3100,9 +3105,12 @@ ChangeStatus updateImpl(Attributor &A) override { bool HasValueBefore = SimplifiedAssociatedValue.hasValue(); - auto PredForCallSite = [&](CallSite CS) { - return checkAndUpdate(A, *this, *CS.getArgOperand(getArgNo()), - SimplifiedAssociatedValue); + auto PredForCallSite = [&](AbstractCallSite ACS) { + // Check if we have an associated argument or not (which can happen for + // callback calls). + if (Value *ArgOp = ACS.getCallArgOperand(getArgNo())) + return checkAndUpdate(A, *this, *ArgOp, SimplifiedAssociatedValue); + return false; }; if (!A.checkForAllCallSites(PredForCallSite, *this, true)) @@ -3914,9 +3922,9 @@ return true; } -bool Attributor::checkForAllCallSites(const function_ref &Pred, - const AbstractAttribute &QueryingAA, - bool RequireAllCallSites) { +bool Attributor::checkForAllCallSites( + const function_ref &Pred, + const AbstractAttribute &QueryingAA, bool RequireAllCallSites) { // We can try to determine information from // the call sites. However, this is only possible all call sites are known, // hence the function has internal linkage. @@ -3934,15 +3942,21 @@ } for (const Use &U : AssociatedFunction->uses()) { - Instruction *I = dyn_cast(U.getUser()); - // TODO: Deal with abstract call sites here. - if (!I) + AbstractCallSite ACS(&U); + if (!ACS) { + LLVM_DEBUG(dbgs() << "[Attributor] Function " + << AssociatedFunction->getName() + << " has non call site use " << *U.get() << " in " + << *U.getUser() << "\n"); return false; + } + Instruction *I = ACS.getInstruction(); Function *Caller = I->getFunction(); - const auto &LivenessAA = getAAFor( - QueryingAA, IRPosition::function(*Caller), /* TrackDependence */ false); + const auto &LivenessAA = + getAAFor(QueryingAA, IRPosition::function(*Caller), + /* TrackDependence */ false); // Skip dead calls. if (LivenessAA.isAssumedDead(I)) { @@ -3952,22 +3966,22 @@ continue; } - CallSite CS(U.getUser()); - if (!CS || !CS.isCallee(&U)) { + const Use *EffectiveUse = + ACS.isCallbackCall() ? &ACS.getCalleeUseForCallback() : &U; + if (!ACS.isCallee(EffectiveUse)) { if (!RequireAllCallSites) continue; - - LLVM_DEBUG(dbgs() << "[Attributor] User " << *U.getUser() + LLVM_DEBUG(dbgs() << "[Attributor] User " << EffectiveUse->getUser() << " is an invalid use of " << AssociatedFunction->getName() << "\n"); return false; } - if (Pred(CS)) + if (Pred(ACS)) continue; LLVM_DEBUG(dbgs() << "[Attributor] Call site callback failed for " - << *CS.getInstruction() << "\n"); + << *ACS.getInstruction() << "\n"); return false; } @@ -4319,7 +4333,7 @@ const auto *LivenessAA = lookupAAFor(IRPosition::function(*F)); if (LivenessAA && - !checkForAllCallSites([](CallSite CS) { return false; }, + !checkForAllCallSites([](AbstractCallSite ACS) { return false; }, *LivenessAA, true)) continue; diff --git a/llvm/test/Transforms/FunctionAttrs/callbacks.ll b/llvm/test/Transforms/FunctionAttrs/callbacks.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionAttrs/callbacks.ll @@ -0,0 +1,63 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -passes=attributor -aa-pipeline='basic-aa' -attributor-disable=false -attributor-max-iterations-verify -attributor-max-iterations=1 < %s | FileCheck %s +; ModuleID = 'callback_simple.c' +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" + +; Test 0 +; +; Make sure we propagate information from the caller to the callback callee but +; only for arguments that are mapped through the callback metadata. Here, the +; first two arguments of the call and the callback callee do not correspond to +; each other but argument 3-5 of the transitive call site in the caller match +; arguments 2-4 of the callback callee. Here we should see information and value +; transfer in both directions. +; FIXME: The callee -> call site direction is not working yet. + +define void @t0_caller(i32* %a) { +; CHECK: @t0_caller(i32* [[A:%.*]]) +; CHECK-NEXT: entry: +; CHECK-NEXT: [[B:%.*]] = alloca i32, align 32 +; CHECK-NEXT: [[C:%.*]] = alloca i32*, align 64 +; CHECK-NEXT: [[PTR:%.*]] = alloca i32, align 128 +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i32* [[B]] to i8* +; CHECK-NEXT: store i32 42, i32* [[B]], align 32 +; CHECK-NEXT: store i32* [[B]], i32** [[C]], align 64 +; CHECK-NEXT: call void (i32*, i32*, void (i32*, i32*, ...)*, ...) @t0_callback_broker(i32* null, i32* nonnull align 128 dereferenceable(4) [[PTR]], void (i32*, i32*, ...)* nonnull bitcast (void (i32*, i32*, i32*, i64, i32**)* @t0_callback_callee to void (i32*, i32*, ...)*), i32* [[A:%.*]], i64 99, i32** nonnull align 64 dereferenceable(8) [[C]]) +; CHECK-NEXT: ret void +; +entry: + %b = alloca i32, align 32 + %c = alloca i32*, align 64 + %ptr = alloca i32, align 128 + %0 = bitcast i32* %b to i8* + store i32 42, i32* %b, align 4 + store i32* %b, i32** %c, align 8 + call void (i32*, i32*, void (i32*, i32*, ...)*, ...) @t0_callback_broker(i32* null, i32* %ptr, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i32*, i64, i32**)* @t0_callback_callee to void (i32*, i32*, ...)*), i32* %a, i64 99, i32** %c) + ret void +} + +; Note that the first two arguments are provided by the callback_broker according to the callback in !1 below! +; The others are annotated with alignment information, amongst others, or even replaced by the constants passed to the call. +define internal void @t0_callback_callee(i32* %is_not_null, i32* %ptr, i32* %a, i64 %b, i32** %c) { +; CHECK: @t0_callback_callee(i32* nocapture writeonly [[IS_NOT_NULL:%.*]], i32* nocapture readonly [[PTR:%.*]], i32* [[A:%.*]], i64 [[B:%.*]], i32** nocapture nonnull readonly align 64 dereferenceable(8) [[C:%.*]]) +; CHECK-NEXT: entry: +; CHECK-NEXT: [[PTR_VAL:%.*]] = load i32, i32* [[PTR:%.*]], align 8 +; CHECK-NEXT: store i32 [[PTR_VAL]], i32* [[IS_NOT_NULL:%.*]] +; CHECK-NEXT: [[TMP0:%.*]] = load i32*, i32** [[C:%.*]], align 64 +; CHECK-NEXT: tail call void @t0_check(i32* align 256 [[A:%.*]], i64 99, i32* [[TMP0]]) +; CHECK-NEXT: ret void +; +entry: + %ptr_val = load i32, i32* %ptr, align 8 + store i32 %ptr_val, i32* %is_not_null + %0 = load i32*, i32** %c, align 8 + tail call void @t0_check(i32* %a, i64 %b, i32* %0) + ret void +} + +declare void @t0_check(i32* align 256, i64, i32*) + +declare !callback !0 void @t0_callback_broker(i32*, i32*, void (i32*, i32*, ...)*, ...) + +!0 = !{!1} +!1 = !{i64 2, i64 -1, i64 -1, i1 true}