diff --git a/llvm/include/llvm/IR/OptBisect.h b/llvm/include/llvm/IR/OptBisect.h --- a/llvm/include/llvm/IR/OptBisect.h +++ b/llvm/include/llvm/IR/OptBisect.h @@ -29,7 +29,8 @@ /// IRDescription is a textual description of the IR unit the pass is running /// over. - virtual bool shouldRunPass(const Pass *P, StringRef IRDescription) { + virtual bool shouldRunPass(const StringRef PassName, + StringRef IRDescription) { return true; } @@ -55,7 +56,8 @@ /// Checks the bisect limit to determine if the specified pass should run. /// /// This forwards to checkPass(). - bool shouldRunPass(const Pass *P, StringRef IRDescription) override; + bool shouldRunPass(const StringRef PassName, + StringRef IRDescription) override; /// isEnabled() should return true before calling shouldRunPass(). bool isEnabled() const override { return BisectLimit != Disabled; } @@ -89,7 +91,7 @@ /// Singleton instance of the OptBisect class, so multiple pass managers don't /// need to coordinate their uses of OptBisect. -OptBisect &getOptBisector(); +OptPassGate &getGlobalPassGate(); } // end namespace llvm diff --git a/llvm/include/llvm/Passes/StandardInstrumentations.h b/llvm/include/llvm/Passes/StandardInstrumentations.h --- a/llvm/include/llvm/Passes/StandardInstrumentations.h +++ b/llvm/include/llvm/Passes/StandardInstrumentations.h @@ -74,12 +74,11 @@ bool shouldRun(StringRef PassID, Any IR); }; -class OptBisectInstrumentation { +class OptPassGateInstrumentation { bool HasWrittenIR = false; - public: - OptBisectInstrumentation() = default; - void registerCallbacks(PassInstrumentationCallbacks &PIC); + OptPassGateInstrumentation() = default; + void registerCallbacks(PassInstrumentationCallbacks &PIC, LLVMContext *Cntxt); }; struct PrintPassOptions { @@ -528,7 +527,7 @@ TimePassesHandler TimePasses; TimeProfilingPassesHandler TimeProfilingPasses; OptNoneInstrumentation OptNone; - OptBisectInstrumentation OptBisect; + OptPassGateInstrumentation OptPassGate; PreservedCFGCheckerInstrumentation PreservedCFGChecker; IRChangedPrinter PrintChangedIR; PseudoProbeVerifier PseudoProbeVerification; @@ -546,7 +545,8 @@ // Register all the standard instrumentation callbacks. If \p FAM is nullptr // then PreservedCFGChecker is not enabled. void registerCallbacks(PassInstrumentationCallbacks &PIC, - FunctionAnalysisManager *FAM = nullptr); + FunctionAnalysisManager *FAM = nullptr, + LLVMContext *Cntxt = nullptr); TimePassesHandler &getTimePasses() { return TimePasses; } }; diff --git a/llvm/lib/Analysis/CallGraphSCCPass.cpp b/llvm/lib/Analysis/CallGraphSCCPass.cpp --- a/llvm/lib/Analysis/CallGraphSCCPass.cpp +++ b/llvm/lib/Analysis/CallGraphSCCPass.cpp @@ -751,7 +751,8 @@ bool CallGraphSCCPass::skipSCC(CallGraphSCC &SCC) const { OptPassGate &Gate = SCC.getCallGraph().getModule().getContext().getOptPassGate(); - return Gate.isEnabled() && !Gate.shouldRunPass(this, getDescription(SCC)); + return Gate.isEnabled() && + !Gate.shouldRunPass(this->getPassName(), getDescription(SCC)); } char DummyCGSCCPass::ID = 0; diff --git a/llvm/lib/Analysis/LoopPass.cpp b/llvm/lib/Analysis/LoopPass.cpp --- a/llvm/lib/Analysis/LoopPass.cpp +++ b/llvm/lib/Analysis/LoopPass.cpp @@ -373,7 +373,8 @@ return false; // Check the opt bisect limit. OptPassGate &Gate = F->getContext().getOptPassGate(); - if (Gate.isEnabled() && !Gate.shouldRunPass(this, getDescription(*L))) + if (Gate.isEnabled() && + !Gate.shouldRunPass(this->getPassName(), getDescription(*L))) return true; // Check for the OptimizeNone attribute. if (F->hasOptNone()) { diff --git a/llvm/lib/Analysis/RegionPass.cpp b/llvm/lib/Analysis/RegionPass.cpp --- a/llvm/lib/Analysis/RegionPass.cpp +++ b/llvm/lib/Analysis/RegionPass.cpp @@ -283,7 +283,8 @@ bool RegionPass::skipRegion(Region &R) const { Function &F = *R.getEntry()->getParent(); OptPassGate &Gate = F.getContext().getOptPassGate(); - if (Gate.isEnabled() && !Gate.shouldRunPass(this, getDescription(R))) + if (Gate.isEnabled() && + !Gate.shouldRunPass(this->getPassName(), getDescription(R))) return true; if (F.hasOptNone()) { diff --git a/llvm/lib/IR/LLVMContextImpl.cpp b/llvm/lib/IR/LLVMContextImpl.cpp --- a/llvm/lib/IR/LLVMContextImpl.cpp +++ b/llvm/lib/IR/LLVMContextImpl.cpp @@ -240,7 +240,7 @@ /// singleton OptBisect if not explicitly set. OptPassGate &LLVMContextImpl::getOptPassGate() const { if (!OPG) - OPG = &getOptBisector(); + OPG = &getGlobalPassGate(); return *OPG; } diff --git a/llvm/lib/IR/OptBisect.cpp b/llvm/lib/IR/OptBisect.cpp --- a/llvm/lib/IR/OptBisect.cpp +++ b/llvm/lib/IR/OptBisect.cpp @@ -20,10 +20,15 @@ using namespace llvm; +static OptBisect &getOptBisector() { + static OptBisect OptBisector; + return OptBisector; +} + static cl::opt OptBisectLimit("opt-bisect-limit", cl::Hidden, cl::init(OptBisect::Disabled), cl::Optional, cl::cb([](int Limit) { - llvm::getOptBisector().setLimit(Limit); + getOptBisector().setLimit(Limit); }), cl::desc("Maximum optimization to perform")); @@ -34,25 +39,16 @@ << "(" << PassNum << ") " << Name << " on " << TargetDesc << "\n"; } -bool OptBisect::shouldRunPass(const Pass *P, StringRef IRDescription) { - assert(isEnabled()); - - return checkPass(P->getPassName(), IRDescription); -} - -bool OptBisect::checkPass(const StringRef PassName, - const StringRef TargetDesc) { +bool OptBisect::shouldRunPass(const StringRef PassName, + StringRef IRDescription) { assert(isEnabled()); int CurBisectNum = ++LastBisectNum; bool ShouldRun = (BisectLimit == -1 || CurBisectNum <= BisectLimit); - printPassMessage(PassName, CurBisectNum, TargetDesc, ShouldRun); + printPassMessage(PassName, CurBisectNum, IRDescription, ShouldRun); return ShouldRun; } const int OptBisect::Disabled; -OptBisect &llvm::getOptBisector() { - static OptBisect OptBisector; - return OptBisector; -} +OptPassGate &llvm::getGlobalPassGate() { return getOptBisector(); } diff --git a/llvm/lib/IR/Pass.cpp b/llvm/lib/IR/Pass.cpp --- a/llvm/lib/IR/Pass.cpp +++ b/llvm/lib/IR/Pass.cpp @@ -62,7 +62,8 @@ bool ModulePass::skipModule(Module &M) const { OptPassGate &Gate = M.getContext().getOptPassGate(); - return Gate.isEnabled() && !Gate.shouldRunPass(this, getDescription(M)); + return Gate.isEnabled() && + !Gate.shouldRunPass(this->getPassName(), getDescription(M)); } bool Pass::mustPreserveAnalysisID(char &AID) const { @@ -172,7 +173,8 @@ bool FunctionPass::skipFunction(const Function &F) const { OptPassGate &Gate = F.getContext().getOptPassGate(); - if (Gate.isEnabled() && !Gate.shouldRunPass(this, getDescription(F))) + if (Gate.isEnabled() && + !Gate.shouldRunPass(this->getPassName(), getDescription(F))) return true; if (F.hasOptNone()) { diff --git a/llvm/lib/LTO/LTOBackend.cpp b/llvm/lib/LTO/LTOBackend.cpp --- a/llvm/lib/LTO/LTOBackend.cpp +++ b/llvm/lib/LTO/LTOBackend.cpp @@ -257,7 +257,7 @@ PassInstrumentationCallbacks PIC; StandardInstrumentations SI(Conf.DebugPassManager); - SI.registerCallbacks(PIC, &FAM); + SI.registerCallbacks(PIC, &FAM, &Mod.getContext()); PassBuilder PB(TM, Conf.PTO, PGOOpt, &PIC); RegisterPassPlugins(Conf.PassPlugins, PB); diff --git a/llvm/lib/LTO/ThinLTOCodeGenerator.cpp b/llvm/lib/LTO/ThinLTOCodeGenerator.cpp --- a/llvm/lib/LTO/ThinLTOCodeGenerator.cpp +++ b/llvm/lib/LTO/ThinLTOCodeGenerator.cpp @@ -245,7 +245,7 @@ PassInstrumentationCallbacks PIC; StandardInstrumentations SI(DebugPassManager); - SI.registerCallbacks(PIC, &FAM); + SI.registerCallbacks(PIC, &FAM, &TheModule.getContext()); PipelineTuningOptions PTO; PTO.LoopVectorization = true; PTO.SLPVectorization = true; diff --git a/llvm/lib/Passes/PassBuilderBindings.cpp b/llvm/lib/Passes/PassBuilderBindings.cpp --- a/llvm/lib/Passes/PassBuilderBindings.cpp +++ b/llvm/lib/Passes/PassBuilderBindings.cpp @@ -66,7 +66,7 @@ PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); StandardInstrumentations SI(Debug, VerifyEach); - SI.registerCallbacks(PIC, &FAM); + SI.registerCallbacks(PIC, &FAM, &Mod->getContext()); ModulePassManager MPM; if (VerifyEach) { MPM.addPass(VerifierPass()); diff --git a/llvm/lib/Passes/StandardInstrumentations.cpp b/llvm/lib/Passes/StandardInstrumentations.cpp --- a/llvm/lib/Passes/StandardInstrumentations.cpp +++ b/llvm/lib/Passes/StandardInstrumentations.cpp @@ -766,14 +766,16 @@ return ShouldRun; } -void OptBisectInstrumentation::registerCallbacks( - PassInstrumentationCallbacks &PIC) { - if (!getOptBisector().isEnabled()) +void OptPassGateInstrumentation::registerCallbacks( + PassInstrumentationCallbacks &PIC, LLVMContext *Cntxt) { + OptPassGate &PassGate = Cntxt ? Cntxt->getOptPassGate() : getGlobalPassGate(); + if (!PassGate.isEnabled()) return; - PIC.registerShouldRunOptionalPassCallback([this](StringRef PassID, Any IR) { - if (isIgnored(PassID)) + + PIC.registerShouldRunOptionalPassCallback([&](StringRef PassName, Any IR) { + if (isIgnored(PassName)) return true; - bool ShouldRun = getOptBisector().checkPass(PassID, getIRName(IR)); + bool ShouldRun = PassGate.shouldRunPass(PassName, getIRName(IR)); if (!ShouldRun && !this->HasWrittenIR && !OptBisectPrintIRPath.empty()) { // FIXME: print IR if limit is higher than number of opt-bisect // invocations @@ -2093,12 +2095,13 @@ } void StandardInstrumentations::registerCallbacks( - PassInstrumentationCallbacks &PIC, FunctionAnalysisManager *FAM) { + PassInstrumentationCallbacks &PIC, FunctionAnalysisManager *FAM, + LLVMContext *Cntxt) { PrintIR.registerCallbacks(PIC); PrintPass.registerCallbacks(PIC); TimePasses.registerCallbacks(PIC); OptNone.registerCallbacks(PIC); - OptBisect.registerCallbacks(PIC); + OptPassGate.registerCallbacks(PIC, Cntxt); if (FAM) PreservedCFGChecker.registerCallbacks(PIC, *FAM); PrintChangedIR.registerCallbacks(PIC); diff --git a/llvm/unittests/IR/LegacyPassManagerTest.cpp b/llvm/unittests/IR/LegacyPassManagerTest.cpp --- a/llvm/unittests/IR/LegacyPassManagerTest.cpp +++ b/llvm/unittests/IR/LegacyPassManagerTest.cpp @@ -359,10 +359,8 @@ struct CustomOptPassGate : public OptPassGate { bool Skip; CustomOptPassGate(bool Skip) : Skip(Skip) { } - bool shouldRunPass(const Pass *P, StringRef IRDescription) override { - if (P->getPassKind() == PT_Module) - return !Skip; - return OptPassGate::shouldRunPass(P, IRDescription); + bool shouldRunPass(const StringRef PassName, StringRef IRDescription) override { + return !Skip; } bool isEnabled() const override { return true; } };