Index: llvm/lib/Transforms/IPO/AttributorAttributes.cpp =================================================================== --- llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -413,8 +413,9 @@ /// Clamp the information known for all returned values of a function /// (identified by \p QueryingAA) into \p S. template -static void clampReturnedValueStates(Attributor &A, const AAType &QueryingAA, - StateType &S) { +static void clampReturnedValueStates( + Attributor &A, const AAType &QueryingAA, StateType &S, + const IRPosition::CallBaseContext *CBContext = nullptr) { LLVM_DEBUG(dbgs() << "[Attributor] Clamp return value states for " << QueryingAA << " into " << S << "\n"); @@ -431,7 +432,7 @@ // Callback for each possibly returned value. auto CheckReturnValue = [&](Value &RV) -> bool { - const IRPosition &RVPos = IRPosition::value(RV); + const IRPosition &RVPos = IRPosition::value(RV, CBContext); const AAType &AA = A.getAAFor(QueryingAA, RVPos); LLVM_DEBUG(dbgs() << "[Attributor] RV: " << RV << " AA: " << AA.getAsStr() << " @ " << RVPos << "\n"); @@ -453,7 +454,8 @@ /// Helper class for generic deduction: return value -> returned position. template + typename StateType = typename BaseType::StateType, + bool PropagateCallBaseContext = false> struct AAReturnedFromReturnedValues : public BaseType { AAReturnedFromReturnedValues(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {} @@ -461,7 +463,9 @@ /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { StateType S(StateType::getBestState(this->getState())); - clampReturnedValueStates(A, *this, S); + clampReturnedValueStates( + A, *this, S, + PropagateCallBaseContext ? this->getCallBaseContext() : nullptr); // TODO: If we know we visited all returned values, thus no are assumed // dead, we can take the known information from the state T. return clampStateAndIndicateChange(this->getState(), S); @@ -515,17 +519,59 @@ S ^= *T; } -/// Helper class for generic deduction: call site argument -> argument position. +/// This function is the bridge between argument position and the call base +/// context. template +bool getArgumentStateFromCallBaseContext(Attributor &A, + BaseType &QueryingAttribute, + IRPosition &Pos, StateType &State) { + assert((Pos.getPositionKind() == IRPosition::IRP_ARGUMENT) && + "Expected a 'argument' position !"); + const CallBase *CBContext = Pos.getCallBaseContext(); + if (!CBContext) + return false; + int ArgNo = Pos.getArgNo(); + if (ArgNo < 0) + return false; + Value *Val = CBContext->getArgOperand(ArgNo); + assert(Val && "CBContext argument value must not be nullptr!"); + + const StateType &CBArgumentState = + (const StateType &)A + .getAAFor(QueryingAttribute, IRPosition::value(*Val)) + .getState(); + + LLVM_DEBUG(dbgs() << "[Attributor] Briding Call site context to argument" + << "Position:" << Pos << " Argument Value: " << Val + << "CB Arg state:" << CBArgumentState << "\n"); + + // NOTE: If we want to do call site grouping it should happen here. + State ^= CBArgumentState; + return true; +} + +/// Helper class for generic deduction: call site argument -> argument position. +template struct AAArgumentFromCallSiteArguments : public BaseType { AAArgumentFromCallSiteArguments(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {} /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { - StateType S(StateType::getBestState(this->getState())); + StateType S = StateType::getBestState(this->getState()); + + if (BridgeCallBaseContext) { + bool Succsess = + getArgumentStateFromCallBaseContext( + A, *this, this->getIRPosition(), S); + if (Succsess) + return clampStateAndIndicateChange(this->getState(), S); + } clampCallSiteArgumentStates(A, *this, S); + // TODO: If we know we visited all incoming values, thus no are assumed // dead, we can take the known information from the state T. return clampStateAndIndicateChange(this->getState(), S); @@ -534,7 +580,8 @@ /// Helper class for generic replication: function returned -> cs returned. template + typename StateType = typename BaseType::StateType, + bool IntroduceCallBaseContext = false> struct AACallSiteReturnedFromReturned : public BaseType { AACallSiteReturnedFromReturned(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {} @@ -552,7 +599,13 @@ if (!AssociatedFunction) return S.indicatePessimisticFixpoint(); - IRPosition FnPos = IRPosition::returned(*AssociatedFunction); + CallBase *CBContext = (CallBase *)&this->getAnchorValue(); + if (IntroduceCallBaseContext) + LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:" + << *CBContext << "\n"); + + IRPosition FnPos = IRPosition::returned( + *AssociatedFunction, IntroduceCallBaseContext ? CBContext : nullptr); const AAType &AA = A.getAAFor(*this, FnPos); return clampStateAndIndicateChange( S, static_cast(AA.getState()));