diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -980,9 +980,11 @@ /// This method will evaluate \p Pred on call sites and return /// true if \p Pred holds in every call sites. However, this is only possible /// all call sites are known, hence the function has internal linkage. + /// If true is returned, \p AllCallSitesKnown is set if all possible call + /// sites of the function have been visited. bool checkForAllCallSites(const function_ref &Pred, const AbstractAttribute &QueryingAA, - bool RequireAllCallSites); + bool RequireAllCallSites, bool &AllCallSitesKnown); /// Check \p Pred on all values potentially returned by \p F. /// @@ -1040,9 +1042,12 @@ /// This method will evaluate \p Pred on call sites and return /// true if \p Pred holds in every call sites. However, this is only possible /// all call sites are known, hence the function has internal linkage. + /// If true is returned, \p AllCallSitesKnown is set if all possible call + /// sites of the function have been visited. bool checkForAllCallSites(const function_ref &Pred, const Function &Fn, bool RequireAllCallSites, - const AbstractAttribute *QueryingAA); + const AbstractAttribute *QueryingAA, + bool &AllCallSitesKnown); /// The private version of getAAFor that allows to omit a querying abstract /// attribute. See also the public getAAFor method. @@ -2094,6 +2099,9 @@ /// Returns true if the underlying value is assumed dead. virtual bool isAssumedDead() const = 0; + /// Returns true if the underlying value is known dead. + virtual bool isKnownDead() const = 0; + /// Returns true if \p BB is assumed dead. virtual bool isAssumedDead(const BasicBlock *BB) const = 0; diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp --- a/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/llvm/lib/Transforms/IPO/Attributor.cpp @@ -13,7 +13,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO/Attributor.h" +#include "llvm/Transforms/IPO/Attributor.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/STLExtras.h" @@ -866,7 +866,9 @@ return T->isValidState(); }; - if (!A.checkForAllCallSites(CallSiteCheck, QueryingAA, true)) + bool AllCallSitesKnown; + if (!A.checkForAllCallSites(CallSiteCheck, QueryingAA, true, + AllCallSitesKnown)) S.indicatePessimisticFixpoint(); else if (T.hasValue()) S ^= *T; @@ -2543,9 +2545,10 @@ // If the argument is never passed through callbacks, no-alias cannot break // synchronization. + bool AllCallSitesKnown; if (A.checkForAllCallSites( [](AbstractCallSite ACS) { return !ACS.isCallbackCall(); }, *this, - true)) + true, AllCallSitesKnown)) return Base::updateImpl(A); // TODO: add no-alias but make sure it doesn't break synchronization by @@ -2773,6 +2776,9 @@ /// See AAIsDead::isAssumedDead(). bool isAssumedDead() const override { return getAssumed(); } + /// See AAIsDead::isKnownDead(). + bool isKnownDead() const override { return getKnown(); } + /// See AAIsDead::isAssumedDead(BasicBlock *). bool isAssumedDead(const BasicBlock *BB) const override { return false; } @@ -2928,18 +2934,25 @@ /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { + bool AllKnownDead = true; auto PredForCallSite = [&](AbstractCallSite ACS) { if (ACS.isCallbackCall()) return false; const IRPosition &CSRetPos = IRPosition::callsite_returned(ACS.getCallSite()); const auto &RetIsDeadAA = A.getAAFor(*this, CSRetPos); + AllKnownDead &= RetIsDeadAA.isKnownDead(); return RetIsDeadAA.isAssumedDead(); }; - if (!A.checkForAllCallSites(PredForCallSite, *this, true)) + bool AllCallSitesKnown; + if (!A.checkForAllCallSites(PredForCallSite, *this, true, + AllCallSitesKnown)) return indicatePessimisticFixpoint(); + if (AllCallSitesKnown && AllKnownDead) + indicateOptimisticFixpoint(); + return ChangeStatus::UNCHANGED; } @@ -3045,6 +3058,9 @@ /// Returns true if the function is assumed dead. bool isAssumedDead() const override { return false; } + /// See AAIsDead::isKnownDead(). + bool isKnownDead() const override { return false; } + /// See AAIsDead::isAssumedDead(BasicBlock *). bool isAssumedDead(const BasicBlock *BB) const override { assert(BB->getParent() == getAssociatedFunction() && @@ -4474,7 +4490,9 @@ return checkAndUpdate(A, *this, *ArgOp, SimplifiedAssociatedValue); }; - if (!A.checkForAllCallSites(PredForCallSite, *this, true)) + bool AllCallSitesKnown; + if (!A.checkForAllCallSites(PredForCallSite, *this, true, + AllCallSitesKnown)) if (!askSimplifiedValueForAAValueConstantRange(A)) return indicatePessimisticFixpoint(); @@ -4883,9 +4901,10 @@ Optional identifyPrivatizableType(Attributor &A) override { // If this is a byval argument and we know all the call sites (so we can // rewrite them), there is no need to check them explicitly. + bool AllCallSitesKnown; if (getIRPosition().hasAttr(Attribute::ByVal) && A.checkForAllCallSites([](AbstractCallSite ACS) { return true; }, *this, - true)) + true, AllCallSitesKnown)) return getAssociatedValue().getType()->getPointerElementType(); Optional Ty; @@ -4934,7 +4953,7 @@ return !Ty.hasValue() || Ty.getValue(); }; - if (!A.checkForAllCallSites(CallSiteCheck, *this, true)) + if (!A.checkForAllCallSites(CallSiteCheck, *this, true, AllCallSitesKnown)) return nullptr; return Ty; } @@ -5098,8 +5117,9 @@ return false; }; - if (!A.checkForAllCallSites(IsCompatiblePrivArgOfOtherCallSite, *this, - true)) + bool AllCallSitesKnown; + if (!A.checkForAllCallSites(IsCompatiblePrivArgOfOtherCallSite, *this, true, + AllCallSitesKnown)) return indicatePessimisticFixpoint(); return ChangeStatus::UNCHANGED; @@ -6415,7 +6435,8 @@ bool Attributor::checkForAllCallSites( const function_ref &Pred, - const AbstractAttribute &QueryingAA, bool RequireAllCallSites) { + const AbstractAttribute &QueryingAA, bool RequireAllCallSites, + bool &AllCallSitesKnown) { // We can try to determine information from // the call sites. However, this is only possible all call sites are known, // hence the function has internal linkage. @@ -6424,24 +6445,30 @@ if (!AssociatedFunction) { LLVM_DEBUG(dbgs() << "[Attributor] No function associated with " << IRP << "\n"); + AllCallSitesKnown = false; return false; } return checkForAllCallSites(Pred, *AssociatedFunction, RequireAllCallSites, - &QueryingAA); + &QueryingAA, AllCallSitesKnown); } bool Attributor::checkForAllCallSites( const function_ref &Pred, const Function &Fn, - bool RequireAllCallSites, const AbstractAttribute *QueryingAA) { + bool RequireAllCallSites, const AbstractAttribute *QueryingAA, + bool &AllCallSitesKnown) { if (RequireAllCallSites && !Fn.hasLocalLinkage()) { LLVM_DEBUG( dbgs() << "[Attributor] Function " << Fn.getName() << " has no internal linkage, hence not all call sites are known\n"); + AllCallSitesKnown = false; return false; } + // If we do not require all call sites we might not see all. + AllCallSitesKnown = RequireAllCallSites; + for (const Use &U : Fn.uses()) { AbstractCallSite ACS(&U); if (!ACS) { @@ -6467,6 +6494,7 @@ // dependence. if (QueryingAA) recordDependence(*LivenessAA, *QueryingAA, DepClassTy::OPTIONAL); + AllCallSitesKnown = false; continue; } @@ -6914,12 +6942,13 @@ if (!F) continue; + bool AllCallSitesKnown; if (!checkForAllCallSites( [this](AbstractCallSite ACS) { return ToBeDeletedFunctions.count( ACS.getInstruction()->getFunction()); }, - *F, true, nullptr)) + *F, true, nullptr, AllCallSitesKnown)) continue; ToBeDeletedFunctions.insert(F); @@ -6979,7 +7008,9 @@ } // Avoid callbacks for now. - if (!checkForAllCallSites(CallSiteCanBeChanged, *Fn, true, nullptr)) { + bool AllCallSitesKnown; + if (!checkForAllCallSites(CallSiteCanBeChanged, *Fn, true, nullptr, + AllCallSitesKnown)) { LLVM_DEBUG(dbgs() << "[Attributor] Cannot rewrite all call sites\n"); return false; } @@ -7178,8 +7209,9 @@ }; // Use the CallSiteReplacementCreator to create replacement call sites. - bool Success = - checkForAllCallSites(CallSiteReplacementCreator, *OldFn, true, nullptr); + bool AllCallSitesKnown; + bool Success = checkForAllCallSites(CallSiteReplacementCreator, *OldFn, + true, nullptr, AllCallSitesKnown); (void)Success; assert(Success && "Assumed call site replacement to succeed!");