diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -4355,6 +4355,10 @@ /// Can \p CB reach \p Fn virtual bool canReach(Attributor &A, CallBase &CB, Function *Fn) const = 0; + /// Can \p Inst reach \p Fn + virtual bool instCanReach(Attributor &A, Instruction &Inst, + Function *Fn) const = 0; + /// Create an abstract attribute view for the position \p IRP. static AAFunctionReachability &createForPosition(const IRPosition &IRP, Attributor &A); @@ -4372,7 +4376,6 @@ /// Unique ID (due to the unique address) static const char ID; - private: /// Can this function reach a call with unknown calee. virtual bool canReachUnknownCallee() const = 0; diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -8599,6 +8599,11 @@ return Change; } + void markReachable(Function &Fn) { + Reachable.insert(&Fn); + Unreachable.erase(&Fn); + } + bool isReachable(Attributor &A, const AAFunctionReachability &AA, ArrayRef AAEdgesList, Function *Fn) { // Assume that we can reach the function. @@ -8614,8 +8619,7 @@ bool Result = checkIfReachable(A, AA, AAEdgesList, Fn); if (Result) { - Reachable.insert(Fn); - Unreachable.erase(Fn); + markReachable(*Fn); } else { Unreachable.insert(Fn); } @@ -8670,6 +8674,65 @@ bool CanReachUnknownCallee = false; }; + void getReachableCallEdges(Attributor &A, const AAReachability &Reachability, + Instruction &Inst, + SmallVector &Result) const { + // Determine call like instructions that we can reach from the inst. + auto CheckCallBase = [&](Instruction &CBInst) { + if (!Reachability.isAssumedReachable(A, CBInst, Inst)) + return true; + + CallBase &CB = static_cast(Inst); + const AACallEdges &AAEdges = A.getAAFor( + *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); + + Result.push_back(&AAEdges); + return true; + }; + + bool UsedAssumedInformation = false; + A.checkForAllCallLikeInstructions(CheckCallBase, *this, + UsedAssumedInformation); + } + + ChangeStatus checkReachableBackwards(Attributor &A, QuerySet &Set) { + ChangeStatus Change = ChangeStatus::UNCHANGED; + if (Set.Unreachable.empty()) + return Change; + + // For all remaining instruction queries, check + // callers. A call inside that function might satisfy the query. + auto CheckCallSite = [&](AbstractCallSite CallSite) { + Instruction *Inst = + CallSite.getInstruction()->getNextNonDebugInstruction(); + const AAFunctionReachability &AA = A.getAAFor( + *this, IRPosition::function(*Inst->getFunction()), + DepClassTy::REQUIRED); + + for (auto It = Set.Unreachable.begin(); It != Set.Unreachable.end();) { + Function *Fn = *It; + It++; + if (AA.instCanReach(A, *Inst, Fn)) { + Set.markReachable(*Fn); + Change = ChangeStatus::CHANGED; + } + } + return true; + }; + bool AllCallSitesKnown = false; + A.checkForAllCallSites(CheckCallSite, *this, true, AllCallSitesKnown); + + // If we don't know all callsites we have to assume that we can reach + // fn. + if (!AllCallSitesKnown) { + for (auto &QueryPair : InstQueries) { + QueryPair.second.CanReachUnknownCallee = true; + } + } + + return Change; + } + public: AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A) : AAFunctionReachability(IRP, A) {} @@ -8706,6 +8769,31 @@ return Result; } + bool instCanReach(Attributor &A, Instruction &Inst, + Function *Fn) const override { + auto Reachability = &A.getAAFor( + *this, IRPosition::function(*getAssociatedFunction()), + DepClassTy::REQUIRED); + + SmallVector CallEdges; + getReachableCallEdges(A, *Reachability, Inst, CallEdges); + // Attributor returns attributes as const, so this function has to be + // const for users of this attribute to use it without having to do + // a const_cast. + // This is a hack for us to be able to cache queries. + auto *NonConstThis = const_cast(this); + QuerySet &InstQuery = NonConstThis->InstQueries[&Inst]; + + bool Result = InstQuery.isReachable(A, *this, CallEdges, Fn); + if (Result) + return true; + + // Check backwards reachability. + NonConstThis->checkReachableBackwards(A, InstQuery); + + return InstQuery.isReachable(A, *this, CallEdges, Fn); + } + /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { const AACallEdges &AAEdges = @@ -8713,7 +8801,7 @@ ChangeStatus Change = ChangeStatus::UNCHANGED; Change = Change | WholeFunction.update(A, *this, {&AAEdges}); - + // Update the CallBaseQueries. for (auto CBPair : CBQueries) { const AACallEdges &AAEdges = A.getAAFor( *this, IRPosition::callsite_function(*CBPair.first), @@ -8722,6 +8810,25 @@ Change = Change | CBPair.second.update(A, *this, {&AAEdges}); } + // Update the Instruction queries. + const AAReachability *Reachability; + if (!InstQueries.empty()) { + Reachability = &A.getAAFor( + *this, IRPosition::function(*getAssociatedFunction()), + DepClassTy::REQUIRED); + } + + // Check for local callbases first. + for (auto InstPair : InstQueries) { + SmallVector CallEdges; + getReachableCallEdges(A, *Reachability, *InstPair.first, CallEdges); + Change = Change | InstPair.second.update(A, *this, CallEdges); + } + + for (auto QueryPair : InstQueries) { + Change = Change | checkReachableBackwards(A, QueryPair.second); + } + return Change; } @@ -8742,6 +8849,7 @@ QuerySet WholeFunction; DenseMap CBQueries; + DenseMap InstQueries; }; } // namespace diff --git a/llvm/unittests/Transforms/IPO/AttributorTest.cpp b/llvm/unittests/Transforms/IPO/AttributorTest.cpp --- a/llvm/unittests/Transforms/IPO/AttributorTest.cpp +++ b/llvm/unittests/Transforms/IPO/AttributorTest.cpp @@ -76,46 +76,66 @@ TEST_F(AttributorTestBase, AAReachabilityTest) { const char *ModuleString = R"( @x = global i32 0 - define void @func4() { + define internal void @func4() { store i32 0, i32* @x ret void } - define void @func3() { + define internal void @func3() { store i32 0, i32* @x ret void } - define void @func2() { + define internal void @func8() { + store i32 0, i32* @x + ret void + } + + define internal void @func2() { entry: call void @func3() ret void } - define void @func1() { + define internal void @func1() { entry: call void @func2() ret void } - define void @func5(void ()* %unknown) { + define internal void @func5(void ()* %unknown) { entry: call void %unknown() ret void } - define void @func6() { + define internal void @func6() { entry: call void @func5(void ()* @func3) ret void } - define void @func7() { + define internal void @func7() { + entry: + call void @func2() + call void @func4() + ret void + } + + define internal void @func9() { entry: call void @func2() + call void @func8() + ret void + } + + define internal void @func10() { + entry: + call void @func9() call void @func4() ret void } + )"; Module &M = parseModule(ModuleString); @@ -135,10 +155,14 @@ Function *F4 = M.getFunction("func4"); Function *F6 = M.getFunction("func6"); Function *F7 = M.getFunction("func7"); + Function *F9 = M.getFunction("func9"); // call void @func2() - CallBase &F7FirstCB = - *static_cast(F7->getEntryBlock().getFirstNonPHI()); + CallBase &F7FirstCB = static_cast(*F7->getEntryBlock().begin()); + // call void @func2() + Instruction &F9FirstInst = *F9->getEntryBlock().begin(); + // call void @func8 + Instruction &F9SecondInst = *++(F9->getEntryBlock().begin()); const AAFunctionReachability &F1AA = A.getOrCreateAAFor(IRPosition::function(*F1)); @@ -149,11 +173,17 @@ const AAFunctionReachability &F7AA = A.getOrCreateAAFor(IRPosition::function(*F7)); + const AAFunctionReachability &F9AA = + A.getOrCreateAAFor(IRPosition::function(*F9)); + F1AA.canReach(A, F3); F1AA.canReach(A, F4); F6AA.canReach(A, F4); F7AA.canReach(A, F7FirstCB, F3); F7AA.canReach(A, F7FirstCB, F4); + F9AA.instCanReach(A, F9FirstInst, F3); + F9AA.instCanReach(A, F9SecondInst, F3); + F9AA.instCanReach(A, F9FirstInst, F4); A.run(); @@ -164,6 +194,14 @@ // Assumed to be reacahable, since F6 can reach a function with // a unknown callee. ASSERT_TRUE(F6AA.canReach(A, F4)); + + // The second instruction of F9 can't reach the first call. + // FIXME: This happens due to backwards reachability. + ASSERT_TRUE(F9AA.instCanReach(A, F9SecondInst, F3)); + // The first instruction of F9 can reach the first call. + ASSERT_TRUE(F9AA.instCanReach(A, F9FirstInst, F3)); + // Because func10 calls the func4 after the call to func9 it is reachable. + ASSERT_TRUE(F9AA.instCanReach(A, F9FirstInst, F4)); } } // namespace llvm