diff --git a/llvm/include/llvm/Analysis/InlineAdvisor.h b/llvm/include/llvm/Analysis/InlineAdvisor.h --- a/llvm/include/llvm/Analysis/InlineAdvisor.h +++ b/llvm/include/llvm/Analysis/InlineAdvisor.h @@ -9,7 +9,9 @@ #ifndef LLVM_ANALYSIS_INLINEADVISOR_H #define LLVM_ANALYSIS_INLINEADVISOR_H +#include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/LazyCallGraph.h" #include "llvm/Analysis/Utils/ImportedFunctionsInliningStatistics.h" #include "llvm/Config/llvm-config.h" #include "llvm/IR/PassManager.h" @@ -159,14 +161,13 @@ /// This must be called when the Inliner pass is exited, as function passes /// may be run subsequently. This allows an implementation of InlineAdvisor - /// to prepare for a partial update. - virtual void onPassExit() {} + /// to prepare for a partial update, based on the optional SCC. + virtual void onPassExit(LazyCallGraph::SCC *SCC = nullptr) {} - /// Called when the module is invalidated. We let the advisor implementation - /// decide what to refresh - in the case of the development mode - /// implementation, for example, we wouldn't want to delete the whole object - /// and need to re-load the model evaluator. - virtual void onModuleInvalidated() {} + /// Support for printer pass + virtual void print(raw_ostream &OS) const { + OS << "Unimplemented InlineAdvisor print\n"; + } protected: InlineAdvisor(Module &M, FunctionAnalysisManager &FAM); @@ -217,7 +218,7 @@ private: std::unique_ptr getAdviceImpl(CallBase &CB) override; - void onPassExit() override { freeDeletedFunctions(); } + void onPassExit(LazyCallGraph::SCC *SCC) override { freeDeletedFunctions(); } InlineParams Params; }; @@ -232,8 +233,6 @@ Result(Module &M, ModuleAnalysisManager &MAM) : M(M), MAM(MAM) {} bool invalidate(Module &, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &) { - if (Advisor && !PA.areAllPreserved()) - Advisor->onModuleInvalidated(); // Check whether the analysis has been explicitly invalidated. Otherwise, // it's stateless and remains preserved. auto PAC = PA.getChecker(); @@ -252,6 +251,17 @@ Result run(Module &M, ModuleAnalysisManager &MAM) { return Result(M, MAM); } }; +/// Printer pass for the FunctionPropertiesAnalysis results. +class InlineAdvisorAnalysisPrinterPass + : public PassInfoMixin { + raw_ostream &OS; + +public: + explicit InlineAdvisorAnalysisPrinterPass(raw_ostream &OS) : OS(OS) {} + + PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); +}; + #ifdef LLVM_HAVE_TF_AOT std::unique_ptr getReleaseModeAdvisor(Module &M, ModuleAnalysisManager &MAM); diff --git a/llvm/include/llvm/Analysis/LazyCallGraph.h b/llvm/include/llvm/Analysis/LazyCallGraph.h --- a/llvm/include/llvm/Analysis/LazyCallGraph.h +++ b/llvm/include/llvm/Analysis/LazyCallGraph.h @@ -930,6 +930,7 @@ EdgeSequence::iterator begin() { return EntryEdges.begin(); } EdgeSequence::iterator end() { return EntryEdges.end(); } + size_t nodeCount() const { return NodeMap.size(); } void buildRefSCCs(); diff --git a/llvm/include/llvm/Analysis/MLInlineAdvisor.h b/llvm/include/llvm/Analysis/MLInlineAdvisor.h --- a/llvm/include/llvm/Analysis/MLInlineAdvisor.h +++ b/llvm/include/llvm/Analysis/MLInlineAdvisor.h @@ -9,8 +9,10 @@ #ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H #define LLVM_ANALYSIS_MLINLINEADVISOR_H +#include "llvm/ADT/StringSet.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineAdvisor.h" +#include "llvm/Analysis/LazyCallGraph.h" #include "llvm/Analysis/MLModelRunner.h" #include "llvm/IR/PassManager.h" @@ -30,6 +32,7 @@ virtual ~MLInlineAdvisor() = default; void onPassEntry() override; + void onPassExit(LazyCallGraph::SCC *SCC) override; int64_t getIRSize(const Function &F) const { return F.getInstructionCount(); } void onSuccessfulInlining(const MLInlineAdvice &Advice, @@ -38,7 +41,6 @@ bool isForcedToStop() const { return ForceStop; } int64_t getLocalCalls(Function &F); const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); } - void onModuleInvalidated() override { Invalid = true; } protected: std::unique_ptr getAdviceImpl(CallBase &CB) override; @@ -55,16 +57,22 @@ private: int64_t getModuleIRSize() const; - - bool Invalid = true; + void print(raw_ostream &OS) const override { + OS << "[MLInlineAdvisor] Nodes: " << NodeCount << " Edges: " << EdgeCount + << "\n"; + } std::unique_ptr CG; + LazyCallGraph &LCG; int64_t NodeCount = 0; int64_t EdgeCount = 0; + int64_t NodeCountToRemove = 0; + int64_t EdgeCountToRemove = 0; + std::map FunctionLevels; const int32_t InitialIRSize = 0; int32_t CurrentIRSize = 0; - + DenseSet LastNodes; bool ForceStop = false; }; diff --git a/llvm/lib/Analysis/InlineAdvisor.cpp b/llvm/lib/Analysis/InlineAdvisor.cpp --- a/llvm/lib/Analysis/InlineAdvisor.cpp +++ b/llvm/lib/Analysis/InlineAdvisor.cpp @@ -21,6 +21,7 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/PassManager.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/raw_ostream.h" @@ -569,3 +570,13 @@ OptimizationRemarkEmitter &InlineAdvisor::getCallerORE(CallBase &CB) { return FAM.getResult(*CB.getCaller()); } + +PreservedAnalyses +InlineAdvisorAnalysisPrinterPass::run(Module &M, ModuleAnalysisManager &MAM) { + const auto *IA = MAM.getCachedResult(M); + if (!IA) + OS << "No Inline Advisor\n"; + else + IA->getAdvisor()->print(OS); + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp --- a/llvm/lib/Analysis/MLInlineAdvisor.cpp +++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp @@ -22,6 +22,7 @@ #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/FunctionPropertiesAnalysis.h" #include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/LazyCallGraph.h" #include "llvm/Analysis/MLInlineAdvisor.h" #include "llvm/Analysis/MLModelRunner.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" @@ -76,6 +77,7 @@ : InlineAdvisor( M, MAM.getResult(M).getManager()), ModelRunner(std::move(Runner)), CG(new CallGraph(M)), + LCG(MAM.getResult(M)), InitialIRSize(getModuleIRSize()), CurrentIRSize(InitialIRSize) { assert(ModelRunner); @@ -111,21 +113,57 @@ FunctionLevels[F] = Level; } } + for (auto &F : M) + if (!F.isDeclaration()) { + ++NodeCount; + EdgeCount += getLocalCalls(F); + } } void MLInlineAdvisor::onPassEntry() { // Function passes executed between InlinerPass runs may have changed the // module-wide features. - if (!Invalid) + for (const auto *N : LastNodes) { + if (N->isDead() || N->getFunction().isDeclaration()) + continue; + ++NodeCount; + EdgeCount += getLocalCalls(N->getFunction()); + } + LastNodes.clear(); + NodeCount -= NodeCountToRemove; + EdgeCount -= EdgeCountToRemove; + NodeCountToRemove = 0; + EdgeCountToRemove = 0; + + int64_t ExpectedNodeCount = static_cast(LCG.nodeCount()); + if (ExpectedNodeCount == NodeCount) return; - NodeCount = 0; + // Slow update: if we accrued new nodes since we were here last, we don't know + // which those were, and need to rebuild the total edge counts + assert(ExpectedNodeCount > NodeCount); + NodeCount = ExpectedNodeCount; EdgeCount = 0; for (auto &F : M) - if (!F.isDeclaration()) { - ++NodeCount; + if (!F.isDeclaration()) EdgeCount += getLocalCalls(F); - } - Invalid = false; +} + +void MLInlineAdvisor::onPassExit(LazyCallGraph::SCC *LastSCC) { + if (!LastSCC) + return; + // Keep track of the nodes and edges we last saw. Then, in onPassEntry, + // we update the node count and edge count from the subset of these nodes that + // survived. + assert(LastNodes.empty()); + NodeCountToRemove = LastSCC->size(); + assert(NodeCount >= NodeCountToRemove); + EdgeCountToRemove = 0; + for (const auto &N : *LastSCC) { + assert(!N.isDead()); + EdgeCountToRemove += getLocalCalls(N.getFunction()); + LastNodes.insert(&N); + } + assert(EdgeCount >= EdgeCountToRemove); } int64_t MLInlineAdvisor::getLocalCalls(Function &F) { diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -62,6 +62,7 @@ MODULE_PASS("hotcoldsplit", HotColdSplittingPass()) MODULE_PASS("inferattrs", InferFunctionAttrsPass()) MODULE_PASS("inliner-wrapper", ModuleInlinerWrapperPass()) +MODULE_PASS("print", InlineAdvisorAnalysisPrinterPass(dbgs())) MODULE_PASS("inliner-wrapper-no-mandatory-first", ModuleInlinerWrapperPass( getInlineParams(), false)) diff --git a/llvm/lib/Transforms/IPO/Inliner.cpp b/llvm/lib/Transforms/IPO/Inliner.cpp --- a/llvm/lib/Transforms/IPO/Inliner.cpp +++ b/llvm/lib/Transforms/IPO/Inliner.cpp @@ -92,6 +92,11 @@ DisableInlinedAllocaMerging("disable-inlined-alloca-merging", cl::init(false), cl::Hidden); +/// A flag for test, so we can print the content of the advisor when running it +/// as part of the default (e.g. -O3) pipeline. +static cl::opt KeepAdvisorForPrinting("keep-inline-advisor-for-printing", + cl::init(false), cl::Hidden); + extern cl::opt InlinerFunctionImportStats; static cl::opt CGSCCInlineReplayFile( @@ -741,7 +746,7 @@ InlineAdvisor &Advisor = getAdvisor(MAMProxy, FAM, M); Advisor.onPassEntry(); - auto AdvisorOnExit = make_scope_exit([&] { Advisor.onPassExit(); }); + auto AdvisorOnExit = make_scope_exit([&] { Advisor.onPassExit(&InitialC); }); // We use a single common worklist for calls across the entire SCC. We // process these in-order and append new calls introduced during inlining to @@ -1118,7 +1123,8 @@ // Discard the InlineAdvisor, a subsequent inlining session should construct // its own. auto PA = PreservedAnalyses::all(); - PA.abandon(); + if (!KeepAdvisorForPrinting) + PA.abandon(); return PA; } diff --git a/llvm/test/Transforms/Inline/ML/state-tracking-coro.ll b/llvm/test/Transforms/Inline/ML/state-tracking-coro.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/Inline/ML/state-tracking-coro.ll @@ -0,0 +1,59 @@ +; Based on llvm/test/Transforms/Coroutines/coro-split-02.ll +; Corosplit will keep f1 and add 3 more functions. +; RUN: opt -passes='default,print' -training-log=/dev/null \ +; RUN: -S -enable-ml-inliner=development -keep-inline-advisor-for-printing < %s 2>&1 | FileCheck %s +; REQUIRES: have_tf_api +; +; CHECK: [MLInlineAdvisor] Nodes: 4 Edges: 0 + +%"struct.std::coroutine_handle" = type { i8* } +%"struct.std::coroutine_handle.0" = type { %"struct.std::coroutine_handle" } +%"struct.lean_future::Awaiter" = type { i32, %"struct.std::coroutine_handle.0" } + +declare i8* @malloc(i64) +declare void @print(i32) + +define void @a() "coroutine.presplit"="1" { +entry: + %ref.tmp7 = alloca %"struct.lean_future::Awaiter", align 8 + %testval = alloca i32 + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %alloc = call i8* @malloc(i64 16) #3 + %vFrame = call noalias nonnull i8* @llvm.coro.begin(token %id, i8* %alloc) + + %save = call token @llvm.coro.save(i8* null) + %Result.i19 = getelementptr inbounds %"struct.lean_future::Awaiter", %"struct.lean_future::Awaiter"* %ref.tmp7, i64 0, i32 0 + %suspend = call i8 @llvm.coro.suspend(token %save, i1 false) + switch i8 %suspend, label %exit [ + i8 0, label %await.ready + i8 1, label %exit + ] +await.ready: + %StrayCoroSave = call token @llvm.coro.save(i8* null) + %val = load i32, i32* %Result.i19 + %cast = bitcast i32* %testval to i8* + call void @llvm.lifetime.start.p0i8(i64 4, i8* %cast) + %test = load i32, i32* %testval + call void @print(i32 %test) + call void @llvm.lifetime.end.p0i8(i64 4, i8* %cast) + call void @print(i32 %val) + br label %exit +exit: + call i1 @llvm.coro.end(i8* null, i1 false) + ret void +} + +declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) +declare i1 @llvm.coro.alloc(token) #3 +declare noalias nonnull i8* @"\01??2@YAPEAX_K@Z"(i64) local_unnamed_addr +declare i64 @llvm.coro.size.i64() #5 +declare i8* @llvm.coro.begin(token, i8* writeonly) #3 +declare void @"\01?puts@@YAXZZ"(...) +declare token @llvm.coro.save(i8*) #3 +declare i8* @llvm.coro.frame() #5 +declare i8 @llvm.coro.suspend(token, i1) #3 +declare void @"\01??3@YAXPEAX@Z"(i8*) local_unnamed_addr #10 +declare i8* @llvm.coro.free(token, i8* nocapture readonly) #2 +declare i1 @llvm.coro.end(i8*, i1) #3 +declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture) #4 +declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture) #4 diff --git a/llvm/test/Transforms/Inline/ML/state-tracking-scc-splits.ll b/llvm/test/Transforms/Inline/ML/state-tracking-scc-splits.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/Inline/ML/state-tracking-scc-splits.ll @@ -0,0 +1,293 @@ +; Based on llvm/test/Other/cgscc-iterate-function-mutation.ll +; RUN: opt -passes='default,print' -training-log=/dev/null \ +; RUN: -S -enable-ml-inliner=development -keep-inline-advisor-for-printing < %s 2>&1 | FileCheck %s +; REQUIRES: have_tf_api +; +; CHECK: [MLInlineAdvisor] Nodes: 36 Edges: 0 + +declare void @readnone() nofree nosync readnone +declare void @unknown() +declare void @reference_function_pointer(void()*) nofree nosync readnone + +; The @test1_* set of functions checks that when we mutate functions with +; simplifycfg to delete call edges and this ends up splitting both the SCCs +; and the RefSCCs that those functions are in, we re-run the CGSCC passes to +; observe the refined call graph structure. + +define void @test1_a() { + call void @test1_b1() + call void @test1_b2() + call void @test1_b3() + call void @test1_b4() + ret void +} + +define void @test1_b1() { + call void @readnone() + ret void +} + +define void @test1_b2() { + call void @readnone() + br i1 false, label %dead, label %exit + +dead: + call void @test1_a() + br label %exit + +exit: + ret void +} + +define void @test1_b3() { + call void @unknown() + br i1 false, label %dead, label %exit + +dead: + call void @test1_a() + br label %exit + +exit: + ret void +} + +define void @test1_b4() { + call void @readnone() + br i1 false, label %dead, label %exit + +dead: + call void @test1_a() + br label %exit + +exit: + ret void +} + +define void @test2_a() { + call void @test2_b1() + call void @test2_b2() + call void @test2_b3() + call void @test2_b4() + ret void +} + +define void @test2_b1() { + call void @readnone() + ret void +} + +define void @test2_b2() { + call void @reference_function_pointer(void()* @test2_a) + br i1 false, label %dead, label %exit + +dead: + call void @test2_a() + br label %exit + +exit: + ret void +} + +define void @test2_b3() { + call void @reference_function_pointer(void()* @test2_a) + call void @unknown() + br i1 false, label %dead, label %exit + +dead: + call void @test2_a() + br label %exit + +exit: + ret void +} + +define void @test2_b4() { + call void @reference_function_pointer(void()* @test2_a) + br i1 false, label %dead, label %exit + +dead: + call void @test2_a() + br label %exit + +exit: + ret void +} + +define void @test3_a() { + call void @test3_b11() + call void @test3_b21() + call void @test3_b31() + call void @test3_b41() + ret void +} + +define void @test3_b11() { + call void @test3_b12() + ret void +} + +define void @test3_b12() { + call void @test3_b13() + ret void +} + +define void @test3_b13() { + call void @readnone() + ret void +} + +define void @test3_b21() { + call void @test3_b22() + ret void +} + +define void @test3_b22() { + call void @test3_b23() + ret void +} + +define void @test3_b23() { + call void @readnone() + br i1 false, label %dead, label %exit + +dead: + call void @test3_a() + br label %exit + +exit: + ret void +} + +define void @test3_b31() { + call void @test3_b32() + ret void +} + +define void @test3_b32() { + call void @test3_b33() + ret void +} + +define void @test3_b33() { + call void @unknown() + br i1 false, label %dead, label %exit + +dead: + call void @test3_a() + br label %exit + +exit: + ret void +} + +define void @test3_b41() { + call void @test3_b42() + ret void +} + +define void @test3_b42() { + call void @test3_b43() + ret void +} + +define void @test3_b43() { + call void @readnone() + br i1 false, label %dead, label %exit + +dead: + call void @test3_a() + br label %exit + +exit: + ret void +} + +define void @test4_a() { + call void @test4_b11() + call void @test4_b21() + call void @test4_b31() + call void @test4_b41() + ret void +} + +define void @test4_b11() { + call void @test4_b12() + ret void +} + +define void @test4_b12() { + call void @test4_b13() + ret void +} + +define void @test4_b13() { + call void @readnone() + ret void +} + +define void @test4_b21() { + call void @test4_b22() + ret void +} + +define void @test4_b22() { + call void @test4_b23() + ret void +} + +define void @test4_b23() { + call void @reference_function_pointer(void()* @test4_a) + br i1 false, label %dead, label %exit + +dead: + call void @test4_a() + br label %exit + +exit: + ret void +} + +define void @test4_b31() { + call void @test4_b32() + ret void +} + +define void @test4_b32() { + call void @test4_b33() + ret void +} + +define void @test4_b33() { + call void @reference_function_pointer(void()* @test4_a) + call void @unknown() + br i1 false, label %dead, label %exit + +dead: + call void @test4_a() + br label %exit + +exit: + ret void +} + +define void @test4_b41() { + call void @test4_b42() + ret void +} + +define void @test4_b42() { + call void @test4_b43() + ret void +} + +define void @test4_b43() { + call void @reference_function_pointer(void()* @test4_a) + br i1 false, label %dead, label %exit + +dead: + call void @test4_a() + br label %exit + +exit: + ret void +} diff --git a/llvm/test/Transforms/Inline/ML/state-tracking.ll b/llvm/test/Transforms/Inline/ML/state-tracking.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/Inline/ML/state-tracking.ll @@ -0,0 +1,24 @@ +; RUN: opt -passes='default,print' -training-log=/dev/null \ +; RUN: -S -enable-ml-inliner=development -keep-inline-advisor-for-printing < %s 2>&1 | FileCheck %s +; REQUIRES: have_tf_api +; +; CHECK: [MLInlineAdvisor] Nodes: 3 Edges: 1 + +declare i32 @f1() + +define i32 @f2() { + ret i32 1 +} + +define i32 @f3() noinline { + ret i32 2 +} + +define i32 @f4() { + %a = call i32 @f1() + %b = call i32 @f2() + %c = call i32 @f3() + %d = add i32 %a, %b + %e = add i32 %d, %c + ret i32 %e +} \ No newline at end of file