diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -4616,17 +4616,25 @@ AAFunctionReachability(const IRPosition &IRP, Attributor &A) : Base(IRP) {} /// If the function represented by this possition can reach \p Fn. - virtual bool canReach(Attributor &A, Function *Fn) const = 0; + virtual bool canReach(Attributor &A, const Function &Fn) const = 0; /// Can \p CB reach \p Fn - virtual bool canReach(Attributor &A, CallBase &CB, Function *Fn) const = 0; + virtual bool canReach(Attributor &A, CallBase &CB, + const Function &Fn) const = 0; + + /// Can \p Inst reach \p Fn + virtual bool instructionCanReach(Attributor &A, const Instruction &Inst, + const Function &Fn, + bool UseBackwards = true) const = 0; /// Create an abstract attribute view for the position \p IRP. static AAFunctionReachability &createForPosition(const IRPosition &IRP, Attributor &A); /// See AbstractAttribute::getName() - const std::string getName() const override { return "AAFunctionReachability"; } + const std::string getName() const override { + return "AAFunctionReachability"; + } /// See AbstractAttribute::getIdAddr() const char *getIdAddr() const override { return &ID; } diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -656,7 +656,7 @@ if (!AssociatedFunction) return S.indicatePessimisticFixpoint(); - CallBase &CBContext = static_cast(this->getAnchorValue()); + CallBase &CBContext = cast(this->getAnchorValue()); if (IntroduceCallBaseContext) LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:" << CBContext << "\n"); @@ -2468,7 +2468,7 @@ const AAFunctionReachability &EdgeReachability = A.getAAFor(*this, getIRPosition(), DepClassTy::REQUIRED); - if (EdgeReachability.canReach(A, getAnchorScope())) + if (EdgeReachability.canReach(A, *getAnchorScope())) return indicatePessimisticFixpoint(); return ChangeStatus::UNCHANGED; } @@ -9482,7 +9482,7 @@ } }; - CallBase *CB = static_cast(getCtxI()); + CallBase *CB = cast(getCtxI()); if (CB->isInlineAsm()) { setHasUnknownCallee(false, Change); @@ -9521,7 +9521,7 @@ ChangeStatus Change = ChangeStatus::UNCHANGED; auto ProcessCallInst = [&](Instruction &Inst) { - CallBase &CB = static_cast(Inst); + CallBase &CB = cast(Inst); auto &CBEdges = A.getAAFor( *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); @@ -9552,11 +9552,39 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability { private: struct QuerySet { - void markReachable(Function *Fn) { - Reachable.insert(Fn); - Unreachable.erase(Fn); + void markReachable(const Function &Fn) { + Reachable.insert(&Fn); + Unreachable.erase(&Fn); } + /// If there is no information about the function None is returned. + Optional isCachedReachable(const Function &Fn) { + // Assume that we can reach the function. + // TODO: Be more specific with the unknown callee. + if (CanReachUnknownCallee) + return true; + + if (Reachable.count(&Fn)) + return true; + + if (Unreachable.count(&Fn)) + return false; + + return llvm::None; + } + + /// Set of functions that we know for sure is reachable. + DenseSet Reachable; + + /// Set of functions that are unreachable, but might become reachable. + DenseSet Unreachable; + + /// If we can reach a function with a call to a unknown function we assume + /// that we can reach any function. + bool CanReachUnknownCallee = false; + }; + + struct QueryResolver : public QuerySet { ChangeStatus update(Attributor &A, const AAFunctionReachability &AA, ArrayRef AAEdgesList) { ChangeStatus Change = ChangeStatus::UNCHANGED; @@ -9570,31 +9598,25 @@ } } - for (Function *Fn : make_early_inc_range(Unreachable)) { - if (checkIfReachable(A, AA, AAEdgesList, Fn)) { + for (const Function *Fn : make_early_inc_range(Unreachable)) { + if (checkIfReachable(A, AA, AAEdgesList, *Fn)) { Change = ChangeStatus::CHANGED; - markReachable(Fn); + markReachable(*Fn); } } return Change; } bool isReachable(Attributor &A, const AAFunctionReachability &AA, - ArrayRef AAEdgesList, Function *Fn) { - // Assume that we can reach the function. - // TODO: Be more specific with the unknown callee. - if (CanReachUnknownCallee) - return true; - - if (Reachable.count(Fn)) - return true; - - if (Unreachable.count(Fn)) - return false; + ArrayRef AAEdgesList, + const Function &Fn) { + Optional Cached = isCachedReachable(Fn); + if (Cached.hasValue()) + return Cached.getValue(); // We need to assume that this function can't reach Fn to prevent // an infinite loop if this function is recursive. - Unreachable.insert(Fn); + Unreachable.insert(&Fn); bool Result = checkIfReachable(A, AA, AAEdgesList, Fn); if (Result) @@ -9604,13 +9626,13 @@ bool checkIfReachable(Attributor &A, const AAFunctionReachability &AA, ArrayRef AAEdgesList, - Function *Fn) const { + const Function &Fn) const { // Handle the most trivial case first. for (auto *AAEdges : AAEdgesList) { const SetVector &Edges = AAEdges->getOptimisticEdges(); - if (Edges.count(Fn)) + if (Edges.count(const_cast(&Fn))) return true; } @@ -9631,28 +9653,80 @@ } // The result is false for now, set dependencies and leave. - for (auto Dep : Deps) - A.recordDependence(AA, *Dep, DepClassTy::REQUIRED); + for (auto *Dep : Deps) + A.recordDependence(*Dep, AA, DepClassTy::REQUIRED); return false; } + }; - /// Set of functions that we know for sure is reachable. - DenseSet Reachable; + /// Get call edges that can be reached by this instruction. + bool getReachableCallEdges(Attributor &A, const AAReachability &Reachability, + const Instruction &Inst, + SmallVector &Result) const { + // Determine call like instructions that we can reach from the inst. + auto CheckCallBase = [&](Instruction &CBInst) { + if (!Reachability.isAssumedReachable(A, Inst, CBInst)) + return true; - /// Set of functions that are unreachable, but might become reachable. - DenseSet Unreachable; + const auto &CB = cast(CBInst); + const AACallEdges &AAEdges = A.getAAFor( + *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); - /// If we can reach a function with a call to a unknown function we assume - /// that we can reach any function. - bool CanReachUnknownCallee = false; - }; + Result.push_back(&AAEdges); + return true; + }; + + bool UsedAssumedInformation = false; + return A.checkForAllCallLikeInstructions(CheckCallBase, *this, + UsedAssumedInformation); + } + + ChangeStatus checkReachableBackwards(Attributor &A, QuerySet &Set) { + ChangeStatus Change = ChangeStatus::UNCHANGED; + + // For all remaining instruction queries, check + // callers. A call inside that function might satisfy the query. + auto CheckCallSite = [&](AbstractCallSite CallSite) { + CallBase *CB = CallSite.getInstruction(); + if (!CB) + return false; + + if (isa(CB)) + return false; + + Instruction *Inst = CB->getNextNonDebugInstruction(); + const AAFunctionReachability &AA = A.getAAFor( + *this, IRPosition::function(*Inst->getFunction()), + DepClassTy::REQUIRED); + for (const Function *Fn : make_early_inc_range(Set.Unreachable)) { + if (AA.instructionCanReach(A, *Inst, *Fn, /* UseBackwards */ false)) { + Set.markReachable(*Fn); + Change = ChangeStatus::CHANGED; + } + } + return true; + }; + + bool NoUnknownCall = true; + if (A.checkForAllCallSites(CheckCallSite, *this, true, NoUnknownCall)) + return Change; + + // If we don't know all callsites we have to assume that we can reach fn. + for (auto &QSet : InstQueriesBackwards) { + if (!QSet.second.CanReachUnknownCallee) + Change = ChangeStatus::CHANGED; + QSet.second.CanReachUnknownCallee = true; + } + + return Change; + } public: AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A) : AAFunctionReachability(IRP, A) {} - bool canReach(Attributor &A, Function *Fn) const override { + bool canReach(Attributor &A, const Function &Fn) const override { const AACallEdges &AAEdges = A.getAAFor(*this, getIRPosition(), DepClassTy::REQUIRED); @@ -9668,7 +9742,8 @@ } /// Can \p CB reach \p Fn - bool canReach(Attributor &A, CallBase &CB, Function *Fn) const override { + bool canReach(Attributor &A, CallBase &CB, + const Function &Fn) const override { const AACallEdges &AAEdges = A.getAAFor( *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); @@ -9677,13 +9752,52 @@ // a const_cast. // This is a hack for us to be able to cache queries. auto *NonConstThis = const_cast(this); - QuerySet &CBQuery = NonConstThis->CBQueries[&CB]; + QueryResolver &CBQuery = NonConstThis->CBQueries[&CB]; bool Result = CBQuery.isReachable(A, *this, {&AAEdges}, Fn); return Result; } + bool instructionCanReach(Attributor &A, const Instruction &Inst, + const Function &Fn, + bool UseBackwards) const override { + const auto &Reachability = &A.getAAFor( + *this, IRPosition::function(*getAssociatedFunction()), + DepClassTy::REQUIRED); + + SmallVector CallEdges; + bool AllKnown = getReachableCallEdges(A, *Reachability, Inst, CallEdges); + // Attributor returns attributes as const, so this function has to be + // const for users of this attribute to use it without having to do + // a const_cast. + // This is a hack for us to be able to cache queries. + auto *NonConstThis = const_cast(this); + QueryResolver &InstQSet = NonConstThis->InstQueries[&Inst]; + if (!AllKnown) + InstQSet.CanReachUnknownCallee = true; + + bool ForwardsResult = InstQSet.isReachable(A, *this, CallEdges, Fn); + if (ForwardsResult) + return true; + // We are done. + if (!UseBackwards) + return false; + + QuerySet &InstBackwardsQSet = NonConstThis->InstQueriesBackwards[&Inst]; + + Optional BackwardsCached = InstBackwardsQSet.isCachedReachable(Fn); + if (BackwardsCached.hasValue()) + return BackwardsCached.getValue(); + + // Assume unreachable, to prevent problems. + InstBackwardsQSet.Unreachable.insert(&Fn); + + // Check backwards reachability. + NonConstThis->checkReachableBackwards(A, InstBackwardsQSet); + return InstBackwardsQSet.isCachedReachable(Fn).getValue(); + } + /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { const AACallEdges &AAEdges = @@ -9692,7 +9806,7 @@ Change |= WholeFunction.update(A, *this, {&AAEdges}); - for (auto CBPair : CBQueries) { + for (auto &CBPair : CBQueries) { const AACallEdges &AAEdges = A.getAAFor( *this, IRPosition::callsite_function(*CBPair.first), DepClassTy::REQUIRED); @@ -9700,6 +9814,29 @@ Change |= CBPair.second.update(A, *this, {&AAEdges}); } + // Update the Instruction queries. + const AAReachability *Reachability; + if (!InstQueries.empty()) { + Reachability = &A.getAAFor( + *this, IRPosition::function(*getAssociatedFunction()), + DepClassTy::REQUIRED); + } + + // Check for local callbases first. + for (auto &InstPair : InstQueries) { + SmallVector CallEdges; + bool AllKnown = + getReachableCallEdges(A, *Reachability, *InstPair.first, CallEdges); + // Update will return change if we this effects any queries. + if (!AllKnown) + InstPair.second.CanReachUnknownCallee = true; + Change |= InstPair.second.update(A, *this, CallEdges); + } + + // Update backwards queries. + for (auto &QueryPair : InstQueriesBackwards) + Change |= checkReachableBackwards(A, QueryPair.second); + return Change; } @@ -9720,11 +9857,17 @@ } /// Used to answer if a the whole function can reacha a specific function. - QuerySet WholeFunction; + QueryResolver WholeFunction; /// Used to answer if a call base inside this function can reach a specific /// function. - DenseMap CBQueries; + DenseMap CBQueries; + + /// This is for instruction queries than scan "forward". + DenseMap InstQueries; + + /// This is for instruction queries than scan "backward". + DenseMap InstQueriesBackwards; }; /// ---------------------- Assumption Propagation ------------------------------ diff --git a/llvm/unittests/Transforms/IPO/AttributorTest.cpp b/llvm/unittests/Transforms/IPO/AttributorTest.cpp --- a/llvm/unittests/Transforms/IPO/AttributorTest.cpp +++ b/llvm/unittests/Transforms/IPO/AttributorTest.cpp @@ -75,47 +75,69 @@ TEST_F(AttributorTestBase, AAReachabilityTest) { const char *ModuleString = R"( - @x = global i32 0 - define void @func4() { + @x = external global i32 + define internal void @func4() { store i32 0, i32* @x ret void } - define void @func3() { + define internal void @func3() { store i32 0, i32* @x ret void } - define void @func2() { + define internal void @func8() { + store i32 0, i32* @x + ret void + } + + define internal void @func2() { entry: call void @func3() ret void } - define void @func1() { + define internal void @func1() { entry: call void @func2() ret void } - define void @func5(void ()* %unknown) { + declare void @unknown() + define internal void @func5(void ()* %ptr) { entry: - call void %unknown() + call void %ptr() + call void @unknown() ret void } - define void @func6() { + define internal void @func6() { entry: call void @func5(void ()* @func3) ret void } - define void @func7() { + define internal void @func7() { + entry: + call void @func2() + call void @func4() + ret void + } + + define internal void @func9() { entry: call void @func2() + call void @func8() + ret void + } + + define void @func10() { + entry: + call void @func9() call void @func4() ret void } + )"; Module &M = parseModule(ModuleString); @@ -128,32 +150,43 @@ CallGraphUpdater CGUpdater; BumpPtrAllocator Allocator; InformationCache InfoCache(M, AG, Allocator, nullptr); - Attributor A(Functions, InfoCache, CGUpdater); + Attributor A(Functions, InfoCache, CGUpdater, /* Allowed */ nullptr, + /*DeleteFns*/ false); - Function *F1 = M.getFunction("func1"); - Function *F3 = M.getFunction("func3"); - Function *F4 = M.getFunction("func4"); - Function *F6 = M.getFunction("func6"); - Function *F7 = M.getFunction("func7"); + Function &F1 = *M.getFunction("func1"); + Function &F3 = *M.getFunction("func3"); + Function &F4 = *M.getFunction("func4"); + Function &F6 = *M.getFunction("func6"); + Function &F7 = *M.getFunction("func7"); + Function &F9 = *M.getFunction("func9"); // call void @func2() - CallBase &F7FirstCB = - *static_cast(F7->getEntryBlock().getFirstNonPHI()); + CallBase &F7FirstCB = static_cast(*F7.getEntryBlock().begin()); + // call void @func2() + Instruction &F9FirstInst = *F9.getEntryBlock().begin(); + // call void @func8 + Instruction &F9SecondInst = *++(F9.getEntryBlock().begin()); const AAFunctionReachability &F1AA = - A.getOrCreateAAFor(IRPosition::function(*F1)); + A.getOrCreateAAFor(IRPosition::function(F1)); const AAFunctionReachability &F6AA = - A.getOrCreateAAFor(IRPosition::function(*F6)); + A.getOrCreateAAFor(IRPosition::function(F6)); const AAFunctionReachability &F7AA = - A.getOrCreateAAFor(IRPosition::function(*F7)); + A.getOrCreateAAFor(IRPosition::function(F7)); + + const AAFunctionReachability &F9AA = + A.getOrCreateAAFor(IRPosition::function(F9)); F1AA.canReach(A, F3); F1AA.canReach(A, F4); F6AA.canReach(A, F4); F7AA.canReach(A, F7FirstCB, F3); F7AA.canReach(A, F7FirstCB, F4); + F9AA.instructionCanReach(A, F9FirstInst, F3); + F9AA.instructionCanReach(A, F9SecondInst, F3, false); + F9AA.instructionCanReach(A, F9FirstInst, F4); A.run(); @@ -166,6 +199,15 @@ // Assumed to be reacahable, since F6 can reach a function with // a unknown callee. ASSERT_TRUE(F6AA.canReach(A, F4)); + + // The second instruction of F9 can't reach the first call. + ASSERT_FALSE(F9AA.instructionCanReach(A, F9SecondInst, F3, false)); + ASSERT_FALSE(F9AA.instructionCanReach(A, F9SecondInst, F3, true)); + + // The first instruction of F9 can reach the first call. + ASSERT_TRUE(F9AA.instructionCanReach(A, F9FirstInst, F3)); + // Because func10 calls the func4 after the call to func9 it is reachable. + ASSERT_TRUE(F9AA.instructionCanReach(A, F9FirstInst, F4)); } } // namespace llvm