diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -4452,10 +4452,15 @@ AAFunctionReachability(const IRPosition &IRP, Attributor &A) : Base(IRP) {} /// If the function represented by this possition can reach \p Fn. - virtual bool canReach(Attributor &A, Function *Fn) const = 0; + virtual bool canReach(Attributor &A, Function &Fn) const = 0; /// Can \p CB reach \p Fn - virtual bool canReach(Attributor &A, CallBase &CB, Function *Fn) const = 0; + virtual bool canReach(Attributor &A, CallBase &CB, Function &Fn) const = 0; + + /// Can \p Inst reach \p Fn + virtual bool instructionCanReach(Attributor &A, Instruction &Inst, + Function &Fn, + bool UseBackwards = true) const = 0; /// Create an abstract attribute view for the position \p IRP. static AAFunctionReachability &createForPosition(const IRPosition &IRP, @@ -4474,7 +4479,6 @@ /// Unique ID (due to the unique address) static const char ID; - private: /// Can this function reach a call with unknown calee. virtual bool canReachUnknownCallee() const = 0; diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -9503,6 +9503,34 @@ Unreachable.erase(Fn); } + /// If there is no information about the function None is returned. + Optional isCachedReachable(Function &Fn) { + // Assume that we can reach the function. + // TODO: Be more specific with the unknown callee. + if (CanReachUnknownCallee) + return true; + + if (Reachable.count(&Fn)) + return true; + + if (Unreachable.count(&Fn)) + return false; + + return llvm::None; + } + + /// Set of functions that we know for sure is reachable. + DenseSet Reachable; + + /// Set of functions that are unreachable, but might become reachable. + DenseSet Unreachable; + + /// If we can reach a function with a call to a unknown function we assume + /// that we can reach any function. + bool CanReachUnknownCallee = false; + }; + + struct QueryResolver : public QuerySet { ChangeStatus update(Attributor &A, const AAFunctionReachability &AA, ArrayRef AAEdgesList) { ChangeStatus Change = ChangeStatus::UNCHANGED; @@ -9527,16 +9555,13 @@ bool isReachable(Attributor &A, const AAFunctionReachability &AA, ArrayRef AAEdgesList, Function *Fn) { - // Assume that we can reach the function. - // TODO: Be more specific with the unknown callee. - if (CanReachUnknownCallee) - return true; + Optional Cached = isCachedReachable(*Fn); + if (Cached.hasValue()) + return Cached.getValue(); - if (Reachable.count(Fn)) - return true; - - if (Unreachable.count(Fn)) - return false; + // Asume that the function is unreachable for now, + // this is to make sure that we don't get stuck in a loop. + Unreachable.insert(Fn); // We need to assume that this function can't reach Fn to prevent // an infinite loop if this function is recursive. @@ -9571,7 +9596,7 @@ AA, IRPosition::function(*Edge), DepClassTy::NONE); Deps.push_back(&EdgeReachability); - if (EdgeReachability.canReach(A, Fn)) + if (EdgeReachability.canReach(A, *Fn)) return true; } } @@ -9582,23 +9607,76 @@ return false; } + }; - /// Set of functions that we know for sure is reachable. - DenseSet Reachable; + bool getReachableCallEdges(Attributor &A, const AAReachability &Reachability, + Instruction &Inst, + SmallVector &Result) const { + // Determine call like instructions that we can reach from the inst. + auto CheckCallBase = [&](Instruction &CBInst) { + if (!Reachability.isAssumedReachable(A, Inst, CBInst)) + return true; - /// Set of functions that are unreachable, but might become reachable. - DenseSet Unreachable; + CallBase &CB = static_cast(Inst); + const AACallEdges &AAEdges = A.getAAFor( + *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); - /// If we can reach a function with a call to a unknown function we assume - /// that we can reach any function. - bool CanReachUnknownCallee = false; - }; + Result.push_back(&AAEdges); + return true; + }; + + bool UsedAssumedInformation = false; + return A.checkForAllCallLikeInstructions(CheckCallBase, *this, + UsedAssumedInformation); + } + + ChangeStatus checkReachableBackwards(Attributor &A, QuerySet &Set) { + ChangeStatus Change = ChangeStatus::UNCHANGED; + + // For all remaining instruction queries, check + // callers. A call inside that function might satisfy the query. + auto CheckCallSite = [&](AbstractCallSite CallSite) { + CallBase *CB = CallSite.getInstruction(); + if (!CB) + return false; + + if (isa(CB)) + return false; + + Instruction *Inst = CB->getNextNonDebugInstruction(); + const AAFunctionReachability &AA = A.getAAFor( + *this, IRPosition::function(*Inst->getFunction()), + DepClassTy::REQUIRED); + for (Function *Fn : make_early_inc_range(Set.Unreachable)) { + if (AA.instructionCanReach(A, *Inst, *Fn, /* UseBackwards */ false)) { + Set.markReachable(Fn); + Change = ChangeStatus::CHANGED; + } + } + return true; + }; + + bool NoUnknownCall = true; + NoUnknownCall &= + A.checkForAllCallSites(CheckCallSite, *this, true, NoUnknownCall); + + // If we don't know all callsites we have to assume that we can reach fn. + if (!NoUnknownCall) { + for (auto &QSet : InstQueriesBackwards) { + if (!QSet.second.CanReachUnknownCallee) + Change = ChangeStatus::CHANGED; + QSet.second.CanReachUnknownCallee = true; + } + } + + return Change; + } public: AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A) : AAFunctionReachability(IRP, A) {} - bool canReach(Attributor &A, Function *Fn) const override { + bool canReach(Attributor &A, Function &Fn) const override { const AACallEdges &AAEdges = A.getAAFor(*this, getIRPosition(), DepClassTy::REQUIRED); @@ -9608,13 +9686,13 @@ // This is a hack for us to be able to cache queries. auto *NonConstThis = const_cast(this); bool Result = - NonConstThis->WholeFunction.isReachable(A, *this, {&AAEdges}, Fn); + NonConstThis->WholeFunction.isReachable(A, *this, {&AAEdges}, &Fn); return Result; } /// Can \p CB reach \p Fn - bool canReach(Attributor &A, CallBase &CB, Function *Fn) const override { + bool canReach(Attributor &A, CallBase &CB, Function &Fn) const override { const AACallEdges &AAEdges = A.getAAFor( *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); @@ -9623,13 +9701,45 @@ // a const_cast. // This is a hack for us to be able to cache queries. auto *NonConstThis = const_cast(this); - QuerySet &CBQuery = NonConstThis->CBQueries[&CB]; + QueryResolver &CBQuery = NonConstThis->CBQueries[&CB]; - bool Result = CBQuery.isReachable(A, *this, {&AAEdges}, Fn); + bool Result = CBQuery.isReachable(A, *this, {&AAEdges}, &Fn); return Result; } + bool instructionCanReach(Attributor &A, Instruction &Inst, Function &Fn, + bool UseBackwards) const override { + auto Reachability = &A.getAAFor( + *this, IRPosition::function(*getAssociatedFunction()), + DepClassTy::REQUIRED); + + SmallVector CallEdges; + bool AllKnown = getReachableCallEdges(A, *Reachability, Inst, CallEdges); + // Attributor returns attributes as const, so this function has to be + // const for users of this attribute to use it without having to do + // a const_cast. + // This is a hack for us to be able to cache queries. + auto *NonConstThis = const_cast(this); + QueryResolver &InstQSet = NonConstThis->InstQueries[&Inst]; + if (!AllKnown) + InstQSet.CanReachUnknownCallee = true; + + bool ForwardsResult = InstQSet.isReachable(A, *this, CallEdges, &Fn); + if (ForwardsResult) + return true; + // We are done. + if (!UseBackwards) + return false; + + QuerySet &InstBackwardsQSet = NonConstThis->InstQueriesBackwards[&Inst]; + InstBackwardsQSet.Unreachable.insert(&Fn); + + // Check backwards reachability. + NonConstThis->checkReachableBackwards(A, InstBackwardsQSet); + return InstBackwardsQSet.Reachable.count(&Fn); + } + /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { const AACallEdges &AAEdges = @@ -9646,6 +9756,28 @@ Change |= CBPair.second.update(A, *this, {&AAEdges}); } + // Update the Instruction queries. + const AAReachability *Reachability; + if (!InstQueries.empty()) { + Reachability = &A.getAAFor( + *this, IRPosition::function(*getAssociatedFunction()), + DepClassTy::REQUIRED); + } + + // Check for local callbases first. + for (auto InstPair : InstQueries) { + SmallVector CallEdges; + bool AllKnown = + getReachableCallEdges(A, *Reachability, *InstPair.first, CallEdges); + // Update will return change if we this effects any queries. + if (!AllKnown) + InstPair.second.CanReachUnknownCallee = true; + Change |= InstPair.second.update(A, *this, CallEdges); + } + + for (auto QueryPair : InstQueriesBackwards) + Change |= checkReachableBackwards(A, QueryPair.second); + return Change; } @@ -9665,11 +9797,16 @@ } /// Used to answer if a the whole function can reacha a specific function. - QuerySet WholeFunction; + QueryResolver WholeFunction; /// Used to answer if a call base inside this function can reach a specific /// function. - DenseMap CBQueries; + DenseMap CBQueries; + + DenseMap InstQueries; + + // This is for instruction queries than scan backwards. + DenseMap InstQueriesBackwards; }; } // namespace diff --git a/llvm/unittests/Transforms/IPO/AttributorTest.cpp b/llvm/unittests/Transforms/IPO/AttributorTest.cpp --- a/llvm/unittests/Transforms/IPO/AttributorTest.cpp +++ b/llvm/unittests/Transforms/IPO/AttributorTest.cpp @@ -75,47 +75,67 @@ TEST_F(AttributorTestBase, AAReachabilityTest) { const char *ModuleString = R"( - @x = global i32 0 - define void @func4() { + @x = external global i32 + define internal void @func4() { store i32 0, i32* @x ret void } - define void @func3() { + define internal void @func3() { store i32 0, i32* @x ret void } - define void @func2() { + define internal void @func8() { + store i32 0, i32* @x + ret void + } + + define internal void @func2() { entry: call void @func3() ret void } - define void @func1() { + define internal void @func1() { entry: call void @func2() ret void } - define void @func5(void ()* %unknown) { + define internal void @func5(void ()* %unknown) { entry: call void %unknown() ret void } - define void @func6() { + define internal void @func6() { entry: call void @func5(void ()* @func3) ret void } - define void @func7() { + define internal void @func7() { + entry: + call void @func2() + call void @func4() + ret void + } + + define internal void @func9() { entry: call void @func2() + call void @func8() + ret void + } + + define internal void @func10() { + entry: + call void @func9() call void @func4() ret void } + )"; Module &M = parseModule(ModuleString); @@ -128,32 +148,43 @@ CallGraphUpdater CGUpdater; BumpPtrAllocator Allocator; InformationCache InfoCache(M, AG, Allocator, nullptr); - Attributor A(Functions, InfoCache, CGUpdater); + Attributor A(Functions, InfoCache, CGUpdater, /* Allowed */ nullptr, + /*DeleteFns*/ false); - Function *F1 = M.getFunction("func1"); - Function *F3 = M.getFunction("func3"); - Function *F4 = M.getFunction("func4"); - Function *F6 = M.getFunction("func6"); - Function *F7 = M.getFunction("func7"); + Function &F1 = *M.getFunction("func1"); + Function &F3 = *M.getFunction("func3"); + Function &F4 = *M.getFunction("func4"); + Function &F6 = *M.getFunction("func6"); + Function &F7 = *M.getFunction("func7"); + Function &F9 = *M.getFunction("func9"); // call void @func2() - CallBase &F7FirstCB = - *static_cast(F7->getEntryBlock().getFirstNonPHI()); + CallBase &F7FirstCB = static_cast(*F7.getEntryBlock().begin()); + // call void @func2() + Instruction &F9FirstInst = *F9.getEntryBlock().begin(); + // call void @func8 + Instruction &F9SecondInst = *++(F9.getEntryBlock().begin()); const AAFunctionReachability &F1AA = - A.getOrCreateAAFor(IRPosition::function(*F1)); + A.getOrCreateAAFor(IRPosition::function(F1)); const AAFunctionReachability &F6AA = - A.getOrCreateAAFor(IRPosition::function(*F6)); + A.getOrCreateAAFor(IRPosition::function(F6)); const AAFunctionReachability &F7AA = - A.getOrCreateAAFor(IRPosition::function(*F7)); + A.getOrCreateAAFor(IRPosition::function(F7)); + + const AAFunctionReachability &F9AA = + A.getOrCreateAAFor(IRPosition::function(F9)); F1AA.canReach(A, F3); F1AA.canReach(A, F4); F6AA.canReach(A, F4); F7AA.canReach(A, F7FirstCB, F3); F7AA.canReach(A, F7FirstCB, F4); + F9AA.instructionCanReach(A, F9FirstInst, F3); + F9AA.instructionCanReach(A, F9SecondInst, F3, false); + F9AA.instructionCanReach(A, F9FirstInst, F4); A.run(); @@ -166,6 +197,15 @@ // Assumed to be reacahable, since F6 can reach a function with // a unknown callee. ASSERT_TRUE(F6AA.canReach(A, F4)); + + // The second instruction of F9 can't reach the first call. + ASSERT_FALSE(F9AA.instructionCanReach(A, F9SecondInst, F3, false)); + ASSERT_FALSE(F9AA.instructionCanReach(A, F9SecondInst, F3, true)); + + // The first instruction of F9 can reach the first call. + ASSERT_TRUE(F9AA.instructionCanReach(A, F9FirstInst, F3)); + // Because func10 calls the func4 after the call to func9 it is reachable. + ASSERT_TRUE(F9AA.instructionCanReach(A, F9FirstInst, F4)); } } // namespace llvm