diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -1618,6 +1618,15 @@ const AAIsDead *FnLivenessAA, DepClassTy DepClass = DepClassTy::OPTIONAL); + /// Check \p Pred on all potential Callees of \p CB. + /// + /// This method will evaluate \p Pred on all potential callees of \p CB and + /// return true if \p Pred holds every time. The second argument to \p Pred + /// indicates if this is a must-be-called callee or a may-be-called callee. + /// If some callees might be unknown this function will return false. + bool checkForAllCallees(function_ref Pred, + const AbstractAttribute &QueryingAA, const CallBase &CB); + /// Check \p Pred on all (transitive) uses of \p V. /// /// This method will evaluate \p Pred on all (transitive) uses of the diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp --- a/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/llvm/lib/Transforms/IPO/Attributor.cpp @@ -999,6 +999,24 @@ return false; } +bool Attributor::checkForAllCallees( + function_ref Pred, + const AbstractAttribute &QueryingAA, const CallBase &CB) { + if (const Function *Callee = CB.getCalledFunction()) + return Pred(*Callee, true); + + const auto &CallEdgesAA = getAAFor( + QueryingAA, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); + if (CallEdgesAA.hasUnknownCallee()) + return false; + + const auto &Callees = CallEdgesAA.getOptimisticEdges(); + for (Function *Callee : Callees) + if (!Pred(*Callee, Callees.size() == 1)) + return false; + return true; +} + bool Attributor::checkForAllUses(function_ref Pred, const AbstractAttribute &QueryingAA, const Value &V, bool CheckBBLivenessOnly, diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -645,20 +645,22 @@ "positions!"); auto &S = this->getState(); - const Function *AssociatedFunction = - this->getIRPosition().getAssociatedFunction(); - if (!AssociatedFunction) - return S.indicatePessimisticFixpoint(); - - CallBase &CBContext = static_cast(this->getAnchorValue()); + CallBase &CB = cast(this->getAnchorValue()); if (IntroduceCallBaseContext) - LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:" - << CBContext << "\n"); + LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:" << CB + << "\n"); - IRPosition FnPos = IRPosition::returned( - *AssociatedFunction, IntroduceCallBaseContext ? &CBContext : nullptr); - const AAType &AA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(S, AA.getState()); + ChangeStatus Changed = ChangeStatus::UNCHANGED; + auto CalleePred = [&](const Function &Callee, bool MustCallee) { + IRPosition FnPos = IRPosition::returned( + Callee, IntroduceCallBaseContext ? &CB : nullptr); + const AAType &AA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); + Changed |= clampStateAndIndicateChange(S, AA.getState()); + return true; + }; + if (!A.checkForAllCallees(CalleePred, *this, CB)) + return S.indicatePessimisticFixpoint(); + return Changed; } }; @@ -1491,10 +1493,17 @@ // call site specific liveness information and then it makes // sense to specialize attributes for call sites arguments instead of // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), FnAA.getState()); + ChangeStatus Changed = ChangeStatus::UNCHANGED; + auto CalleePred = [&](const Function &Callee, bool MustCallee) { + const IRPosition &FnPos = IRPosition::function(Callee); + auto &FnAA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); + Changed |= clampStateAndIndicateChange(getState(), FnAA.getState()); + return true; + }; + CallBase &CB = *cast(getCtxI()); + if (!A.checkForAllCallees(CalleePred, *this, CB)) + return indicatePessimisticFixpoint(); + return Changed; } /// See AbstractAttribute::trackStatistics() @@ -1537,8 +1546,10 @@ indicatePessimisticFixpoint(); return; } - assert(!F->getReturnType()->isVoidTy() && - "Did not expect a void return type!"); + if (F->getReturnType()->isVoidTy()) { + indicatePessimisticFixpoint(); + return; + } // The map from instruction opcodes to those instructions in the function. auto &OpcodeInstMap = A.getInfoCache().getOpcodeInstMapForFunction(*F); @@ -1889,10 +1900,17 @@ // call site specific liveness information and then it makes // sense to specialize attributes for call sites arguments instead of // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), FnAA.getState()); + ChangeStatus Changed = ChangeStatus::UNCHANGED; + auto CalleePred = [&](const Function &Callee, bool MustCallee) { + const IRPosition &FnPos = IRPosition::function(Callee); + auto &FnAA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); + Changed |= clampStateAndIndicateChange(getState(), FnAA.getState()); + return true; + }; + CallBase &CB = *cast(getCtxI()); + if (!A.checkForAllCallees(CalleePred, *this, CB)) + return indicatePessimisticFixpoint(); + return Changed; } /// See AbstractAttribute::trackStatistics() @@ -1956,10 +1974,17 @@ // call site specific liveness information and then it makes // sense to specialize attributes for call sites arguments instead of // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), FnAA.getState()); + ChangeStatus Changed = ChangeStatus::UNCHANGED; + auto CalleePred = [&](const Function &Callee, bool MustCallee) { + const IRPosition &FnPos = IRPosition::function(Callee); + auto &FnAA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); + Changed |= clampStateAndIndicateChange(getState(), FnAA.getState()); + return true; + }; + CallBase &CB = *cast(getCtxI()); + if (!A.checkForAllCallees(CalleePred, *this, CB)) + return indicatePessimisticFixpoint(); + return Changed; } /// See AbstractAttribute::trackStatistics() @@ -2395,24 +2420,23 @@ AANoRecurseCallSite(const IRPosition &IRP, Attributor &A) : AANoRecurseImpl(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - AANoRecurseImpl::initialize(A); - Function *F = getAssociatedFunction(); - if (!F || F->isDeclaration()) - indicatePessimisticFixpoint(); - } - /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { // TODO: Once we have call site specific value information we can provide // call site specific liveness information and then it makes // sense to specialize attributes for call sites arguments instead of // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), FnAA.getState()); + ChangeStatus Changed = ChangeStatus::UNCHANGED; + auto CalleePred = [&](const Function &Callee, bool MustCallee) { + const IRPosition &FnPos = IRPosition::function(Callee); + auto &FnAA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); + Changed |= clampStateAndIndicateChange(getState(), FnAA.getState()); + return true; + }; + CallBase &CB = *cast(getCtxI()); + if (!A.checkForAllCallees(CalleePred, *this, CB)) + return indicatePessimisticFixpoint(); + return Changed; } /// See AbstractAttribute::trackStatistics() @@ -2879,15 +2903,21 @@ ChangeStatus updateImpl(Attributor &A) override { if (isImpliedByMustprogressAndReadonly(A, /* KnownOnly */ false)) return ChangeStatus::UNCHANGED; - // TODO: Once we have call site specific value information we can provide // call site specific liveness information and then it makes // sense to specialize attributes for call sites arguments instead of // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), FnAA.getState()); + ChangeStatus Changed = ChangeStatus::UNCHANGED; + auto CalleePred = [&](const Function &Callee, bool MustCallee) { + const IRPosition &FnPos = IRPosition::function(Callee); + auto &FnAA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); + Changed |= clampStateAndIndicateChange(getState(), FnAA.getState()); + return true; + }; + CallBase &CB = *cast(getCtxI()); + if (!A.checkForAllCallees(CalleePred, *this, CB)) + return indicatePessimisticFixpoint(); + return Changed; } /// See AbstractAttribute::trackStatistics() @@ -4623,10 +4653,17 @@ // call site specific liveness information and then it makes // sense to specialize attributes for call sites arguments instead of // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), FnAA.getState()); + ChangeStatus Changed = ChangeStatus::UNCHANGED; + auto CalleePred = [&](const Function &Callee, bool MustCallee) { + const IRPosition &FnPos = IRPosition::function(Callee); + auto &FnAA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); + Changed |= clampStateAndIndicateChange(getState(), FnAA.getState()); + return true; + }; + CallBase &CB = *cast(getCtxI()); + if (!A.checkForAllCallees(CalleePred, *this, CB)) + return indicatePessimisticFixpoint(); + return Changed; } /// See AbstractAttribute::trackStatistics() @@ -7177,14 +7214,21 @@ /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { // TODO: Once we have call site specific value information we can provide - // call site specific liveness liveness information and then it makes + // call site specific liveness information and then it makes // sense to specialize attributes for call sites arguments instead of // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = - A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), FnAA.getState()); + ChangeStatus Changed = ChangeStatus::UNCHANGED; + auto CalleePred = [&](const Function &Callee, bool MustCallee) { + const IRPosition &FnPos = IRPosition::function(Callee); + auto &FnAA = + A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); + Changed |= clampStateAndIndicateChange(getState(), FnAA.getState()); + return true; + }; + CallBase &CB = *cast(getCtxI()); + if (!A.checkForAllCallees(CalleePred, *this, CB)) + return indicatePessimisticFixpoint(); + return Changed; } /// See AbstractAttribute::trackStatistics() @@ -9658,6 +9702,7 @@ } void trackStatistics() const override {} + private: bool canReachUnknownCallee() const override { return WholeFunction.CanReachUnknownCallee; diff --git a/llvm/test/Transforms/Attributor/IPConstantProp/arg-count-mismatch.ll b/llvm/test/Transforms/Attributor/IPConstantProp/arg-count-mismatch.ll --- a/llvm/test/Transforms/Attributor/IPConstantProp/arg-count-mismatch.ll +++ b/llvm/test/Transforms/Attributor/IPConstantProp/arg-count-mismatch.ll @@ -36,15 +36,20 @@ ; FIXME we should recognize this as UB and make it an unreachable. define dso_local i16 @foo(i16 %a) { -; NOT_CGSCC_NPM-LABEL: define {{[^@]+}}@foo -; NOT_CGSCC_NPM-SAME: (i16 [[A:%.*]]) { -; NOT_CGSCC_NPM-NEXT: [[CALL:%.*]] = call i16 bitcast (i16 (i16, i16)* @bar to i16 (i16)*)(i16 [[A]]) -; NOT_CGSCC_NPM-NEXT: ret i16 [[CALL]] +; IS__TUNIT____-LABEL: define {{[^@]+}}@foo +; IS__TUNIT____-SAME: (i16 [[A:%.*]]) { +; IS__TUNIT____-NEXT: [[CALL:%.*]] = call noundef i16 bitcast (i16 (i16, i16)* @bar to i16 (i16)*)(i16 [[A]]) +; IS__TUNIT____-NEXT: ret i16 [[CALL]] +; +; IS__CGSCC_OPM-LABEL: define {{[^@]+}}@foo +; IS__CGSCC_OPM-SAME: (i16 [[A:%.*]]) { +; IS__CGSCC_OPM-NEXT: [[CALL:%.*]] = call i16 bitcast (i16 (i16, i16)* @bar to i16 (i16)*)(i16 [[A]]) +; IS__CGSCC_OPM-NEXT: ret i16 [[CALL]] ; ; IS__CGSCC_NPM: Function Attrs: nofree norecurse nosync nounwind readnone ; IS__CGSCC_NPM-LABEL: define {{[^@]+}}@foo ; IS__CGSCC_NPM-SAME: (i16 [[A:%.*]]) #[[ATTR0:[0-9]+]] { -; IS__CGSCC_NPM-NEXT: [[CALL:%.*]] = call i16 bitcast (i16 (i16, i16)* @bar to i16 (i16)*)(i16 [[A]]) +; IS__CGSCC_NPM-NEXT: [[CALL:%.*]] = call noundef i16 bitcast (i16 (i16, i16)* @bar to i16 (i16)*)(i16 [[A]]) ; IS__CGSCC_NPM-NEXT: ret i16 [[CALL]] ; %call = call i16 bitcast (i16 (i16, i16) * @bar to i16 (i16) *)(i16 %a) @@ -116,18 +121,20 @@ ; been provided), define dso_local i16 @vararg_tests(i16 %a) { -; NOT_CGSCC_NPM-LABEL: define {{[^@]+}}@vararg_tests -; NOT_CGSCC_NPM-SAME: (i16 [[A:%.*]]) { -; NOT_CGSCC_NPM-NEXT: [[CALL2:%.*]] = call i16 bitcast (i16 (i16, i16, ...)* @vararg_no_prop to i16 (i16)*)(i16 noundef 7) -; NOT_CGSCC_NPM-NEXT: [[ADD:%.*]] = add i16 7, [[CALL2]] -; NOT_CGSCC_NPM-NEXT: ret i16 [[ADD]] +; IS__TUNIT____-LABEL: define {{[^@]+}}@vararg_tests +; IS__TUNIT____-SAME: (i16 [[A:%.*]]) { +; IS__TUNIT____-NEXT: ret i16 14 ; -; IS__CGSCC_NPM: Function Attrs: nofree norecurse nosync nounwind readnone +; IS__CGSCC_OPM-LABEL: define {{[^@]+}}@vararg_tests +; IS__CGSCC_OPM-SAME: (i16 [[A:%.*]]) { +; IS__CGSCC_OPM-NEXT: [[CALL2:%.*]] = call i16 bitcast (i16 (i16, i16, ...)* @vararg_no_prop to i16 (i16)*)(i16 noundef 7) +; IS__CGSCC_OPM-NEXT: [[ADD:%.*]] = add i16 7, [[CALL2]] +; IS__CGSCC_OPM-NEXT: ret i16 14 +; +; IS__CGSCC_NPM: Function Attrs: nofree norecurse nosync nounwind readnone willreturn ; IS__CGSCC_NPM-LABEL: define {{[^@]+}}@vararg_tests -; IS__CGSCC_NPM-SAME: (i16 [[A:%.*]]) #[[ATTR0]] { -; IS__CGSCC_NPM-NEXT: [[CALL2:%.*]] = call i16 bitcast (i16 (i16, i16, ...)* @vararg_no_prop to i16 (i16)*)(i16 noundef 7) -; IS__CGSCC_NPM-NEXT: [[ADD:%.*]] = add i16 7, [[CALL2]] -; IS__CGSCC_NPM-NEXT: ret i16 [[ADD]] +; IS__CGSCC_NPM-SAME: (i16 [[A:%.*]]) #[[ATTR1]] { +; IS__CGSCC_NPM-NEXT: ret i16 14 ; %call1 = call i16 (i16, ...) @vararg_prop(i16 7, i16 8, i16 %a) %call2 = call i16 bitcast (i16 (i16, i16, ...) * @vararg_no_prop to i16 (i16) *) (i16 7) @@ -150,20 +157,15 @@ } define internal i16 @vararg_no_prop(i16 %p1, i16 %p2, ...) { -; IS__TUNIT____: Function Attrs: nofree nosync nounwind readnone willreturn -; IS__TUNIT____-LABEL: define {{[^@]+}}@vararg_no_prop -; IS__TUNIT____-SAME: (i16 [[P1:%.*]], i16 [[P2:%.*]], ...) #[[ATTR0]] { -; IS__TUNIT____-NEXT: ret i16 7 -; ; IS__CGSCC_OPM: Function Attrs: nofree norecurse nosync nounwind readnone willreturn ; IS__CGSCC_OPM-LABEL: define {{[^@]+}}@vararg_no_prop ; IS__CGSCC_OPM-SAME: (i16 [[P1:%.*]], i16 [[P2:%.*]], ...) #[[ATTR0]] { -; IS__CGSCC_OPM-NEXT: ret i16 7 +; IS__CGSCC_OPM-NEXT: ret i16 undef ; ; IS__CGSCC_NPM: Function Attrs: nofree norecurse nosync nounwind readnone willreturn ; IS__CGSCC_NPM-LABEL: define {{[^@]+}}@vararg_no_prop ; IS__CGSCC_NPM-SAME: (i16 [[P1:%.*]], i16 [[P2:%.*]], ...) #[[ATTR1]] { -; IS__CGSCC_NPM-NEXT: ret i16 7 +; IS__CGSCC_NPM-NEXT: ret i16 undef ; ret i16 %p1 }