diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -4199,6 +4199,39 @@ static const char ID; }; +/// An abstract Attribute for computing reachability between functions. +struct AAFunctionReachability + : public StateWrapper { + using Base = StateWrapper; + + AAFunctionReachability(const IRPosition &IRP, Attributor &A) : Base(IRP) {} + + /// If the function represented by this possition can reach \p Fn. + virtual bool canReach(Attributor &A, Function *Fn) const = 0; + + /// Create an abstract attribute view for the position \p IRP. + static AAFunctionReachability &createForPosition(const IRPosition &IRP, + Attributor &A); + + /// See AbstractAttribute::getName() + const std::string getName() const override { return "AAFuncitonReacability"; } + + /// See AbstractAttribute::getIdAddr() + const char *getIdAddr() const override { return &ID; } + + /// This function should return true if the type of the \p AA is AACallEdges. + static bool classof(const AbstractAttribute *AA) { + return (AA->getIdAddr() == &ID); + } + + /// Unique ID (due to the unique address) + static const char ID; + +private: + /// Can this function reach a call with unknown calee. + virtual bool canReachUnknownCallee() const = 0; +}; + /// Run options, used by the pass manager. enum AttributorRunOption { NONE = 0, diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -136,6 +136,7 @@ PIPE_OPERATOR(AAPotentialValues) PIPE_OPERATOR(AANoUndef) PIPE_OPERATOR(AACallEdges) +PIPE_OPERATOR(AAFunctionReachability) #undef PIPE_OPERATOR } // namespace llvm @@ -8276,6 +8277,118 @@ bool HasUnknownCallee = false; }; +struct AAFunctionReachabilityFunction : public AAFunctionReachability { + AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A) + : AAFunctionReachability(IRP, A) {} + + bool canReach(Attributor &A, Function *Fn) const override { + // Assume that we can reach any function if we can reach a call with + // unknown callee. + if (CanReachUnknownCallee) + return true; + + if (ReachableQueries.count(Fn)) + return true; + + if (UnreachableQueries.count(Fn)) + return false; + + const AACallEdges &AAEdges = + A.getAAFor(*this, getIRPosition(), DepClassTy::REQUIRED); + + const SetVector &Edges = AAEdges.getOptimisticEdges(); + bool Result = checkIfReachable(A, Edges, Fn); + + // Attributor returns attributes as const, so this function has to be + // const for users of this attribute to use it without having to do + // a const_cast. + // This is a hack for us to be able to cache queries. + auto *NonConstThis = const_cast(this); + + if (Result) + NonConstThis->ReachableQueries.insert(Fn); + else + NonConstThis->UnreachableQueries.insert(Fn); + + return Result; + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + if (CanReachUnknownCallee) + return ChangeStatus::UNCHANGED; + + const AACallEdges &AAEdges = + A.getAAFor(*this, getIRPosition(), DepClassTy::REQUIRED); + const SetVector &Edges = AAEdges.getOptimisticEdges(); + ChangeStatus Change = ChangeStatus::UNCHANGED; + + if (AAEdges.hasUnknownCallee()) { + bool OldCanReachUnknown = CanReachUnknownCallee; + CanReachUnknownCallee = true; + return OldCanReachUnknown ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + + // Check if any of the unreachable functions become reachable. + for (auto Current = UnreachableQueries.begin(); + Current != UnreachableQueries.end();) { + if (!checkIfReachable(A, Edges, *Current)) { + Current++; + continue; + } + ReachableQueries.insert(*Current); + UnreachableQueries.erase(*Current++); + Change = ChangeStatus::CHANGED; + } + + return Change; + } + + const std::string getAsStr() const override { + size_t QueryCount = ReachableQueries.size() + UnreachableQueries.size(); + + return "FunctionReachability [" + std::to_string(ReachableQueries.size()) + + "," + std::to_string(QueryCount) + "]"; + } + + void trackStatistics() const override {} + +private: + bool canReachUnknownCallee() const override { return CanReachUnknownCallee; } + + bool checkIfReachable(Attributor &A, const SetVector &Edges, + Function *Fn) const { + if (Edges.count(Fn)) + return true; + + for (Function *Edge : Edges) { + // We don't need a dependency if the result is reachable. + const AAFunctionReachability &EdgeReachability = + A.getAAFor(*this, IRPosition::function(*Edge), + DepClassTy::NONE); + + if (EdgeReachability.canReach(A, Fn)) + return true; + } + for (Function *Fn : Edges) + A.getAAFor(*this, IRPosition::function(*Fn), + DepClassTy::REQUIRED); + + return false; + } + + /// Set of functions that we know for sure is reachable. + SmallPtrSet ReachableQueries; + + /// Set of functions that are unreachable, but might become reachable. + SmallPtrSet UnreachableQueries; + + /// If we can reach a function with a call to a unknown function we assume + /// that we can reach any function. + bool CanReachUnknownCallee = false; +}; + } // namespace AACallGraphNode *AACallEdgeIterator::operator*() const { @@ -8311,6 +8424,7 @@ const char AAPotentialValues::ID = 0; const char AANoUndef::ID = 0; const char AACallEdges::ID = 0; +const char AAFunctionReachability::ID = 0; // Macro magic to create the static generator function for attributes that // follow the naming scheme. @@ -8431,6 +8545,7 @@ CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAReachability) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUndefinedBehavior) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AACallEdges) +CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAFunctionReachability) CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryBehavior) diff --git a/llvm/unittests/Transforms/IPO/AttributorTest.cpp b/llvm/unittests/Transforms/IPO/AttributorTest.cpp --- a/llvm/unittests/Transforms/IPO/AttributorTest.cpp +++ b/llvm/unittests/Transforms/IPO/AttributorTest.cpp @@ -73,4 +73,71 @@ ASSERT_TRUE(SSucc); } +TEST_F(AttributorTestBase, AAReachabilityTest) { + const char *ModuleString = R"( + declare void @func4() + declare void @func3() + + define void @func2() { + entry: + call void @func3() + ret void + } + + define void @func1() { + entry: + call void @func2() + ret void + } + + define void @func5(void ()* %unknown) { + entry: + call void %unknown() + ret void + } + + define void @func6() { + entry: + call void @func5(void ()* @func3) + ret void + } + )"; + + Module &M = parseModule(ModuleString); + + SetVector Functions; + AnalysisGetter AG; + for (Function &F : M) + Functions.insert(&F); + + CallGraphUpdater CGUpdater; + BumpPtrAllocator Allocator; + InformationCache InfoCache(M, AG, Allocator, nullptr); + Attributor A(Functions, InfoCache, CGUpdater); + + Function *F1 = M.getFunction("func1"); + Function *F3 = M.getFunction("func3"); + Function *F4 = M.getFunction("func4"); + Function *F6 = M.getFunction("func6"); + + const AAFunctionReachability &F1AA = + A.getOrCreateAAFor(IRPosition::function(*F1)); + + const AAFunctionReachability &F6AA = + A.getOrCreateAAFor(IRPosition::function(*F6)); + + F1AA.canReach(A, F3); + F1AA.canReach(A, F4); + F6AA.canReach(A, F4); + + A.run(); + + ASSERT_TRUE(F1AA.canReach(A, F3)); + ASSERT_FALSE(F1AA.canReach(A, F4)); + + // Assumed to be reacahable, since F6 can reach a function with + // a unknown callee. + ASSERT_TRUE(F6AA.canReach(A, F4)); +} + } // namespace llvm