diff --git a/llvm/include/llvm/IR/PassManager.h b/llvm/include/llvm/IR/PassManager.h --- a/llvm/include/llvm/IR/PassManager.h +++ b/llvm/include/llvm/IR/PassManager.h @@ -559,6 +559,17 @@ Passes.emplace_back(std::move(P)); } + template + std::enable_if_t::value> + addPassBefore(PassT Pass, StringRef Other) { + using PassModelT = + detail::PassModel; + auto It = std::find_if(Passes.begin(), Passes.end(), + [&](const auto& Pass) { return Pass->name() == Other; }); + Passes.emplace(It, new PassModelT(std::move(Pass))); + } + /// Returns if the pass manager contains any passes. bool isEmpty() const { return Passes.empty(); } diff --git a/llvm/unittests/IR/PassManagerTest.cpp b/llvm/unittests/IR/PassManagerTest.cpp --- a/llvm/unittests/IR/PassManagerTest.cpp +++ b/llvm/unittests/IR/PassManagerTest.cpp @@ -703,6 +703,18 @@ FuncT Func; }; +struct LambdaModulePass : public PassInfoMixin { + using FuncT = std::function; + + LambdaModulePass(FuncT Func) : Func(std::move(Func)) {} + + PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM) { + return Func(M, AM); + } + + FuncT Func; +}; + TEST_F(PassManagerTest, IndirectAnalysisInvalidation) { FunctionAnalysisManager FAM; ModuleAnalysisManager MAM; @@ -950,4 +962,43 @@ FPM.addPass(TestSimplifyCFGWrapperPass(InnerFPM)); FPM.run(*F, FAM); } + +TEST_F(PassManagerTest, AddPassBefore) { + ModuleAnalysisManager MAM; + int ModuleAnalysisRuns = 0; + PassInstrumentationCallbacks PIC; + MAM.registerPass([&] { return TestModuleAnalysis(ModuleAnalysisRuns); }); + MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); }); + + ModulePassManager MPM1; + + int TestModulePassRunCount1 = 0; + int CustomModulePassRunCount1 = 0; + MPM1.addPass(TestModulePass(TestModulePassRunCount1)); + MPM1.addPass(LambdaModulePass([&](Module &, ModuleAnalysisManager &){ + EXPECT_EQ(1, TestModulePassRunCount1); + CustomModulePassRunCount1++; + return PreservedAnalyses::all(); + })); + + MPM1.run(*M, MAM); + EXPECT_EQ(1, TestModulePassRunCount1); + EXPECT_EQ(1, CustomModulePassRunCount1); + + ModulePassManager MPM2; + + int TestModulePassRunCount2 = 0; + int CustomModulePassRunCount2 = 0; + MPM2.addPass(TestModulePass(TestModulePassRunCount2)); + MPM2.addPassBefore(LambdaModulePass([&](Module &, ModuleAnalysisManager &){ + EXPECT_EQ(0, TestModulePassRunCount2); + CustomModulePassRunCount2++; + return PreservedAnalyses::all(); + }), "{anonymous}::TestModulePass"); + + MPM2.run(*M, MAM); + EXPECT_EQ(1, TestModulePassRunCount2); + EXPECT_EQ(1, CustomModulePassRunCount2); +} + }