diff --git a/llvm/include/llvm/IR/PassInstrumentation.h b/llvm/include/llvm/IR/PassInstrumentation.h --- a/llvm/include/llvm/IR/PassInstrumentation.h +++ b/llvm/include/llvm/IR/PassInstrumentation.h @@ -67,11 +67,14 @@ // to take them as constant pointers, wrapped with llvm::Any. // For the case when IRUnit has been invalidated there is a different // callback to use - AfterPassInvalidated. + // We call all BeforePassFuncs to determine if a pass should run or not. + // BeforeNonSkippedPassFuncs are called only if the pass should run. // TODO: currently AfterPassInvalidated does not accept IRUnit, since passing - // already invalidated IRUnit is unsafe. There are ways to handle invalidated IRUnits - // in a safe way, and we might pursue that as soon as there is a useful instrumentation - // that needs it. + // already invalidated IRUnit is unsafe. There are ways to handle invalidated + // IRUnits in a safe way, and we might pursue that as soon as there is a + // useful instrumentation that needs it. using BeforePassFunc = bool(StringRef, Any); + using BeforeNonSkippedPassFunc = void(StringRef, Any); using AfterPassFunc = void(StringRef, Any); using AfterPassInvalidatedFunc = void(StringRef); using BeforeAnalysisFunc = void(StringRef, Any); @@ -88,6 +91,11 @@ BeforePassCallbacks.emplace_back(std::move(C)); } + template + void registerBeforeNonSkippedPassCallback(CallableT C) { + BeforeNonSkippedPassCallbacks.emplace_back(std::move(C)); + } + template void registerAfterPassCallback(CallableT C) { AfterPassCallbacks.emplace_back(std::move(C)); } @@ -111,6 +119,8 @@ friend class PassInstrumentation; SmallVector, 4> BeforePassCallbacks; + SmallVector, 4> + BeforeNonSkippedPassCallbacks; SmallVector, 4> AfterPassCallbacks; SmallVector, 4> AfterPassInvalidatedCallbacks; @@ -165,6 +175,12 @@ for (auto &C : Callbacks->BeforePassCallbacks) ShouldRun &= C(Pass.name(), llvm::Any(&IR)); ShouldRun = ShouldRun || isRequired(Pass); + + if (ShouldRun) { + for (auto &C : Callbacks->BeforeNonSkippedPassCallbacks) + C(Pass.name(), llvm::Any(&IR)); + } + return ShouldRun; } diff --git a/llvm/unittests/IR/PassBuilderCallbacksTest.cpp b/llvm/unittests/IR/PassBuilderCallbacksTest.cpp --- a/llvm/unittests/IR/PassBuilderCallbacksTest.cpp +++ b/llvm/unittests/IR/PassBuilderCallbacksTest.cpp @@ -320,6 +320,7 @@ ON_CALL(*this, runBeforePass(_, _)).WillByDefault(Return(true)); } MOCK_METHOD2(runBeforePass, bool(StringRef PassID, llvm::Any)); + MOCK_METHOD2(runBeforeNonSkippedPass, void(StringRef PassID, llvm::Any)); MOCK_METHOD2(runAfterPass, void(StringRef PassID, llvm::Any)); MOCK_METHOD1(runAfterPassInvalidated, void(StringRef PassID)); MOCK_METHOD2(runBeforeAnalysis, void(StringRef PassID, llvm::Any)); @@ -329,6 +330,10 @@ Callbacks.registerBeforePassCallback([this](StringRef P, llvm::Any IR) { return this->runBeforePass(P, IR); }); + Callbacks.registerBeforeNonSkippedPassCallback( + [this](StringRef P, llvm::Any IR) { + this->runBeforeNonSkippedPass(P, IR); + }); Callbacks.registerAfterPassCallback( [this](StringRef P, llvm::Any IR) { this->runAfterPass(P, IR); }); Callbacks.registerAfterPassInvalidatedCallback( @@ -349,6 +354,9 @@ EXPECT_CALL(*this, runBeforePass(Not(HasNameRegex("Mock")), HasName(IRName))) .Times(AnyNumber()); + EXPECT_CALL(*this, runBeforeNonSkippedPass(Not(HasNameRegex("Mock")), + HasName(IRName))) + .Times(AnyNumber()); EXPECT_CALL(*this, runAfterPass(Not(HasNameRegex("Mock")), HasName(IRName))) .Times(AnyNumber()); EXPECT_CALL(*this, @@ -500,6 +508,10 @@ EXPECT_CALL(CallbacksHandle, runBeforePass(HasNameRegex("MockPassHandle"), HasName(""))) .InSequence(PISequence); + EXPECT_CALL(CallbacksHandle, + runBeforeNonSkippedPass(HasNameRegex("MockPassHandle"), + HasName(""))) + .InSequence(PISequence); EXPECT_CALL(CallbacksHandle, runBeforeAnalysis(HasNameRegex("MockAnalysisHandle"), HasName(""))) @@ -532,8 +544,11 @@ EXPECT_CALL(AnalysisHandle, run(HasName(""), _)).Times(0); EXPECT_CALL(PassHandle, run(HasName(""), _)).Times(0); - // As the pass is skipped there is no afterPass, beforeAnalysis/afterAnalysis - // as well. + // As the pass is skipped there is no nonskippedpass/afterPass, + // beforeAnalysis/afterAnalysis as well. + EXPECT_CALL(CallbacksHandle, + runBeforeNonSkippedPass(HasNameRegex("MockPassHandle"), _)) + .Times(0); EXPECT_CALL(CallbacksHandle, runAfterPass(HasNameRegex("MockPassHandle"), _)) .Times(0); EXPECT_CALL(CallbacksHandle, @@ -545,12 +560,35 @@ // Order is important here. `Adaptor` expectations should be checked first // because the its argument contains 'PassManager' (for example: - // ModuleToFunctionPassAdaptor{{.*}}PassManager{{.*}}). Here only check - // `runAfterPass` to show that they are not skipped. - + // ModuleToFunctionPassAdaptor{{.*}}PassManager{{.*}}). Check + // `runBeforeNonSkippedPass` and `runAfterPass` to show that they are not + // skipped. + // // Pass managers are not ignored. // 5 = (1) ModulePassManager + (2) FunctionPassMangers + (1) LoopPassManager + // (1) CGSCCPassManager + EXPECT_CALL(CallbacksHandle, + runBeforeNonSkippedPass(HasNameRegex("PassManager"), _)) + .Times(5); + EXPECT_CALL( + CallbacksHandle, + runBeforeNonSkippedPass(HasNameRegex("ModuleToFunctionPassAdaptor"), _)) + .Times(1); + EXPECT_CALL(CallbacksHandle, + runBeforeNonSkippedPass( + HasNameRegex("ModuleToPostOrderCGSCCPassAdaptor"), _)) + .Times(1); + EXPECT_CALL( + CallbacksHandle, + runBeforeNonSkippedPass(HasNameRegex("CGSCCToFunctionPassAdaptor"), _)) + .Times(1); + EXPECT_CALL( + CallbacksHandle, + runBeforeNonSkippedPass(HasNameRegex("FunctionToLoopPassAdaptor"), _)) + .Times(1); + + // The `runAfterPass` checks are the same as these of + // `runBeforeNonSkippedPass`. EXPECT_CALL(CallbacksHandle, runAfterPass(HasNameRegex("PassManager"), _)) .Times(5); EXPECT_CALL(CallbacksHandle, @@ -630,6 +668,10 @@ EXPECT_CALL(CallbacksHandle, runBeforePass(HasNameRegex("MockPassHandle"), HasName("foo"))) .InSequence(PISequence); + EXPECT_CALL( + CallbacksHandle, + runBeforeNonSkippedPass(HasNameRegex("MockPassHandle"), HasName("foo"))) + .InSequence(PISequence); EXPECT_CALL( CallbacksHandle, runBeforeAnalysis(HasNameRegex("MockAnalysisHandle"), HasName("foo"))) @@ -717,6 +759,10 @@ EXPECT_CALL(CallbacksHandle, runBeforePass(HasNameRegex("MockPassHandle"), HasName("loop"))) .InSequence(PISequence); + EXPECT_CALL( + CallbacksHandle, + runBeforeNonSkippedPass(HasNameRegex("MockPassHandle"), HasName("loop"))) + .InSequence(PISequence); EXPECT_CALL( CallbacksHandle, runBeforeAnalysis(HasNameRegex("MockAnalysisHandle"), HasName("loop"))) @@ -758,6 +804,10 @@ EXPECT_CALL(CallbacksHandle, runBeforePass(HasNameRegex("MockPassHandle"), HasName("loop"))) .InSequence(PISequence); + EXPECT_CALL( + CallbacksHandle, + runBeforeNonSkippedPass(HasNameRegex("MockPassHandle"), HasName("loop"))) + .InSequence(PISequence); EXPECT_CALL( CallbacksHandle, runBeforeAnalysis(HasNameRegex("MockAnalysisHandle"), HasName("loop"))) @@ -847,6 +897,10 @@ EXPECT_CALL(CallbacksHandle, runBeforePass(HasNameRegex("MockPassHandle"), HasName("(foo)"))) .InSequence(PISequence); + EXPECT_CALL( + CallbacksHandle, + runBeforeNonSkippedPass(HasNameRegex("MockPassHandle"), HasName("(foo)"))) + .InSequence(PISequence); EXPECT_CALL( CallbacksHandle, runBeforeAnalysis(HasNameRegex("MockAnalysisHandle"), HasName("(foo)"))) @@ -888,6 +942,10 @@ EXPECT_CALL(CallbacksHandle, runBeforePass(HasNameRegex("MockPassHandle"), HasName("(foo)"))) .InSequence(PISequence); + EXPECT_CALL( + CallbacksHandle, + runBeforeNonSkippedPass(HasNameRegex("MockPassHandle"), HasName("(foo)"))) + .InSequence(PISequence); EXPECT_CALL( CallbacksHandle, runBeforeAnalysis(HasNameRegex("MockAnalysisHandle"), HasName("(foo)")))