diff --git a/llvm/lib/Passes/StandardInstrumentations.cpp b/llvm/lib/Passes/StandardInstrumentations.cpp --- a/llvm/lib/Passes/StandardInstrumentations.cpp +++ b/llvm/lib/Passes/StandardInstrumentations.cpp @@ -1067,6 +1067,23 @@ AnalysisKey PreservedFunctionHashAnalysis::Key; +struct PreservedModuleHashAnalysis + : public AnalysisInfoMixin { + static AnalysisKey Key; + + struct ModuleHash { + uint64_t Hash; + }; + + using Result = ModuleHash; + + Result run(Module &F, ModuleAnalysisManager &FAM) { + return Result{StructuralHash(F)}; + } +}; + +AnalysisKey PreservedModuleHashAnalysis::Key; + bool PreservedCFGCheckerInstrumentation::CFG::invalidate( Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &) { @@ -1106,6 +1123,7 @@ if (!Registered) { FAM.registerPass([&] { return PreservedCFGCheckerAnalysis(); }); FAM.registerPass([&] { return PreservedFunctionHashAnalysis(); }); + MAM.registerPass([&] { return PreservedModuleHashAnalysis(); }); Registered = true; } @@ -1114,6 +1132,11 @@ FAM.getResult(*F); FAM.getResult(*F); } + + if (auto *MaybeM = any_cast(&IR)) { + Module &M = **const_cast(MaybeM); + MAM.getResult(M); + } }); PIC.registerAfterPassInvalidatedCallback( @@ -1169,6 +1192,16 @@ CheckCFG(P, F->getName(), *GraphBefore, CFG(F, /* TrackBBLifetime */ false)); } + if (auto *MaybeM = any_cast(&IR)) { + Module &M = **const_cast(MaybeM); + if (auto *HashBefore = + MAM.getCachedResult(M)) { + if (HashBefore->Hash != StructuralHash(M)) { + report_fatal_error(formatv( + "Module changed by {0} without invalidating analyses", P)); + } + } + } }); } diff --git a/llvm/unittests/IR/PassManagerTest.cpp b/llvm/unittests/IR/PassManagerTest.cpp --- a/llvm/unittests/IR/PassManagerTest.cpp +++ b/llvm/unittests/IR/PassManagerTest.cpp @@ -985,6 +985,7 @@ MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); }); MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); + FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); }); FunctionPassManager FPM; FPM.addPass(WrongFunctionPass()); @@ -998,7 +999,10 @@ for (Function &F : M) F.getEntryBlock().begin()->eraseFromParent(); - return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserveSet>(); + PA.preserve(); + return PA; } static StringRef name() { return "WrongModulePass"; } }; @@ -1018,6 +1022,7 @@ MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); }); MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); + FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); }); ModulePassManager MPM; MPM.addPass(WrongModulePass()); @@ -1027,5 +1032,43 @@ "Function @foo changed by WrongModulePass without invalidating analyses"); } +struct WrongModulePass2 : PassInfoMixin { + PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM) { + for (Function &F : M) + F.getEntryBlock().begin()->eraseFromParent(); + + PreservedAnalyses PA; + PA.preserveSet>(); + PA.abandon(); + return PA; + } + static StringRef name() { return "WrongModulePass2"; } +}; + +TEST_F(PassManagerTest, ModulePassMissedModuleAnalysisInvalidation) { + LLVMContext Context; + auto M = parseIR(Context, "define void @foo() {\n" + " %a = add i32 0, 0\n" + " ret void\n" + "}\n"); + + FunctionAnalysisManager FAM; + ModuleAnalysisManager MAM; + PassInstrumentationCallbacks PIC; + StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ false); + SI.registerCallbacks(PIC, &MAM); + MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); }); + MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); + FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); + FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); }); + + ModulePassManager MPM; + MPM.addPass(WrongModulePass2()); + + EXPECT_DEATH( + MPM.run(*M, MAM), + "Module changed by WrongModulePass2 without invalidating analyses"); +} + #endif }