diff --git a/llvm/include/llvm/Analysis/InlineAdvisor.h b/llvm/include/llvm/Analysis/InlineAdvisor.h --- a/llvm/include/llvm/Analysis/InlineAdvisor.h +++ b/llvm/include/llvm/Analysis/InlineAdvisor.h @@ -162,6 +162,12 @@ /// to prepare for a partial update. virtual void onPassExit() {} + /// Called when the module is invalidated. We let the advisor implementation + /// decide what to refresh - in the case of the development mode + /// implementation, for example, we wouldn't want to delete the whole object + /// and need to re-load the model evaluator. + virtual void onModuleInvalidated() {} + protected: InlineAdvisor(Module &M, FunctionAnalysisManager &FAM); virtual std::unique_ptr getAdviceImpl(CallBase &CB) = 0; @@ -226,6 +232,8 @@ Result(Module &M, ModuleAnalysisManager &MAM) : M(M), MAM(MAM) {} bool invalidate(Module &, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &) { + if (Advisor && !PA.areAllPreserved()) + Advisor->onModuleInvalidated(); // Check whether the analysis has been explicitly invalidated. Otherwise, // it's stateless and remains preserved. auto PAC = PA.getChecker(); diff --git a/llvm/include/llvm/Analysis/MLInlineAdvisor.h b/llvm/include/llvm/Analysis/MLInlineAdvisor.h --- a/llvm/include/llvm/Analysis/MLInlineAdvisor.h +++ b/llvm/include/llvm/Analysis/MLInlineAdvisor.h @@ -38,6 +38,7 @@ bool isForcedToStop() const { return ForceStop; } int64_t getLocalCalls(Function &F); const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); } + void onModuleInvalidated() override { Invalid = true; } protected: std::unique_ptr getAdviceImpl(CallBase &CB) override; @@ -55,6 +56,7 @@ private: int64_t getModuleIRSize() const; + bool Invalid = true; std::unique_ptr CG; int64_t NodeCount = 0; diff --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp --- a/llvm/lib/Analysis/MLInlineAdvisor.cpp +++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp @@ -116,6 +116,8 @@ void MLInlineAdvisor::onPassEntry() { // Function passes executed between InlinerPass runs may have changed the // module-wide features. + if (!Invalid) + return; NodeCount = 0; EdgeCount = 0; for (auto &F : M) @@ -123,6 +125,7 @@ ++NodeCount; EdgeCount += getLocalCalls(F); } + Invalid = false; } int64_t MLInlineAdvisor::getLocalCalls(Function &F) {