diff --git a/llvm/include/llvm/Passes/StandardInstrumentations.h b/llvm/include/llvm/Passes/StandardInstrumentations.h --- a/llvm/include/llvm/Passes/StandardInstrumentations.h +++ b/llvm/include/llvm/Passes/StandardInstrumentations.h @@ -153,7 +153,7 @@ #endif void registerCallbacks(PassInstrumentationCallbacks &PIC, - FunctionAnalysisManager &FAM); + ModuleAnalysisManager &MAM); }; // Base class for classes that report changes to the IR. @@ -574,7 +574,7 @@ // Register all the standard instrumentation callbacks. If \p FAM is nullptr // then PreservedCFGChecker is not enabled. void registerCallbacks(PassInstrumentationCallbacks &PIC, - FunctionAnalysisManager *FAM = nullptr); + ModuleAnalysisManager *MAM = nullptr); TimePassesHandler &getTimePasses() { return TimePasses; } }; diff --git a/llvm/lib/LTO/LTOBackend.cpp b/llvm/lib/LTO/LTOBackend.cpp --- a/llvm/lib/LTO/LTOBackend.cpp +++ b/llvm/lib/LTO/LTOBackend.cpp @@ -260,7 +260,7 @@ PassInstrumentationCallbacks PIC; StandardInstrumentations SI(Mod.getContext(), Conf.DebugPassManager); - SI.registerCallbacks(PIC, &FAM); + SI.registerCallbacks(PIC, &MAM); PassBuilder PB(TM, Conf.PTO, PGOOpt, &PIC); RegisterPassPlugins(Conf.PassPlugins, PB); diff --git a/llvm/lib/LTO/ThinLTOCodeGenerator.cpp b/llvm/lib/LTO/ThinLTOCodeGenerator.cpp --- a/llvm/lib/LTO/ThinLTOCodeGenerator.cpp +++ b/llvm/lib/LTO/ThinLTOCodeGenerator.cpp @@ -245,7 +245,7 @@ PassInstrumentationCallbacks PIC; StandardInstrumentations SI(TheModule.getContext(), DebugPassManager); - SI.registerCallbacks(PIC, &FAM); + SI.registerCallbacks(PIC, &MAM); PipelineTuningOptions PTO; PTO.LoopVectorization = true; PTO.SLPVectorization = true; diff --git a/llvm/lib/Passes/PassBuilderBindings.cpp b/llvm/lib/Passes/PassBuilderBindings.cpp --- a/llvm/lib/Passes/PassBuilderBindings.cpp +++ b/llvm/lib/Passes/PassBuilderBindings.cpp @@ -66,7 +66,7 @@ PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); StandardInstrumentations SI(Mod->getContext(), Debug, VerifyEach); - SI.registerCallbacks(PIC, &FAM); + SI.registerCallbacks(PIC, &MAM); ModulePassManager MPM; if (VerifyEach) { MPM.addPass(VerifierPass()); diff --git a/llvm/lib/Passes/StandardInstrumentations.cpp b/llvm/lib/Passes/StandardInstrumentations.cpp --- a/llvm/lib/Passes/StandardInstrumentations.cpp +++ b/llvm/lib/Passes/StandardInstrumentations.cpp @@ -1075,29 +1075,46 @@ PAC.preservedSet()); } +static SmallVector GetFunctions(Any IR) { + SmallVector Functions; + + if (const auto **MaybeF = any_cast(&IR)) { + Functions.push_back(*const_cast(MaybeF)); + } else if (const auto **MaybeM = any_cast(&IR)) { + for (Function &F : **const_cast(MaybeM)) + Functions.push_back(&F); + } + return Functions; +} + void PreservedCFGCheckerInstrumentation::registerCallbacks( - PassInstrumentationCallbacks &PIC, FunctionAnalysisManager &FAM) { + PassInstrumentationCallbacks &PIC, ModuleAnalysisManager &MAM) { if (!VerifyAnalysisInvalidation) return; - FAM.registerPass([&] { return PreservedCFGCheckerAnalysis(); }); - FAM.registerPass([&] { return PreservedFunctionHashAnalysis(); }); - - PIC.registerBeforeNonSkippedPassCallback( - [this, &FAM](StringRef P, Any IR) { + bool Registered = false; + PIC.registerBeforeNonSkippedPassCallback([this, &MAM, Registered]( + StringRef P, Any IR) mutable { #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS - assert(&PassStack.emplace_back(P)); + assert(&PassStack.emplace_back(P)); #endif - (void)this; - const auto **F = any_cast(&IR); - if (!F) - return; + (void)this; - // Make sure a fresh CFG snapshot is available before the pass. - FAM.getResult(*const_cast(*F)); - FAM.getResult( - *const_cast(*F)); - }); + auto &FAM = MAM.getResult( + *const_cast(unwrapModule(IR, /*Force=*/true))) + .getManager(); + if (!Registered) { + FAM.registerPass([&] { return PreservedCFGCheckerAnalysis(); }); + FAM.registerPass([&] { return PreservedFunctionHashAnalysis(); }); + Registered = true; + } + + for (Function *F : GetFunctions(IR)) { + // Make sure a fresh CFG snapshot is available before the pass. + FAM.getResult(*F); + FAM.getResult(*F); + } + }); PIC.registerAfterPassInvalidatedCallback( [this](StringRef P, const PreservedAnalyses &PassPA) { @@ -1108,7 +1125,7 @@ (void)this; }); - PIC.registerAfterPassCallback([this, &FAM](StringRef P, Any IR, + PIC.registerAfterPassCallback([this, &MAM](StringRef P, Any IR, const PreservedAnalyses &PassPA) { #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS assert(PassStack.pop_back_val() == P && @@ -1116,36 +1133,42 @@ #endif (void)this; - const auto **MaybeF = any_cast(&IR); - if (!MaybeF) - return; - Function &F = *const_cast(*MaybeF); - - if (auto *HashBefore = - FAM.getCachedResult(F)) { - if (HashBefore->Hash != StructuralHash(F)) { - report_fatal_error(formatv( - "Function @{0} changed by {1} without invalidating analyses", - F.getName(), P)); + // We have to get the FAM via the MAM, rather than directly use a passed in + // FAM because if MAM has not cached the FAM, it won't invalidate function + // analyses in FAM. + auto &FAM = MAM.getResult( + *const_cast(unwrapModule(IR, /*Force=*/true))) + .getManager(); + + for (Function *F : GetFunctions(IR)) { + if (auto *HashBefore = + FAM.getCachedResult(*F)) { + if (HashBefore->Hash != StructuralHash(*F)) { + report_fatal_error(formatv( + "Function @{0} changed by {1} without invalidating analyses", + F->getName(), P)); + } } - } - auto CheckCFG = [](StringRef Pass, StringRef FuncName, - const CFG &GraphBefore, const CFG &GraphAfter) { - if (GraphAfter == GraphBefore) - return; - - dbgs() << "Error: " << Pass - << " does not invalidate CFG analyses but CFG changes detected in " - "function @" - << FuncName << ":\n"; - CFG::printDiff(dbgs(), GraphBefore, GraphAfter); - report_fatal_error(Twine("CFG unexpectedly changed by ", Pass)); - }; - - if (auto *GraphBefore = FAM.getCachedResult(F)) - CheckCFG(P, F.getName(), *GraphBefore, - CFG(&F, /* TrackBBLifetime */ false)); + auto CheckCFG = [](StringRef Pass, StringRef FuncName, + const CFG &GraphBefore, const CFG &GraphAfter) { + if (GraphAfter == GraphBefore) + return; + + dbgs() + << "Error: " << Pass + << " does not invalidate CFG analyses but CFG changes detected in " + "function @" + << FuncName << ":\n"; + CFG::printDiff(dbgs(), GraphBefore, GraphAfter); + report_fatal_error(Twine("CFG unexpectedly changed by ", Pass)); + }; + + if (auto *GraphBefore = + FAM.getCachedResult(*F)) + CheckCFG(P, F->getName(), *GraphBefore, + CFG(F, /* TrackBBLifetime */ false)); + } }); } @@ -2175,7 +2198,7 @@ } void StandardInstrumentations::registerCallbacks( - PassInstrumentationCallbacks &PIC, FunctionAnalysisManager *FAM) { + PassInstrumentationCallbacks &PIC, ModuleAnalysisManager *MAM) { PrintIR.registerCallbacks(PIC); PrintPass.registerCallbacks(PIC); TimePasses.registerCallbacks(PIC); @@ -2189,8 +2212,8 @@ WebsiteChangeReporter.registerCallbacks(PIC); ChangeTester.registerCallbacks(PIC); PrintCrashIR.registerCallbacks(PIC); - if (FAM) - PreservedCFGChecker.registerCallbacks(PIC, *FAM); + if (MAM) + PreservedCFGChecker.registerCallbacks(PIC, *MAM); // TimeProfiling records the pass running time cost. // Its 'BeforePassCallback' can be appended at the tail of all the diff --git a/llvm/tools/opt/NewPMDriver.cpp b/llvm/tools/opt/NewPMDriver.cpp --- a/llvm/tools/opt/NewPMDriver.cpp +++ b/llvm/tools/opt/NewPMDriver.cpp @@ -395,7 +395,7 @@ PrintPassOpts.SkipAnalyses = DebugPM == DebugLogging::Quiet; StandardInstrumentations SI(M.getContext(), DebugPM != DebugLogging::None, VerifyEachPass, PrintPassOpts); - SI.registerCallbacks(PIC, &FAM); + SI.registerCallbacks(PIC, &MAM); DebugifyEachInstrumentation Debugify; DebugifyStatsMap DIStatsMap; DebugInfoPerPass DebugInfoBeforePass; diff --git a/llvm/unittests/IR/PassManagerTest.cpp b/llvm/unittests/IR/PassManagerTest.cpp --- a/llvm/unittests/IR/PassManagerTest.cpp +++ b/llvm/unittests/IR/PassManagerTest.cpp @@ -824,10 +824,13 @@ auto *F = M->getFunction("foo"); FunctionAnalysisManager FAM; + ModuleAnalysisManager MAM; FunctionPassManager FPM; PassInstrumentationCallbacks PIC; StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ true); - SI.registerCallbacks(PIC, &FAM); + SI.registerCallbacks(PIC, &MAM); + MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); + MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); }); FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); FAM.registerPass([&] { return DominatorTreeAnalysis(); }); FAM.registerPass([&] { return AssumptionAnalysis(); }); @@ -870,10 +873,13 @@ auto *F = M->getFunction("foo"); FunctionAnalysisManager FAM; + ModuleAnalysisManager MAM; FunctionPassManager FPM; PassInstrumentationCallbacks PIC; StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ true); - SI.registerCallbacks(PIC, &FAM); + SI.registerCallbacks(PIC, &MAM); + MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); }); + MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); FAM.registerPass([&] { return DominatorTreeAnalysis(); }); FAM.registerPass([&] { return AssumptionAnalysis(); }); @@ -935,10 +941,13 @@ auto *F = M->getFunction("foo"); FunctionAnalysisManager FAM; + ModuleAnalysisManager MAM; FunctionPassManager FPM; PassInstrumentationCallbacks PIC; StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ true); - SI.registerCallbacks(PIC, &FAM); + SI.registerCallbacks(PIC, &MAM); + MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); }); + MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); FAM.registerPass([&] { return DominatorTreeAnalysis(); }); FAM.registerPass([&] { return AssumptionAnalysis(); }); @@ -961,7 +970,7 @@ static StringRef name() { return "WrongFunctionPass"; } }; -TEST_F(PassManagerTest, FunctionAnalysisMissedInvalidation) { +TEST_F(PassManagerTest, FunctionPassMissedFunctionAnalysisInvalidation) { LLVMContext Context; auto M = parseIR(Context, "define void @foo() {\n" " %a = add i32 0, 0\n" @@ -969,9 +978,12 @@ "}\n"); FunctionAnalysisManager FAM; + ModuleAnalysisManager MAM; PassInstrumentationCallbacks PIC; StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ false); - SI.registerCallbacks(PIC, &FAM); + SI.registerCallbacks(PIC, &MAM); + MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); }); + MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); FunctionPassManager FPM; @@ -981,6 +993,39 @@ EXPECT_DEATH(FPM.run(*F, FAM), "Function @foo changed by WrongFunctionPass without invalidating analyses"); } -#endif +struct WrongModulePass : PassInfoMixin { + PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM) { + for (Function &F : M) + F.getEntryBlock().begin()->eraseFromParent(); + + return PreservedAnalyses::all(); + } + static StringRef name() { return "WrongModulePass"; } +}; + +TEST_F(PassManagerTest, ModulePassMissedFunctionAnalysisInvalidation) { + LLVMContext Context; + auto M = parseIR(Context, "define void @foo() {\n" + " %a = add i32 0, 0\n" + " ret void\n" + "}\n"); + + FunctionAnalysisManager FAM; + ModuleAnalysisManager MAM; + PassInstrumentationCallbacks PIC; + StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ false); + SI.registerCallbacks(PIC, &MAM); + MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); }); + MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); + FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); + ModulePassManager MPM; + MPM.addPass(WrongModulePass()); + + EXPECT_DEATH( + MPM.run(*M, MAM), + "Function @foo changed by WrongModulePass without invalidating analyses"); +} + +#endif }