diff --git a/llvm/lib/Analysis/CallGraph.cpp b/llvm/lib/Analysis/CallGraph.cpp --- a/llvm/lib/Analysis/CallGraph.cpp +++ b/llvm/lib/Analysis/CallGraph.cpp @@ -281,13 +281,37 @@ I->second = NewNode; NewNode->AddRef(); - // Refresh callback references. - forEachCallbackFunction(Call, [=](Function *CB) { - removeOneAbstractEdgeTo(CG->getOrInsertFunction(CB)); + // Refresh callback references. Do not resize CalledFunctions if the + // number of callbacks is the same for new and old call sites. + SmallVector OldCBs; + SmallVector NewCBs; + forEachCallbackFunction(Call, [this, &OldCBs](Function *CB) { + OldCBs.push_back(CG->getOrInsertFunction(CB)); }); - forEachCallbackFunction(NewCall, [=](Function *CB) { - addCalledFunction(nullptr, CG->getOrInsertFunction(CB)); + forEachCallbackFunction(NewCall, [this, &NewCBs](Function *CB) { + NewCBs.push_back(CG->getOrInsertFunction(CB)); }); + if (OldCBs.size() == NewCBs.size()) { + for (unsigned N = 0; N < OldCBs.size(); ++N) { + CallGraphNode *OldNode = OldCBs[N]; + CallGraphNode *NewNode = NewCBs[N]; + for (auto J = CalledFunctions.begin();; ++J) { + assert(J != CalledFunctions.end() && + "Cannot find callsite to update!"); + if (!J->first && J->second == OldNode) { + J->second = NewNode; + OldNode->DropRef(); + NewNode->AddRef(); + break; + } + } + } + } else { + for (auto *CGN : OldCBs) + removeOneAbstractEdgeTo(CGN); + for (auto *CGN : NewCBs) + addCalledFunction(nullptr, CGN); + } return; } } diff --git a/llvm/unittests/IR/LegacyPassManagerTest.cpp b/llvm/unittests/IR/LegacyPassManagerTest.cpp --- a/llvm/unittests/IR/LegacyPassManagerTest.cpp +++ b/llvm/unittests/IR/LegacyPassManagerTest.cpp @@ -16,6 +16,8 @@ #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/AbstractCallSite.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallingConv.h" #include "llvm/IR/DataLayout.h" @@ -28,6 +30,7 @@ #include "llvm/IR/OptBisect.h" #include "llvm/InitializePasses.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/CallGraphUpdater.h" #include "gtest/gtest.h" @@ -694,6 +697,89 @@ ASSERT_EQ(P->NumExtCalledBefore, /* test1, 2a, 2b, 3, 4 */ 5U); ASSERT_EQ(P->NumExtCalledAfter, /* test1, 3repl, 4 */ 3U); } + + // Test for call graph SCC pass that replaces all callback call instructions + // with clones and updates CallGraph by calling CallGraph::replaceCallEdge() + // method. Test is expected to complete successfully after running pass on + // all SCCs in the test module. + struct CallbackCallsModifierPass : public CGPass { + bool runOnSCC(CallGraphSCC &SCC) override { + CGPass::run(); + + CallGraph &CG = const_cast(SCC.getCallGraph()); + + bool Changed = false; + for (CallGraphNode *CGN : SCC) { + Function *F = CGN->getFunction(); + if (!F || F->isDeclaration()) + continue; + + SmallVector Calls; + for (Use &U : F->uses()) { + AbstractCallSite ACS(&U); + if (!ACS || !ACS.isCallbackCall() || !ACS.isCallee(&U)) + continue; + Calls.push_back(cast(ACS.getInstruction())); + } + if (Calls.empty()) + continue; + + for (CallBase *OldCB : Calls) { + CallGraphNode *CallerCGN = CG[OldCB->getParent()->getParent()]; + assert(any_of(*CallerCGN, + [CGN](const CallGraphNode::CallRecord &CallRecord) { + return CallRecord.second == CGN; + }) && + "function is not a callee"); + + CallBase *NewCB = cast(OldCB->clone()); + + NewCB->insertBefore(OldCB); + NewCB->takeName(OldCB); + + CallerCGN->replaceCallEdge(*OldCB, *NewCB, CG[F]); + + OldCB->replaceAllUsesWith(NewCB); + OldCB->eraseFromParent(); + } + Changed = true; + } + return Changed; + } + }; + + TEST(PassManager, CallbackCallsModifier0) { + LLVMContext Context; + + const char *IR = "define void @foo() {\n" + " call void @broker(void (i8*)* @callback0, i8* null)\n" + " call void @broker(void (i8*)* @callback1, i8* null)\n" + " ret void\n" + "}\n" + "\n" + "declare !callback !0 void @broker(void (i8*)*, i8*)\n" + "\n" + "define internal void @callback0(i8* %arg) {\n" + " ret void\n" + "}\n" + "\n" + "define internal void @callback1(i8* %arg) {\n" + " ret void\n" + "}\n" + "\n" + "!0 = !{!1}\n" + "!1 = !{i64 0, i64 1, i1 false}"; + + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString(IR, Err, Context); + if (!M) + Err.print("LegacyPassManagerTest", errs()); + + CallbackCallsModifierPass *P = new CallbackCallsModifierPass(); + legacy::PassManager Passes; + Passes.add(P); + Passes.run(*M); + } } }