diff --git a/llvm/include/llvm/Analysis/CallGraph.h b/llvm/include/llvm/Analysis/CallGraph.h --- a/llvm/include/llvm/Analysis/CallGraph.h +++ b/llvm/include/llvm/Analysis/CallGraph.h @@ -94,10 +94,6 @@ /// callers from the old function to the new. void spliceFunction(const Function *From, const Function *To); - /// Add a function to the call graph, and link the node to all of the - /// functions that it calls. - void addToCallGraph(Function *F); - public: explicit CallGraph(Module &M); CallGraph(CallGraph &&Arg); @@ -158,6 +154,13 @@ /// Similar to operator[], but this will insert a new CallGraphNode for /// \c F if one does not already exist. CallGraphNode *getOrInsertFunction(const Function *F); + + /// Populate \p CGN based on the calls inside the associated function. + void populateCallGraphNode(CallGraphNode *CGN); + + /// Add a function to the call graph, and link the node to all of the + /// functions that it calls. + void addToCallGraph(Function *F); }; /// A node in the call graph for a module. diff --git a/llvm/include/llvm/Analysis/LazyCallGraph.h b/llvm/include/llvm/Analysis/LazyCallGraph.h --- a/llvm/include/llvm/Analysis/LazyCallGraph.h +++ b/llvm/include/llvm/Analysis/LazyCallGraph.h @@ -1058,6 +1058,9 @@ /// fully visited by the DFS prior to calling this routine. void removeDeadFunction(Function &F); + /// Introduce a node for the function \p NewF in the SCC \p C. + void addNewFunctionIntoSCC(Function &NewF, SCC &C); + ///@} ///@{ diff --git a/llvm/include/llvm/Transforms/Utils/CallGraphUpdater.h b/llvm/include/llvm/Transforms/Utils/CallGraphUpdater.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Transforms/Utils/CallGraphUpdater.h @@ -0,0 +1,106 @@ +//===- CallGraphUpdater.h - A (lazy) call graph update helper ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// +/// This file provides interfaces used to manipulate a call graph, regardless +/// if it is a "old style" CallGraph or an "new style" LazyCallGraph. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_UTILS_CALLGRAPHUPDATER_H +#define LLVM_TRANSFORMS_UTILS_CALLGRAPHUPDATER_H + +#include "llvm/Analysis/CGSCCPassManager.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/CallGraphSCCPass.h" +#include "llvm/Analysis/LazyCallGraph.h" + +namespace llvm { + +/// Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph. This +/// simplifies the interface and the call sites, e.g., new and old pass manager +/// passes can share the same code. +class CallGraphUpdater { + /// Containers for functions which we did replace or want to delete when + /// `finalize` is called. This can happen explicitly or as part of the + /// destructor. Dead functions in comdat sections are tracked seperatly + /// because a function with discardable linakage in a COMDAT should only + /// be dropped if the entire COMDAT is dropped, see git ac07703842cf. + ///{ + SmallPtrSet ReplacedFunctions; + SmallVector DeadFunctions; + SmallVector DeadFunctionsInComdats; + ///} + + /// Old PM variables + ///{ + CallGraph *CG = nullptr; + CallGraphSCC *CGSCC = nullptr; + ///} + + /// New PM variables + ///{ + LazyCallGraph *LCG = nullptr; + LazyCallGraph::SCC *SCC = nullptr; + CGSCCAnalysisManager *AM = nullptr; + CGSCCUpdateResult *UR = nullptr; + ///} + +public: + CallGraphUpdater() {} + ~CallGraphUpdater() { finalize(); } + + /// Initializers for usage outside of a CGSCC pass, inside a CGSCC pass in + /// the old and new pass manager (PM). + ///{ + void initialize(CallGraph &CG, CallGraphSCC &SCC) { + this->CG = &CG; + this->CGSCC = &SCC; + } + void initialize(LazyCallGraph &LCG, LazyCallGraph::SCC &SCC, + CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR) { + this->LCG = &LCG; + this->SCC = &SCC; + this->AM = &AM; + this->UR = &UR; + } + ///} + + /// Finalizer that will trigger actions like function removal from the CG. + bool finalize(); + + /// Remove \p Fn from the call graph. + void removeFunction(Function &Fn); + + /// After an CGSCC pass changes a function in ways that affect the call + /// graph, this method can be called to update it. + void reanalyzeFunction(Function &Fn); + + /// If a new function was created by outlining, this method can be called + /// to update the call graph for the new function. Note that the old one + /// still needs to be re-analyzed or manually updated. + void registerOutlinedFunction(Function &NewFn); + + /// Replace \p OldFn in the call graph (and SCC) with \p NewFn. The uses + /// outside the call graph and the function \p OldFn are not modified. + /// Note that \p OldFn is also removed from the call graph + /// (\see removeFunction). + void replaceFunctionWith(Function &OldFn, Function &NewFn); + + /// Remove the call site \p CS from the call graph. + void removeCallSite(CallBase &CS); + + /// Replace \p OldCS with the new call site \p NewCS. + /// \return True if the replacement was successful, otherwise False. In the + /// latter case the parent function of \p OldCB needs to be re-analyzed. + bool replaceCallSite(CallBase &OldCS, CallBase &NewCS); +}; + +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_UTILS_CALLGRAPHUPDATER_H diff --git a/llvm/lib/Analysis/CallGraph.cpp b/llvm/lib/Analysis/CallGraph.cpp --- a/llvm/lib/Analysis/CallGraph.cpp +++ b/llvm/lib/Analysis/CallGraph.cpp @@ -74,6 +74,12 @@ if (!F->hasLocalLinkage() || F->hasAddressTaken()) ExternalCallingNode->addCalledFunction(nullptr, Node); + populateCallGraphNode(Node); +} + +void CallGraph::populateCallGraphNode(CallGraphNode *Node) { + Function *F = Node->getFunction(); + // If this function is not defined in this translation unit, it could call // anything. if (F->isDeclaration() && !F->isIntrinsic()) diff --git a/llvm/lib/Analysis/CallGraphSCCPass.cpp b/llvm/lib/Analysis/CallGraphSCCPass.cpp --- a/llvm/lib/Analysis/CallGraphSCCPass.cpp +++ b/llvm/lib/Analysis/CallGraphSCCPass.cpp @@ -549,7 +549,10 @@ for (unsigned i = 0; ; ++i) { assert(i != Nodes.size() && "Node not in SCC"); if (Nodes[i] != Old) continue; - Nodes[i] = New; + if (New) + Nodes[i] = New; + else + Nodes.erase(Nodes.begin() + i); break; } diff --git a/llvm/lib/Analysis/LazyCallGraph.cpp b/llvm/lib/Analysis/LazyCallGraph.cpp --- a/llvm/lib/Analysis/LazyCallGraph.cpp +++ b/llvm/lib/Analysis/LazyCallGraph.cpp @@ -1566,6 +1566,15 @@ // allocators. } +void LazyCallGraph::addNewFunctionIntoSCC(Function &NewF, SCC &C) { + Node &CGNode = get(NewF); + CGNode.DFSNumber = CGNode.LowLink = -1; + CGNode.populate(); + C.Nodes.push_back(&CGNode); + SCCMap[&CGNode] = &C; + NodeMap[&NewF] = &CGNode; +} + LazyCallGraph::Node &LazyCallGraph::insertInto(Function &F, Node *&MappedN) { return *new (MappedN = BPA.Allocate()) Node(*this, F); } diff --git a/llvm/lib/Transforms/Utils/CMakeLists.txt b/llvm/lib/Transforms/Utils/CMakeLists.txt --- a/llvm/lib/Transforms/Utils/CMakeLists.txt +++ b/llvm/lib/Transforms/Utils/CMakeLists.txt @@ -7,6 +7,7 @@ BuildLibCalls.cpp BypassSlowDivision.cpp CallPromotionUtils.cpp + CallGraphUpdater.cpp CanonicalizeAliases.cpp CloneFunction.cpp CloneModule.cpp diff --git a/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp b/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp @@ -0,0 +1,152 @@ +//===- CallGraphUpdater.cpp - A (lazy) call graph update helper -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// +/// This file provides interfaces used to manipulate a call graph, regardless +/// if it is a "old style" CallGraph or an "new style" LazyCallGraph. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/CallGraphUpdater.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +using namespace llvm; + +bool CallGraphUpdater::finalize() { + if (!DeadFunctionsInComdats.empty()) { + filterDeadComdatFunctions(*DeadFunctionsInComdats.front()->getParent(), + DeadFunctionsInComdats); + DeadFunctions.append(DeadFunctionsInComdats.begin(), + DeadFunctionsInComdats.end()); + } + + for (Function *DeadFn : DeadFunctions) { + DeadFn->removeDeadConstantUsers(); + + if (CG) { + CallGraphNode *OldCGN = CG->getOrInsertFunction(DeadFn); + CG->getExternalCallingNode()->removeAnyCallEdgeTo(OldCGN); + OldCGN->removeAllCalledFunctions(); + DeadFn->replaceAllUsesWith(UndefValue::get(DeadFn->getType())); + + assert(OldCGN->getNumReferences() == 0); + + delete CG->removeFunctionFromModule(OldCGN); + continue; + } + + // The old style call graph (CG) has a value handle we do not want to + // replace with undef so we do this here. + DeadFn->replaceAllUsesWith(UndefValue::get(DeadFn->getType())); + + if (LCG && !ReplacedFunctions.count(DeadFn)) { + // Taken mostly from the inliner: + FunctionAnalysisManager &FAM = + AM->getResult(*SCC, *LCG) + .getManager(); + + LazyCallGraph::Node &N = LCG->get(*DeadFn); + auto *DeadSCC = LCG->lookupSCC(N); + assert(DeadSCC && DeadSCC->size() == 1 && + &DeadSCC->begin()->getFunction() == DeadFn); + auto &DeadRC = DeadSCC->getOuterRefSCC(); + + FAM.clear(*DeadFn, DeadFn->getName()); + AM->clear(*DeadSCC, DeadSCC->getName()); + LCG->removeDeadFunction(*DeadFn); + + // Mark the relevant parts of the call graph as invalid so we don't visit + // them. + UR->InvalidatedSCCs.insert(DeadSCC); + UR->InvalidatedRefSCCs.insert(&DeadRC); + } + + // The function is now really dead and de-attached from everything. + DeadFn->eraseFromParent(); + } + + bool Changed = !DeadFunctions.empty(); + DeadFunctionsInComdats.clear(); + DeadFunctions.clear(); + return Changed; +} + +void CallGraphUpdater::reanalyzeFunction(Function &Fn) { + if (CG) { + CallGraphNode *OldCGN = CG->getOrInsertFunction(&Fn); + OldCGN->removeAllCalledFunctions(); + CG->populateCallGraphNode(OldCGN); + } else if (LCG) { + LazyCallGraph::Node &N = LCG->get(Fn); + LazyCallGraph::SCC *C = LCG->lookupSCC(N); + updateCGAndAnalysisManagerForCGSCCPass(*LCG, *C, N, *AM, *UR); + } +} + +void CallGraphUpdater::registerOutlinedFunction(Function &NewFn) { + if (CG) + CG->addToCallGraph(&NewFn); + else if (LCG) + LCG->addNewFunctionIntoSCC(NewFn, *SCC); +} + +void CallGraphUpdater::removeFunction(Function &DeadFn) { + DeadFn.deleteBody(); + DeadFn.setLinkage(GlobalValue::ExternalLinkage); + if (DeadFn.hasComdat()) + DeadFunctionsInComdats.push_back(&DeadFn); + else + DeadFunctions.push_back(&DeadFn); +} + +void CallGraphUpdater::replaceFunctionWith(Function &OldFn, Function &NewFn) { + ReplacedFunctions.insert(&OldFn); + if (CG) { + // Update the call graph for the newly promoted function. + // CG->spliceFunction(&OldFn, &NewFn); + CallGraphNode *OldCGN = (*CG)[&OldFn]; + CallGraphNode *NewCGN = CG->getOrInsertFunction(&NewFn); + NewCGN->stealCalledFunctionsFrom(OldCGN); + + // And update the SCC we're iterating as well. + CGSCC->ReplaceNode(OldCGN, NewCGN); + } else if (LCG) { + // Directly substitute the functions in the call graph. + LazyCallGraph::Node &OldLCGN = LCG->get(OldFn); + SCC->getOuterRefSCC().replaceNodeFunction(OldLCGN, NewFn); + } + removeFunction(OldFn); +} + +bool CallGraphUpdater::replaceCallSite(CallBase &OldCS, CallBase &NewCS) { + // This is only necessary in the (old) CG. + if (!CG) + return true; + + Function *Caller = OldCS.getCaller(); + CallGraphNode *NewCalleeNode = + CG->getOrInsertFunction(NewCS.getCalledFunction()); + CallGraphNode *CallerNode = (*CG)[Caller]; + if (llvm::none_of(*CallerNode, [&OldCS](const CallGraphNode::CallRecord &CR) { + return CR.first == &OldCS; + })) + return false; + CallerNode->replaceCallEdge(OldCS, NewCS, NewCalleeNode); + return true; +} + +void CallGraphUpdater::removeCallSite(CallBase &CS) { + // This is only necessary in the (old) CG. + if (!CG) + return; + + Function *Caller = CS.getCaller(); + CallGraphNode *CallerNode = (*CG)[Caller]; + CallerNode->removeCallEdgeFor(CS); +} diff --git a/llvm/unittests/Analysis/CGSCCPassManagerTest.cpp b/llvm/unittests/Analysis/CGSCCPassManagerTest.cpp --- a/llvm/unittests/Analysis/CGSCCPassManagerTest.cpp +++ b/llvm/unittests/Analysis/CGSCCPassManagerTest.cpp @@ -16,6 +16,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Transforms/Utils/CallGraphUpdater.h" #include "gtest/gtest.h" using namespace llvm; @@ -1315,7 +1316,11 @@ PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR) { Func(C, AM, CG, UR); - return PreservedAnalyses::none(); + PreservedAnalyses PA; + // We update the core CGSCC data structures and so can preserve the proxy to + // the function analysis manager. + PA.preserve(); + return PA; } std::functiongetFunction("f"); + Function *FnewF = Function::Create(FnF->getFunctionType(), + FnF->getLinkage(), "newF", *M); + BasicBlock *BB = BasicBlock::Create(FnewF->getContext(), "", FnewF); + ReturnInst::Create(FnewF->getContext(), BB); + + // Use the CallGraphUpdater to update the call graph for the new + // function. + CallGraphUpdater CGU; + CGU.initialize(CG, C, AM, UR); + CGU.registerOutlinedFunction(*FnewF); + + // And insert a call to `newF` + Instruction *IP = &FnF->getEntryBlock().front(); + (void)CallInst::Create(FnewF, {}, "", IP); + + auto &FN = *llvm::find_if( + C, [](LazyCallGraph::Node &N) { return N.getName() == "f"; }); + + ASSERT_NO_FATAL_FAILURE( + updateCGAndAnalysisManagerForCGSCCPass(CG, C, FN, AM, UR)); + })); + + ModulePassManager MPM(/*DebugLogging*/ true); + MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); + MPM.run(*M, MAM); +} + +TEST_F(CGSCCPassManagerTest, TestUpdateCGAndAnalysisManagerForPasses5) { + CGSCCPassManager CGPM(/*DebugLogging*/ true); + CGPM.addPass(LambdaSCCPassNoPreserve([&](LazyCallGraph::SCC &C, + CGSCCAnalysisManager &AM, + LazyCallGraph &CG, + CGSCCUpdateResult &UR) { + if (C.getName() != "(f)") + return; + + Function *FnF = M->getFunction("f"); + Function *FnewF = + Function::Create(FnF->getFunctionType(), FnF->getLinkage(), "newF", *M); + BasicBlock *BB = BasicBlock::Create(FnewF->getContext(), "", FnewF); + ReturnInst::Create(FnewF->getContext(), BB); + + // Use the CallGraphUpdater to update the call graph for the new + // function. + CallGraphUpdater CGU; + CGU.initialize(CG, C, AM, UR); + CGU.registerOutlinedFunction(*FnewF); + + // And insert a call to `newF` + Instruction *IP = &FnF->getEntryBlock().front(); + (void)CallInst::Create(FnewF, {}, "", IP); + + auto &FN = *llvm::find_if( + C, [](LazyCallGraph::Node &N) { return N.getName() == "f"; }); + + ASSERT_DEATH(updateCGAndAnalysisManagerForFunctionPass(CG, C, FN, AM, UR), + "Any new calls should be modeled as"); + })); + + ModulePassManager MPM(/*DebugLogging*/ true); + MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); + MPM.run(*M, MAM); +} + +TEST_F(CGSCCPassManagerTest, TestUpdateCGAndAnalysisManagerForPasses6) { + CGSCCPassManager CGPM(/*DebugLogging*/ true); + CGPM.addPass(LambdaSCCPassNoPreserve( + [&](LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, + CGSCCUpdateResult &UR) { + if (C.getName() != "(h3, h1, h2)") + return; + + Function *FnX = M->getFunction("x"); + Function *FnH1 = M->getFunction("h1"); + Function *FnH2 = M->getFunction("h2"); + Function *FnH3 = M->getFunction("h3"); + ASSERT_NE(FnX, nullptr); + ASSERT_NE(FnH1, nullptr); + ASSERT_NE(FnH2, nullptr); + ASSERT_NE(FnH3, nullptr); + + // And insert a call to `h1`, `h2`, and `h3`. + Instruction *IP = &FnH2->getEntryBlock().front(); + (void)CallInst::Create(FnH1, {}, "", IP); + (void)CallInst::Create(FnH2, {}, "", IP); + (void)CallInst::Create(FnH3, {}, "", IP); + + // Use the CallGraphUpdater to update the call graph for the new + // function. + CallGraphUpdater CGU; + CGU.initialize(CG, C, AM, UR); + ASSERT_NO_FATAL_FAILURE(CGU.reanalyzeFunction(*FnH2)); + })); + + ModulePassManager MPM(/*DebugLogging*/ true); + MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); + MPM.run(*M, MAM); +} + +TEST_F(CGSCCPassManagerTest, TestUpdateCGAndAnalysisManagerForPasses7) { + CGSCCPassManager CGPM(/*DebugLogging*/ true); + CGPM.addPass(LambdaSCCPassNoPreserve( + [&](LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, + CGSCCUpdateResult &UR) { + if (C.getName() != "(f)") + return; + + Function *FnF = M->getFunction("f"); + Function *FnH2 = M->getFunction("h2"); + ASSERT_NE(FnF, nullptr); + ASSERT_NE(FnH2, nullptr); + + // And insert a call to `h2` + Instruction *IP = &FnF->getEntryBlock().front(); + (void)CallInst::Create(FnH2, {}, "", IP); + + // Use the CallGraphUpdater to update the call graph for the new + // function. + CallGraphUpdater CGU; + CGU.initialize(CG, C, AM, UR); + ASSERT_NO_FATAL_FAILURE(CGU.reanalyzeFunction(*FnF)); + })); + + ModulePassManager MPM(/*DebugLogging*/ true); + MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); + MPM.run(*M, MAM); +} + +TEST_F(CGSCCPassManagerTest, TestUpdateCGAndAnalysisManagerForPasses8) { + CGSCCPassManager CGPM(/*DebugLogging*/ true); + CGPM.addPass(LambdaSCCPassNoPreserve( + [&](LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, + CGSCCUpdateResult &UR) { + if (C.getName() != "(f)") + return; + + Function *FnF = M->getFunction("f"); + Function *FnewF = Function::Create(FnF->getFunctionType(), + FnF->getLinkage(), "newF", *M); + BasicBlock *BB = BasicBlock::Create(FnewF->getContext(), "", FnewF); + auto *RI = ReturnInst::Create(FnewF->getContext(), BB); + while (FnF->getEntryBlock().size() > 1) + FnF->getEntryBlock().front().moveBefore(RI); + ASSERT_NE(FnF, nullptr); + + // Use the CallGraphUpdater to update the call graph. + CallGraphUpdater CGU; + CGU.initialize(CG, C, AM, UR); + ASSERT_NO_FATAL_FAILURE(CGU.replaceFunctionWith(*FnF, *FnewF)); + ASSERT_TRUE(FnF->isDeclaration()); + ASSERT_EQ(FnF->getNumUses(), 0U); + })); + + ModulePassManager MPM(/*DebugLogging*/ true); + MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); + MPM.run(*M, MAM); +} + +TEST_F(CGSCCPassManagerTest, TestUpdateCGAndAnalysisManagerForPasses9) { + CGSCCPassManager CGPM(/*DebugLogging*/ true); + CGPM.addPass(LambdaSCCPassNoPreserve( + [&](LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, + CGSCCUpdateResult &UR) { + if (C.getName() != "(f)") + return; + + Function *FnF = M->getFunction("f"); + + // Use the CallGraphUpdater to update the call graph. + { + CallGraphUpdater CGU; + CGU.initialize(CG, C, AM, UR); + ASSERT_NO_FATAL_FAILURE(CGU.removeFunction(*FnF)); + ASSERT_EQ(M->getFunctionList().size(), 6U); + } + ASSERT_EQ(M->getFunctionList().size(), 5U); + })); + + ModulePassManager MPM(/*DebugLogging*/ true); + MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); + MPM.run(*M, MAM); +} + +TEST_F(CGSCCPassManagerTest, TestUpdateCGAndAnalysisManagerForPasses10) { + CGSCCPassManager CGPM(/*DebugLogging*/ true); + CGPM.addPass(LambdaSCCPassNoPreserve( + [&](LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, + CGSCCUpdateResult &UR) { + if (C.getName() != "(h3, h1, h2)") + return; + + Function *FnX = M->getFunction("x"); + Function *FnH1 = M->getFunction("h1"); + Function *FnH2 = M->getFunction("h2"); + Function *FnH3 = M->getFunction("h3"); + ASSERT_NE(FnX, nullptr); + ASSERT_NE(FnH1, nullptr); + ASSERT_NE(FnH2, nullptr); + ASSERT_NE(FnH3, nullptr); + + // And insert a call to `h1`, and `h3`. + Instruction *IP = &FnH1->getEntryBlock().front(); + (void)CallInst::Create(FnH1, {}, "", IP); + (void)CallInst::Create(FnH3, {}, "", IP); + + // Remove the `h2` call. + ASSERT_TRUE(isa(IP)); + ASSERT_EQ(cast(IP)->getCalledFunction(), FnH2); + IP->eraseFromParent(); + + // Use the CallGraphUpdater to update the call graph. + CallGraphUpdater CGU; + CGU.initialize(CG, C, AM, UR); + ASSERT_NO_FATAL_FAILURE(CGU.reanalyzeFunction(*FnH1)); + ASSERT_NO_FATAL_FAILURE(CGU.removeFunction(*FnH2)); + })); + + ModulePassManager MPM(/*DebugLogging*/ true); + MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); + MPM.run(*M, MAM); +} + #endif } // namespace diff --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt --- a/llvm/unittests/Analysis/CMakeLists.txt +++ b/llvm/unittests/Analysis/CMakeLists.txt @@ -3,6 +3,7 @@ AsmParser Core Support + TransformUtils ) add_llvm_unittest(AnalysisTests diff --git a/llvm/unittests/Analysis/LazyCallGraphTest.cpp b/llvm/unittests/Analysis/LazyCallGraphTest.cpp --- a/llvm/unittests/Analysis/LazyCallGraphTest.cpp +++ b/llvm/unittests/Analysis/LazyCallGraphTest.cpp @@ -450,6 +450,47 @@ EXPECT_EQ(0, std::distance(B->begin(), B->end())); } +TEST(LazyCallGraphTest, BasicGraphMutationOutlining) { + LLVMContext Context; + std::unique_ptr M = parseAssembly(Context, "define void @a() {\n" + "entry:\n" + " call void @b()\n" + " call void @c()\n" + " ret void\n" + "}\n" + "define void @b() {\n" + "entry:\n" + " ret void\n" + "}\n" + "define void @c() {\n" + "entry:\n" + " ret void\n" + "}\n"); + LazyCallGraph CG = buildCG(*M); + + LazyCallGraph::Node &A = CG.get(lookupFunction(*M, "a")); + LazyCallGraph::Node &B = CG.get(lookupFunction(*M, "b")); + LazyCallGraph::Node &C = CG.get(lookupFunction(*M, "c")); + A.populate(); + B.populate(); + C.populate(); + CG.buildRefSCCs(); + + // Add a new function that is called from @b and verify it is in the same SCC. + Function &BFn = B.getFunction(); + Function *NewFn = + Function::Create(BFn.getFunctionType(), BFn.getLinkage(), "NewFn", *M); + auto IP = BFn.getEntryBlock().getFirstInsertionPt(); + CallInst::Create(NewFn, "", &*IP); + CG.addNewFunctionIntoSCC(*NewFn, *CG.lookupSCC(B)); + + EXPECT_EQ(CG.lookupSCC(A)->size(), 1U); + EXPECT_EQ(CG.lookupSCC(B)->size(), 2U); + EXPECT_EQ(CG.lookupSCC(C)->size(), 1U); + EXPECT_EQ(CG.lookupSCC(*CG.lookup(*NewFn))->size(), 2U); + EXPECT_EQ(CG.lookupSCC(*CG.lookup(*NewFn))->size(), CG.lookupSCC(B)->size()); +} + TEST(LazyCallGraphTest, InnerSCCFormation) { LLVMContext Context; std::unique_ptr M = parseAssembly(Context, DiamondOfTriangles); diff --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt --- a/llvm/unittests/IR/CMakeLists.txt +++ b/llvm/unittests/IR/CMakeLists.txt @@ -4,6 +4,7 @@ Core Support Passes + TransformUtils ) add_llvm_unittest(IRTests diff --git a/llvm/unittests/IR/LegacyPassManagerTest.cpp b/llvm/unittests/IR/LegacyPassManagerTest.cpp --- a/llvm/unittests/IR/LegacyPassManagerTest.cpp +++ b/llvm/unittests/IR/LegacyPassManagerTest.cpp @@ -29,6 +29,7 @@ #include "llvm/InitializePasses.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/CallGraphUpdater.h" #include "gtest/gtest.h" using namespace llvm; @@ -559,6 +560,72 @@ return mod; } + struct CGModifierPass : public CGPass { + unsigned NumSCCs = 0; + unsigned NumFns = 0; + bool SetupWorked = true; + + CallGraphUpdater CGU; + + bool runOnSCC(CallGraphSCC &SCMM) override { + ++NumSCCs; + for (CallGraphNode *N : SCMM) + if (N->getFunction()) + ++NumFns; + + CGPass::run(); + + if (SCMM.size() <= 1) + return false; + + CallGraphNode *N = *(SCMM.begin()); + Function *F = N->getFunction(); + Module *M = F->getParent(); + Function *Test1F = M->getFunction("test1"); + Function *Test2F = M->getFunction("test2"); + Function *Test3F = M->getFunction("test3"); + auto InSCC = [&](Function *Fn) { + return llvm::any_of(SCMM, [Fn](CallGraphNode *CGN) { + return CGN->getFunction() == Fn; + }); + }; + + if (!Test1F || !Test2F || !Test3F || !InSCC(Test1F) || !InSCC(Test2F) || + !InSCC(Test3F)) + return SetupWorked = false; + + CallInst *CI = dyn_cast(&Test1F->getEntryBlock().front()); + if (!CI || CI->getCalledFunction() != Test2F) + return SetupWorked = false; + + CI->setCalledFunction(Test3F); + + CGU.initialize(const_cast(SCMM.getCallGraph()), SCMM); + CGU.removeFunction(*Test2F); + CGU.reanalyzeFunction(*Test1F); + return true; + } + + bool doFinalization(CallGraph &CG) override { return CGU.finalize(); } + }; + + TEST(PassManager, CallGraphUpdater0) { + // SCC#1: test1->test2->test3->test1 + // SCC#2: test4 + // SCC#3: indirect call node + + LLVMContext Context; + std::unique_ptr M(makeLLVMModule(Context)); + ASSERT_EQ(M->getFunctionList().size(), 4U); + CGModifierPass *P = new CGModifierPass(); + legacy::PassManager Passes; + Passes.add(P); + Passes.run(*M); + ASSERT_TRUE(P->SetupWorked); + ASSERT_EQ(P->NumSCCs, 3U); + ASSERT_EQ(P->NumFns, 4U); + ASSERT_EQ(M->getFunctionList().size(), 3U); + } } }