diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -4454,6 +4454,9 @@ /// If the function represented by this possition can reach \p Fn. virtual bool canReach(Attributor &A, Function *Fn) const = 0; + /// Can \p CB reach \p Fn + virtual bool canReach(Attributor &A, CallBase &CB, Function *Fn) const = 0; + /// Create an abstract attribute view for the position \p IRP. static AAFunctionReachability &createForPosition(const IRPosition &IRP, Attributor &A); diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -9496,115 +9496,180 @@ }; struct AAFunctionReachabilityFunction : public AAFunctionReachability { - AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A) - : AAFunctionReachability(IRP, A) {} +private: + struct QuerySet { + void markReachable(Function *Fn) { + Reachable.insert(Fn); + Unreachable.erase(Fn); + } + + ChangeStatus update(Attributor &A, const AAFunctionReachability &AA, + ArrayRef AAEdgesList) { + ChangeStatus Change = ChangeStatus::UNCHANGED; + + for (auto *AAEdges : AAEdgesList) { + if (AAEdges->hasUnknownCallee()) { + if (!CanReachUnknownCallee) + Change = ChangeStatus::CHANGED; + CanReachUnknownCallee = true; + return Change; + } + } - bool canReach(Attributor &A, Function *Fn) const override { - // Assume that we can reach any function if we can reach a call with - // unknown callee. - if (CanReachUnknownCallee) - return true; + for (Function *Fn : make_early_inc_range(Unreachable)) { + if (checkIfReachable(A, AA, AAEdgesList, Fn)) { + Change = ChangeStatus::CHANGED; + markReachable(Fn); + } + } + return Change; + } - if (ReachableQueries.count(Fn)) - return true; + bool isReachable(Attributor &A, const AAFunctionReachability &AA, + ArrayRef AAEdgesList, Function *Fn) { + // Assume that we can reach the function. + // TODO: Be more specific with the unknown callee. + if (CanReachUnknownCallee) + return true; + + if (Reachable.count(Fn)) + return true; + + if (Unreachable.count(Fn)) + return false; + + // We need to assume that this function can't reach Fn to prevent + // an infinite loop if this function is recursive. + Unreachable.insert(Fn); + + bool Result = checkIfReachable(A, AA, AAEdgesList, Fn); + if (Result) + markReachable(Fn); + return Result; + } + + bool checkIfReachable(Attributor &A, const AAFunctionReachability &AA, + ArrayRef AAEdgesList, + Function *Fn) const { + + // Handle the most trivial case first. + for (auto *AAEdges : AAEdgesList) { + const SetVector &Edges = AAEdges->getOptimisticEdges(); + + if (Edges.count(Fn)) + return true; + } + + SmallVector Deps; + for (auto &AAEdges : AAEdgesList) { + const SetVector &Edges = AAEdges->getOptimisticEdges(); + + for (Function *Edge : Edges) { + // We don't need a dependency if the result is reachable. + const AAFunctionReachability &EdgeReachability = + A.getAAFor( + AA, IRPosition::function(*Edge), DepClassTy::NONE); + Deps.push_back(&EdgeReachability); + + if (EdgeReachability.canReach(A, Fn)) + return true; + } + } + + // The result is false for now, set dependencies and leave. + for (auto Dep : Deps) + A.recordDependence(AA, *Dep, DepClassTy::REQUIRED); - if (UnreachableQueries.count(Fn)) return false; + } + + /// Set of functions that we know for sure is reachable. + DenseSet Reachable; + + /// Set of functions that are unreachable, but might become reachable. + DenseSet Unreachable; + /// If we can reach a function with a call to a unknown function we assume + /// that we can reach any function. + bool CanReachUnknownCallee = false; + }; + +public: + AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A) + : AAFunctionReachability(IRP, A) {} + + bool canReach(Attributor &A, Function *Fn) const override { const AACallEdges &AAEdges = A.getAAFor(*this, getIRPosition(), DepClassTy::REQUIRED); - const SetVector &Edges = AAEdges.getOptimisticEdges(); - bool Result = checkIfReachable(A, Edges, Fn); + // Attributor returns attributes as const, so this function has to be + // const for users of this attribute to use it without having to do + // a const_cast. + // This is a hack for us to be able to cache queries. + auto *NonConstThis = const_cast(this); + bool Result = + NonConstThis->WholeFunction.isReachable(A, *this, {&AAEdges}, Fn); + + return Result; + } + + /// Can \p CB reach \p Fn + bool canReach(Attributor &A, CallBase &CB, Function *Fn) const override { + const AACallEdges &AAEdges = A.getAAFor( + *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); // Attributor returns attributes as const, so this function has to be // const for users of this attribute to use it without having to do // a const_cast. // This is a hack for us to be able to cache queries. auto *NonConstThis = const_cast(this); + QuerySet &CBQuery = NonConstThis->CBQueries[&CB]; - if (Result) - NonConstThis->ReachableQueries.insert(Fn); - else - NonConstThis->UnreachableQueries.insert(Fn); + bool Result = CBQuery.isReachable(A, *this, {&AAEdges}, Fn); return Result; } /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { - if (CanReachUnknownCallee) - return ChangeStatus::UNCHANGED; - const AACallEdges &AAEdges = A.getAAFor(*this, getIRPosition(), DepClassTy::REQUIRED); - const SetVector &Edges = AAEdges.getOptimisticEdges(); ChangeStatus Change = ChangeStatus::UNCHANGED; - if (AAEdges.hasUnknownCallee()) { - bool OldCanReachUnknown = CanReachUnknownCallee; - CanReachUnknownCallee = true; - return OldCanReachUnknown ? ChangeStatus::UNCHANGED - : ChangeStatus::CHANGED; - } + Change |= WholeFunction.update(A, *this, {&AAEdges}); - // Check if any of the unreachable functions become reachable. - for (auto Current = UnreachableQueries.begin(); - Current != UnreachableQueries.end();) { - if (!checkIfReachable(A, Edges, *Current)) { - Current++; - continue; - } - ReachableQueries.insert(*Current); - UnreachableQueries.erase(*Current++); - Change = ChangeStatus::CHANGED; + for (auto CBPair : CBQueries) { + const AACallEdges &AAEdges = A.getAAFor( + *this, IRPosition::callsite_function(*CBPair.first), + DepClassTy::REQUIRED); + + Change |= CBPair.second.update(A, *this, {&AAEdges}); } return Change; } const std::string getAsStr() const override { - size_t QueryCount = ReachableQueries.size() + UnreachableQueries.size(); + size_t QueryCount = + WholeFunction.Reachable.size() + WholeFunction.Unreachable.size(); - return "FunctionReachability [" + std::to_string(ReachableQueries.size()) + - "," + std::to_string(QueryCount) + "]"; + return "FunctionReachability [" + + std::to_string(WholeFunction.Reachable.size()) + "," + + std::to_string(QueryCount) + "]"; } void trackStatistics() const override {} - private: - bool canReachUnknownCallee() const override { return CanReachUnknownCallee; } - - bool checkIfReachable(Attributor &A, const SetVector &Edges, - Function *Fn) const { - if (Edges.count(Fn)) - return true; - - for (Function *Edge : Edges) { - // We don't need a dependency if the result is reachable. - const AAFunctionReachability &EdgeReachability = - A.getAAFor(*this, IRPosition::function(*Edge), - DepClassTy::NONE); - - if (EdgeReachability.canReach(A, Fn)) - return true; - } - for (Function *Fn : Edges) - A.getAAFor(*this, IRPosition::function(*Fn), - DepClassTy::REQUIRED); - - return false; + bool canReachUnknownCallee() const override { + return WholeFunction.CanReachUnknownCallee; } - /// Set of functions that we know for sure is reachable. - SmallPtrSet ReachableQueries; - - /// Set of functions that are unreachable, but might become reachable. - SmallPtrSet UnreachableQueries; + /// Used to answer if a the whole function can reacha a specific function. + QuerySet WholeFunction; - /// If we can reach a function with a call to a unknown function we assume - /// that we can reach any function. - bool CanReachUnknownCallee = false; + /// Used to answer if a call base inside this function can reach a specific + /// function. + DenseMap CBQueries; }; } // namespace diff --git a/llvm/unittests/Transforms/IPO/AttributorTest.cpp b/llvm/unittests/Transforms/IPO/AttributorTest.cpp --- a/llvm/unittests/Transforms/IPO/AttributorTest.cpp +++ b/llvm/unittests/Transforms/IPO/AttributorTest.cpp @@ -109,6 +109,13 @@ call void @func5(void ()* @func3) ret void } + + define void @func7() { + entry: + call void @func2() + call void @func4() + ret void + } )"; Module &M = parseModule(ModuleString); @@ -127,6 +134,11 @@ Function *F3 = M.getFunction("func3"); Function *F4 = M.getFunction("func4"); Function *F6 = M.getFunction("func6"); + Function *F7 = M.getFunction("func7"); + + // call void @func2() + CallBase &F7FirstCB = + *static_cast(F7->getEntryBlock().getFirstNonPHI()); const AAFunctionReachability &F1AA = A.getOrCreateAAFor(IRPosition::function(*F1)); @@ -134,15 +146,23 @@ const AAFunctionReachability &F6AA = A.getOrCreateAAFor(IRPosition::function(*F6)); + const AAFunctionReachability &F7AA = + A.getOrCreateAAFor(IRPosition::function(*F7)); + F1AA.canReach(A, F3); F1AA.canReach(A, F4); F6AA.canReach(A, F4); + F7AA.canReach(A, F7FirstCB, F3); + F7AA.canReach(A, F7FirstCB, F4); A.run(); ASSERT_TRUE(F1AA.canReach(A, F3)); ASSERT_FALSE(F1AA.canReach(A, F4)); + ASSERT_TRUE(F7AA.canReach(A, F7FirstCB, F3)); + ASSERT_FALSE(F7AA.canReach(A, F7FirstCB, F4)); + // Assumed to be reacahable, since F6 can reach a function with // a unknown callee. ASSERT_TRUE(F6AA.canReach(A, F4));