diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -4409,6 +4409,11 @@ /// Can \p CB reach \p Fn virtual bool canReach(Attributor &A, CallBase &CB, Function *Fn) const = 0; + /// Can \p Inst reach \p Fn + virtual bool instructionCanReach(Attributor &A, Instruction &Inst, + Function *Fn, + bool UseBackwards = true) const = 0; + /// Create an abstract attribute view for the position \p IRP. static AAFunctionReachability &createForPosition(const IRPosition &IRP, Attributor &A); @@ -4426,7 +4431,6 @@ /// Unique ID (due to the unique address) static const char ID; - private: /// Can this function reach a call with unknown calee. virtual bool canReachUnknownCallee() const = 0; diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -9314,6 +9314,40 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability { private: struct QuerySet { + /// If there is no information about the function None is returned. + Optional isCachedReachable(Function &Fn) { + // Assume that we can reach the function. + // TODO: Be more specific with the unknown callee. + if (CanReachUnknownCallee) + return true; + + if (Reachable.count(&Fn)) + return true; + + if (Unreachable.count(&Fn)) + return false; + + return llvm::None; + } + + // Mark the function \p Fn as reachable. + void markReachable(Function &Fn) { + Reachable.insert(&Fn); + Unreachable.erase(&Fn); + } + + /// Set of functions that we know for sure is reachable. + DenseSet Reachable; + + /// Set of functions that are unreachable, but might become reachable. + DenseSet Unreachable; + + /// If we can reach a function with a call to a unknown function we assume + /// that we can reach any function. + bool CanReachUnknownCallee = false; + }; + + struct QueryResolver : public QuerySet { ChangeStatus update(Attributor &A, const AAFunctionReachability &AA, ArrayRef AAEdgesList) { ChangeStatus Change = ChangeStatus::UNCHANGED; @@ -9339,24 +9373,18 @@ bool isReachable(Attributor &A, const AAFunctionReachability &AA, ArrayRef AAEdgesList, Function *Fn) { - // Assume that we can reach the function. - // TODO: Be more specific with the unknown callee. - if (CanReachUnknownCallee) - return true; - - if (Reachable.count(Fn)) - return true; + Optional Cached = isCachedReachable(*Fn); + if (Cached.hasValue()) + return Cached.getValue(); - if (Unreachable.count(Fn)) - return false; + // Asume that the function is unreachable for now, + // this is to make sure that we don't get stuck in a loop. + Unreachable.insert(Fn); bool Result = checkIfReachable(A, AA, AAEdgesList, Fn); - if (Result) { - Reachable.insert(Fn); - Unreachable.erase(Fn); - } else { - Unreachable.insert(Fn); - } + if (Result) + markReachable(*Fn); + return Result; } @@ -9396,17 +9424,69 @@ } return false; } + }; - /// Set of functions that we know for sure is reachable. - DenseSet Reachable; + bool getReachableCallEdges(Attributor &A, const AAReachability &Reachability, + Instruction &Inst, + SmallVector &Result) const { + // Determine call like instructions that we can reach from the inst. + auto CheckCallBase = [&](Instruction &CBInst) { + if (!Reachability.isAssumedReachable(A, CBInst, Inst)) + return true; - /// Set of functions that are unreachable, but might become reachable. - DenseSet Unreachable; + CallBase &CB = static_cast(Inst); + const AACallEdges &AAEdges = A.getAAFor( + *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); - /// If we can reach a function with a call to a unknown function we assume - /// that we can reach any function. - bool CanReachUnknownCallee = false; - }; + Result.push_back(&AAEdges); + return true; + }; + + bool UsedAssumedInformation = false; + return A.checkForAllCallLikeInstructions(CheckCallBase, *this, + UsedAssumedInformation); + } + + ChangeStatus checkReachableBackwards(Attributor &A, QuerySet &Set) { + ChangeStatus Change = ChangeStatus::UNCHANGED; + + // For all remaining instruction queries, check + // callers. A call inside that function might satisfy the query. + auto CheckCallSite = [&](AbstractCallSite CallSite) { + CallBase *CB = CallSite.getInstruction(); + if (!CB) + return false; + + if (isa(CB)) + return false; + + Instruction *Inst = CB->getNextNonDebugInstruction(); + const AAFunctionReachability &AA = A.getAAFor( + *this, IRPosition::function(*Inst->getFunction()), + DepClassTy::REQUIRED); + + for (auto It = Set.Unreachable.begin(); It != Set.Unreachable.end();) { + Function *Fn = *It; + It++; + if (AA.instructionCanReach(A, *Inst, Fn, /* UseBackwards */ false)) { + Set.markReachable(*Fn); + Change = ChangeStatus::CHANGED; + } + } + return true; + }; + bool NoUnknownCall = false; + NoUnknownCall &= + A.checkForAllCallSites(CheckCallSite, *this, true, NoUnknownCall); + + // If we don't know all callsites we have to assume that we can reach + // fn. + if (!NoUnknownCall) + for (auto &QSet : InstQueriesBackwards) + QSet.second.CanReachUnknownCallee = true; + + return Change; + } public: AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A) @@ -9437,13 +9517,45 @@ // a const_cast. // This is a hack for us to be able to cache queries. auto *NonConstThis = const_cast(this); - QuerySet &CBQuery = NonConstThis->CBQueries[&CB]; + QueryResolver &CBQuery = NonConstThis->CBQueries[&CB]; bool Result = CBQuery.isReachable(A, *this, {&AAEdges}, Fn); return Result; } + bool instructionCanReach(Attributor &A, Instruction &Inst, Function *Fn, + bool UseBackwards) const override { + auto Reachability = &A.getAAFor( + *this, IRPosition::function(*getAssociatedFunction()), + DepClassTy::REQUIRED); + + SmallVector CallEdges; + bool AllKnown = getReachableCallEdges(A, *Reachability, Inst, CallEdges); + // Attributor returns attributes as const, so this function has to be + // const for users of this attribute to use it without having to do + // a const_cast. + // This is a hack for us to be able to cache queries. + auto *NonConstThis = const_cast(this); + QueryResolver &InstQSet = NonConstThis->InstQueries[&Inst]; + if (!AllKnown) + InstQSet.CanReachUnknownCallee = true; + + bool ForwardsResult = InstQSet.isReachable(A, *this, CallEdges, Fn); + if (ForwardsResult) + return true; + // We are done. + if (!UseBackwards) + return false; + + QuerySet &InstBackwardsQSet = NonConstThis->InstQueriesBackwards[&Inst]; + InstBackwardsQSet.Unreachable.insert(Fn); + + // Check backwards reachability. + NonConstThis->checkReachableBackwards(A, InstBackwardsQSet); + return InstBackwardsQSet.Reachable.count(Fn); + } + /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { const AACallEdges &AAEdges = @@ -9460,6 +9572,28 @@ Change |= CBPair.second.update(A, *this, {&AAEdges}); } + // Update the Instruction queries. + const AAReachability *Reachability; + if (!InstQueries.empty()) { + Reachability = &A.getAAFor( + *this, IRPosition::function(*getAssociatedFunction()), + DepClassTy::REQUIRED); + } + + // Check for local callbases first. + for (auto InstPair : InstQueries) { + SmallVector CallEdges; + bool AllKnown = + getReachableCallEdges(A, *Reachability, *InstPair.first, CallEdges); + // Update will return change if we this effects any queries. + if (!AllKnown) + InstPair.second.CanReachUnknownCallee = true; + Change |= InstPair.second.update(A, *this, CallEdges); + } + + for (auto QueryPair : InstQueriesBackwards) + Change |= checkReachableBackwards(A, QueryPair.second); + return Change; } @@ -9478,8 +9612,12 @@ return WholeFunction.CanReachUnknownCallee; } - QuerySet WholeFunction; - DenseMap CBQueries; + QueryResolver WholeFunction; + DenseMap CBQueries; + DenseMap InstQueries; + + // This is for instruction queries than scan backwards. + DenseMap InstQueriesBackwards; }; } // namespace diff --git a/llvm/unittests/Transforms/IPO/AttributorTest.cpp b/llvm/unittests/Transforms/IPO/AttributorTest.cpp --- a/llvm/unittests/Transforms/IPO/AttributorTest.cpp +++ b/llvm/unittests/Transforms/IPO/AttributorTest.cpp @@ -75,47 +75,67 @@ TEST_F(AttributorTestBase, AAReachabilityTest) { const char *ModuleString = R"( - @x = global i32 0 - define void @func4() { + @x = external global i32 + define internal void @func4() { store i32 0, i32* @x ret void } - define void @func3() { + define internal void @func3() { store i32 0, i32* @x ret void } - define void @func2() { + define internal void @func8() { + store i32 0, i32* @x + ret void + } + + define internal void @func2() { entry: call void @func3() ret void } - define void @func1() { + define internal void @func1() { entry: call void @func2() ret void } - define void @func5(void ()* %unknown) { + define internal void @func5(void ()* %unknown) { entry: call void %unknown() ret void } - define void @func6() { + define internal void @func6() { entry: call void @func5(void ()* @func3) ret void } - define void @func7() { + define internal void @func7() { + entry: + call void @func2() + call void @func4() + ret void + } + + define internal void @func9() { entry: call void @func2() + call void @func8() + ret void + } + + define internal void @func10() { + entry: + call void @func9() call void @func4() ret void } + )"; Module &M = parseModule(ModuleString); @@ -128,17 +148,22 @@ CallGraphUpdater CGUpdater; BumpPtrAllocator Allocator; InformationCache InfoCache(M, AG, Allocator, nullptr); - Attributor A(Functions, InfoCache, CGUpdater); + Attributor A(Functions, InfoCache, CGUpdater, /* Allowed */ nullptr, + /*DeleteFns*/ false); Function *F1 = M.getFunction("func1"); Function *F3 = M.getFunction("func3"); Function *F4 = M.getFunction("func4"); Function *F6 = M.getFunction("func6"); Function *F7 = M.getFunction("func7"); + Function *F9 = M.getFunction("func9"); // call void @func2() - CallBase &F7FirstCB = - *static_cast(F7->getEntryBlock().getFirstNonPHI()); + CallBase &F7FirstCB = static_cast(*F7->getEntryBlock().begin()); + // call void @func2() + Instruction &F9FirstInst = *F9->getEntryBlock().begin(); + // call void @func8 + Instruction &F9SecondInst = *++(F9->getEntryBlock().begin()); const AAFunctionReachability &F1AA = A.getOrCreateAAFor(IRPosition::function(*F1)); @@ -149,11 +174,17 @@ const AAFunctionReachability &F7AA = A.getOrCreateAAFor(IRPosition::function(*F7)); + const AAFunctionReachability &F9AA = + A.getOrCreateAAFor(IRPosition::function(*F9)); + F1AA.canReach(A, F3); F1AA.canReach(A, F4); F6AA.canReach(A, F4); F7AA.canReach(A, F7FirstCB, F3); F7AA.canReach(A, F7FirstCB, F4); + F9AA.instructionCanReach(A, F9FirstInst, F3); + F9AA.instructionCanReach(A, F9SecondInst, F3, false); + F9AA.instructionCanReach(A, F9FirstInst, F4); A.run(); @@ -164,6 +195,15 @@ // Assumed to be reacahable, since F6 can reach a function with // a unknown callee. ASSERT_TRUE(F6AA.canReach(A, F4)); + + // The second instruction of F9 can't reach the first call. + ASSERT_FALSE(F9AA.instructionCanReach(A, F9SecondInst, F3, false)); + ASSERT_FALSE(F9AA.instructionCanReach(A, F9SecondInst, F3, true)); + + // The first instruction of F9 can reach the first call. + ASSERT_TRUE(F9AA.instructionCanReach(A, F9FirstInst, F3)); + // Because func10 calls the func4 after the call to func9 it is reachable. + ASSERT_TRUE(F9AA.instructionCanReach(A, F9FirstInst, F4)); } } // namespace llvm