diff --git a/llvm/include/llvm/Analysis/CGSCCPassManager.h b/llvm/include/llvm/Analysis/CGSCCPassManager.h --- a/llvm/include/llvm/Analysis/CGSCCPassManager.h +++ b/llvm/include/llvm/Analysis/CGSCCPassManager.h @@ -355,6 +355,8 @@ /// Runs the CGSCC pass across every SCC in the module. PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); + static bool isRequired() { return true; } + private: CGSCCPassT Pass; }; @@ -543,6 +545,8 @@ return PA; } + static bool isRequired() { return true; } + private: FunctionPassT Pass; }; diff --git a/llvm/include/llvm/IR/PassInstrumentation.h b/llvm/include/llvm/IR/PassInstrumentation.h --- a/llvm/include/llvm/IR/PassInstrumentation.h +++ b/llvm/include/llvm/IR/PassInstrumentation.h @@ -129,6 +129,26 @@ class PassInstrumentation { PassInstrumentationCallbacks *Callbacks; + // Template argument PassT of PassInstrumentation::runBeforePass could be two + // kinds: (1) a regular pass inherited from PassInfoMixin (happen when + // creating a adaptor pass for a regular pass); (2) a type-erased PassConcept + // created from (1). Here we want to make case (1) skippable unconditionally + // since they are regular passes. We call PassConcept::isRequired to decide + // for case (2). + template + using has_required_t = decltype(std::declval().isRequired()); + + template + static std::enable_if_t::value, bool> + isRequired(const PassT &Pass) { + return Pass.isRequired(); + } + template + static std::enable_if_t::value, bool> + isRequired(const PassT &Pass) { + return false; + } + public: /// Callbacks object is not owned by PassInstrumentation, its life-time /// should at least match the life-time of corresponding @@ -148,6 +168,7 @@ bool ShouldRun = true; for (auto &C : Callbacks->BeforePassCallbacks) ShouldRun &= C(Pass.name(), llvm::Any(&IR)); + ShouldRun = ShouldRun || isRequired(Pass); return ShouldRun; } diff --git a/llvm/include/llvm/IR/PassManager.h b/llvm/include/llvm/IR/PassManager.h --- a/llvm/include/llvm/IR/PassManager.h +++ b/llvm/include/llvm/IR/PassManager.h @@ -559,6 +559,8 @@ Passes.emplace_back(new PassModelT(std::move(Pass))); } + static bool isRequired() { return true; } + private: using PassConceptT = detail::PassConcept; @@ -1260,6 +1262,8 @@ return PA; } + static bool isRequired() { return true; } + private: FunctionPassT Pass; }; diff --git a/llvm/include/llvm/IR/PassManagerInternal.h b/llvm/include/llvm/IR/PassManagerInternal.h --- a/llvm/include/llvm/IR/PassManagerInternal.h +++ b/llvm/include/llvm/IR/PassManagerInternal.h @@ -48,6 +48,12 @@ /// Polymorphic method to access the name of a pass. virtual StringRef name() const = 0; + + /// Polymorphic method to to let a pass optionally exempted from skipping by + /// PassInstrumentation. + /// To opt-in, pass should implement `static bool isRequired()`. It's no-op + /// to have `isRequired` always return false since that is the default. + virtual bool isRequired() const = 0; }; /// A template wrapper used to implement the polymorphic API. @@ -81,6 +87,22 @@ StringRef name() const override { return PassT::name(); } + template + using has_required_t = decltype(std::declval().isRequired()); + + template + static std::enable_if_t::value, bool> + passIsRequiredImpl() { + return T::isRequired(); + } + template + static std::enable_if_t::value, bool> + passIsRequiredImpl() { + return false; + } + + bool isRequired() const override { return passIsRequiredImpl(); } + PassT Pass; }; diff --git a/llvm/include/llvm/Transforms/Scalar/LoopPassManager.h b/llvm/include/llvm/Transforms/Scalar/LoopPassManager.h --- a/llvm/include/llvm/Transforms/Scalar/LoopPassManager.h +++ b/llvm/include/llvm/Transforms/Scalar/LoopPassManager.h @@ -366,6 +366,8 @@ return PA; } + static bool isRequired() { return true; } + private: LoopPassT Pass; diff --git a/llvm/unittests/IR/PassBuilderCallbacksTest.cpp b/llvm/unittests/IR/PassBuilderCallbacksTest.cpp --- a/llvm/unittests/IR/PassBuilderCallbacksTest.cpp +++ b/llvm/unittests/IR/PassBuilderCallbacksTest.cpp @@ -524,10 +524,10 @@ // Non-mock instrumentation run here can safely be ignored. CallbacksHandle.ignoreNonMockPassInstrumentation(""); - // Skip the pass by returning false. - EXPECT_CALL(CallbacksHandle, runBeforePass(HasNameRegex("MockPassHandle"), - HasName(""))) - .WillOnce(Return(false)); + // Skip all passes by returning false. Pass managers and adaptor passes are + // also passes that observed by the callbacks. + EXPECT_CALL(CallbacksHandle, runBeforePass(_, _)) + .WillRepeatedly(Return(false)); EXPECT_CALL(AnalysisHandle, run(HasName(""), _)).Times(0); EXPECT_CALL(PassHandle, run(HasName(""), _)).Times(0); @@ -543,7 +543,60 @@ runAfterAnalysis(HasNameRegex("MockAnalysisHandle"), _)) .Times(0); - StringRef PipelineText = "test-transform"; + // Order is important here. `Adaptor` expectations should be checked first + // because the its argument contains 'PassManager' (for example: + // ModuleToFunctionPassAdaptor{{.*}}PassManager{{.*}}). Here only check + // `runAfterPass` to show that they are not skipped. + + // Pass managers are not ignored. + // 5 = (1) ModulePassManager + (2) FunctionPassMangers + (1) LoopPassManager + + // (1) CGSCCPassManager + EXPECT_CALL(CallbacksHandle, runAfterPass(HasNameRegex("PassManager"), _)) + .Times(5); + EXPECT_CALL(CallbacksHandle, + runAfterPass(HasNameRegex("ModuleToFunctionPassAdaptor"), _)) + .Times(1); + EXPECT_CALL( + CallbacksHandle, + runAfterPass(HasNameRegex("ModuleToPostOrderCGSCCPassAdaptor"), _)) + .Times(1); + EXPECT_CALL(CallbacksHandle, + runAfterPass(HasNameRegex("CGSCCToFunctionPassAdaptor"), _)) + .Times(1); + EXPECT_CALL(CallbacksHandle, + runAfterPass(HasNameRegex("FunctionToLoopPassAdaptor"), _)) + .Times(1); + + // Ignore analyses introduced by adaptor passes. + EXPECT_CALL(CallbacksHandle, + runBeforeAnalysis(Not(HasNameRegex("MockAnalysisHandle")), _)) + .Times(AnyNumber()); + EXPECT_CALL(CallbacksHandle, + runAfterAnalysis(Not(HasNameRegex("MockAnalysisHandle")), _)) + .Times(AnyNumber()); + + // Register Funtion and Loop version of "test-transform" for testing + PB.registerPipelineParsingCallback( + [](StringRef Name, FunctionPassManager &FPM, + ArrayRef) { + if (Name == "test-transform") { + FPM.addPass(MockPassHandle().getPass()); + return true; + } + return false; + }); + PB.registerPipelineParsingCallback( + [](StringRef Name, LoopPassManager &LPM, + ArrayRef) { + if (Name == "test-transform") { + LPM.addPass(MockPassHandle().getPass()); + return true; + } + return false; + }); + + StringRef PipelineText = "test-transform,function(test-transform),cgscc(" + "function(loop(test-transform)))"; ASSERT_THAT_ERROR(PB.parsePassPipeline(PM, PipelineText, true), Succeeded()) << "Pipeline was: " << PipelineText;