diff --git a/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp b/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp --- a/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp +++ b/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp @@ -26,49 +26,56 @@ DeadFunctionsInComdats.end()); } - for (Function *DeadFn : DeadFunctions) { - DeadFn->removeDeadConstantUsers(); - - if (CG) { - CallGraphNode *OldCGN = CG->getOrInsertFunction(DeadFn); - CG->getExternalCallingNode()->removeAnyCallEdgeTo(OldCGN); - OldCGN->removeAllCalledFunctions(); + if (CG) { + // First remove all references, e.g., outgoing via called functions. This is + // necessary as we can delete functions that have circular references. + for (Function *DeadFn : DeadFunctions) { + DeadFn->removeDeadConstantUsers(); + CallGraphNode *DeadCGN = (*CG)[DeadFn]; + DeadCGN->removeAllCalledFunctions(); + CG->getExternalCallingNode()->removeAnyCallEdgeTo(DeadCGN); DeadFn->replaceAllUsesWith(UndefValue::get(DeadFn->getType())); - - assert(OldCGN->getNumReferences() == 0); - - delete CG->removeFunctionFromModule(OldCGN); - continue; } - // The old style call graph (CG) has a value handle we do not want to - // replace with undef so we do this here. - DeadFn->replaceAllUsesWith(UndefValue::get(DeadFn->getType())); - - if (LCG && !ReplacedFunctions.count(DeadFn)) { - // Taken mostly from the inliner: - LazyCallGraph::Node &N = LCG->get(*DeadFn); - auto *DeadSCC = LCG->lookupSCC(N); - assert(DeadSCC && DeadSCC->size() == 1 && - &DeadSCC->begin()->getFunction() == DeadFn); - auto &DeadRC = DeadSCC->getOuterRefSCC(); - - FunctionAnalysisManager &FAM = - AM->getResult(*DeadSCC, *LCG) - .getManager(); - - FAM.clear(*DeadFn, DeadFn->getName()); - AM->clear(*DeadSCC, DeadSCC->getName()); - LCG->removeDeadFunction(*DeadFn); - - // Mark the relevant parts of the call graph as invalid so we don't visit - // them. - UR->InvalidatedSCCs.insert(DeadSCC); - UR->InvalidatedRefSCCs.insert(&DeadRC); + // Then remove the node and function from the module. + for (Function *DeadFn : DeadFunctions) { + CallGraphNode *DeadCGN = CG->getOrInsertFunction(DeadFn); + assert(DeadCGN->getNumReferences() == 0 && + "References should have been handled by now"); + delete CG->removeFunctionFromModule(DeadCGN); } + } else { + // This is the code path for the new lazy call graph and for the case were + // no call graph was provided. + for (Function *DeadFn : DeadFunctions) { + DeadFn->removeDeadConstantUsers(); + DeadFn->replaceAllUsesWith(UndefValue::get(DeadFn->getType())); - // The function is now really dead and de-attached from everything. - DeadFn->eraseFromParent(); + if (LCG && !ReplacedFunctions.count(DeadFn)) { + // Taken mostly from the inliner: + LazyCallGraph::Node &N = LCG->get(*DeadFn); + auto *DeadSCC = LCG->lookupSCC(N); + assert(DeadSCC && DeadSCC->size() == 1 && + &DeadSCC->begin()->getFunction() == DeadFn); + auto &DeadRC = DeadSCC->getOuterRefSCC(); + + FunctionAnalysisManager &FAM = + AM->getResult(*DeadSCC, *LCG) + .getManager(); + + FAM.clear(*DeadFn, DeadFn->getName()); + AM->clear(*DeadSCC, DeadSCC->getName()); + LCG->removeDeadFunction(*DeadFn); + + // Mark the relevant parts of the call graph as invalid so we don't + // visit them. + UR->InvalidatedSCCs.insert(DeadSCC); + UR->InvalidatedRefSCCs.insert(&DeadRC); + } + + // The function is now really dead and de-attached from everything. + DeadFn->eraseFromParent(); + } } bool Changed = !DeadFunctions.empty(); diff --git a/llvm/unittests/IR/LegacyPassManagerTest.cpp b/llvm/unittests/IR/LegacyPassManagerTest.cpp --- a/llvm/unittests/IR/LegacyPassManagerTest.cpp +++ b/llvm/unittests/IR/LegacyPassManagerTest.cpp @@ -560,6 +560,21 @@ return mod; } + /// Split a simple function which contains only a call and a return into two + /// such that the first calls the second and the second whoever was called + /// initially. + Function *splitSimpleFunction(Function &F) { + LLVMContext &Context = F.getContext(); + Function *SF = Function::Create(F.getFunctionType(), F.getLinkage(), + F.getName() + "b", F.getParent()); + F.setName(F.getName() + "a"); + BasicBlock *Entry = BasicBlock::Create(Context, "entry", SF, nullptr); + CallInst &CI = cast(F.getEntryBlock().front()); + CI.clone()->insertBefore(ReturnInst::Create(Context, Entry)); + CI.setCalledFunction(SF); + return SF; + } + struct CGModifierPass : public CGPass { unsigned NumSCCs = 0; unsigned NumFns = 0; @@ -582,7 +597,8 @@ Function *F = N->getFunction(); Module *M = F->getParent(); Function *Test1F = M->getFunction("test1"); - Function *Test2F = M->getFunction("test2"); + Function *Test2aF = M->getFunction("test2a"); + Function *Test2bF = M->getFunction("test2b"); Function *Test3F = M->getFunction("test3"); auto InSCC = [&](Function *Fn) { return llvm::any_of(SCMM, [Fn](CallGraphNode *CGN) { @@ -590,18 +606,19 @@ }); }; - if (!Test1F || !Test2F || !Test3F || !InSCC(Test1F) || !InSCC(Test2F) || - !InSCC(Test3F)) + if (!Test1F || !Test2aF || !Test2bF || !Test3F || !InSCC(Test1F) || + !InSCC(Test2aF) || !InSCC(Test2bF) || !InSCC(Test3F)) return SetupWorked = false; CallInst *CI = dyn_cast(&Test1F->getEntryBlock().front()); - if (!CI || CI->getCalledFunction() != Test2F) + if (!CI || CI->getCalledFunction() != Test2aF) return SetupWorked = false; CI->setCalledFunction(Test3F); CGU.initialize(const_cast(SCMM.getCallGraph()), SCMM); - CGU.removeFunction(*Test2F); + CGU.removeFunction(*Test2aF); + CGU.removeFunction(*Test2bF); CGU.reanalyzeFunction(*Test1F); return true; } @@ -610,20 +627,24 @@ }; TEST(PassManager, CallGraphUpdater0) { - // SCC#1: test1->test2->test3->test1 + // SCC#1: test1->test2a->test2b->test3->test1 // SCC#2: test4 // SCC#3: indirect call node LLVMContext Context; std::unique_ptr M(makeLLVMModule(Context)); ASSERT_EQ(M->getFunctionList().size(), 4U); + Function *F = M->getFunction("test2"); + Function *SF = splitSimpleFunction(*F); + CallInst::Create(F, "", &SF->getEntryBlock()); + ASSERT_EQ(M->getFunctionList().size(), 5U); CGModifierPass *P = new CGModifierPass(); legacy::PassManager Passes; Passes.add(P); Passes.run(*M); ASSERT_TRUE(P->SetupWorked); ASSERT_EQ(P->NumSCCs, 3U); - ASSERT_EQ(P->NumFns, 4U); + ASSERT_EQ(P->NumFns, 5U); ASSERT_EQ(M->getFunctionList().size(), 3U); } }