diff --git a/llvm/include/llvm/Analysis/InlineAdvisor.h b/llvm/include/llvm/Analysis/InlineAdvisor.h --- a/llvm/include/llvm/Analysis/InlineAdvisor.h +++ b/llvm/include/llvm/Analysis/InlineAdvisor.h @@ -10,6 +10,7 @@ #define LLVM_ANALYSIS_INLINEADVISOR_H #include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/LazyCallGraph.h" #include "llvm/Analysis/Utils/ImportedFunctionsInliningStatistics.h" #include "llvm/Config/llvm-config.h" #include "llvm/IR/PassManager.h" @@ -161,14 +162,13 @@ /// This must be called when the Inliner pass is exited, as function passes /// may be run subsequently. This allows an implementation of InlineAdvisor - /// to prepare for a partial update. - virtual void onPassExit() {} + /// to prepare for a partial update, based on the optional SCC. + virtual void onPassExit(LazyCallGraph::SCC *SCC = nullptr) {} - /// Called when the module is invalidated. We let the advisor implementation - /// decide what to refresh - in the case of the development mode - /// implementation, for example, we wouldn't want to delete the whole object - /// and need to re-load the model evaluator. - virtual void onModuleInvalidated() {} + /// Support for printer pass + virtual void print(raw_ostream &OS) const { + OS << "Unimplemented InlineAdvisor print\n"; + } protected: InlineAdvisor(Module &M, FunctionAnalysisManager &FAM); @@ -217,8 +217,6 @@ Result(Module &M, ModuleAnalysisManager &MAM) : M(M), MAM(MAM) {} bool invalidate(Module &, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &) { - if (Advisor && !PA.areAllPreserved()) - Advisor->onModuleInvalidated(); // Check whether the analysis has been explicitly invalidated. Otherwise, // it's stateless and remains preserved. auto PAC = PA.getChecker(); @@ -237,6 +235,17 @@ Result run(Module &M, ModuleAnalysisManager &MAM) { return Result(M, MAM); } }; +/// Printer pass for the FunctionPropertiesAnalysis results. +class InlineAdvisorAnalysisPrinterPass + : public PassInfoMixin { + raw_ostream &OS; + +public: + explicit InlineAdvisorAnalysisPrinterPass(raw_ostream &OS) : OS(OS) {} + + PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); +}; + #ifdef LLVM_HAVE_TF_AOT std::unique_ptr getReleaseModeAdvisor(Module &M, ModuleAnalysisManager &MAM); diff --git a/llvm/include/llvm/Analysis/MLInlineAdvisor.h b/llvm/include/llvm/Analysis/MLInlineAdvisor.h --- a/llvm/include/llvm/Analysis/MLInlineAdvisor.h +++ b/llvm/include/llvm/Analysis/MLInlineAdvisor.h @@ -14,8 +14,8 @@ #include "llvm/Analysis/MLModelRunner.h" #include "llvm/IR/PassManager.h" +#include #include -#include namespace llvm { class Module; @@ -29,6 +29,7 @@ virtual ~MLInlineAdvisor() = default; void onPassEntry() override; + void onPassExit(LazyCallGraph::SCC *SCC) override; int64_t getIRSize(const Function &F) const { return F.getInstructionCount(); } void onSuccessfulInlining(const MLInlineAdvice &Advice, @@ -37,7 +38,6 @@ bool isForcedToStop() const { return ForceStop; } int64_t getLocalCalls(Function &F); const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); } - void onModuleInvalidated() override { Invalid = true; } protected: std::unique_ptr getAdviceImpl(CallBase &CB) override; @@ -60,15 +60,22 @@ private: int64_t getModuleIRSize() const; - bool Invalid = true; + void print(raw_ostream &OS) const override { + OS << "[MLInlineAdvisor] Nodes: " << NodeCount << " Edges: " << EdgeCount + << "\n"; + } + LazyCallGraph &CG; int64_t NodeCount = 0; int64_t EdgeCount = 0; + int64_t EdgesOfLastSeenNodes = 0; + std::map FunctionLevels; const int32_t InitialIRSize = 0; int32_t CurrentIRSize = 0; - + std::deque NodesInLastSCC; + DenseSet AllNodes; bool ForceStop = false; }; diff --git a/llvm/lib/Analysis/InlineAdvisor.cpp b/llvm/lib/Analysis/InlineAdvisor.cpp --- a/llvm/lib/Analysis/InlineAdvisor.cpp +++ b/llvm/lib/Analysis/InlineAdvisor.cpp @@ -21,6 +21,7 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/PassManager.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/raw_ostream.h" @@ -554,3 +555,13 @@ OptimizationRemarkEmitter &InlineAdvisor::getCallerORE(CallBase &CB) { return FAM.getResult(*CB.getCaller()); } + +PreservedAnalyses +InlineAdvisorAnalysisPrinterPass::run(Module &M, ModuleAnalysisManager &MAM) { + const auto *IA = MAM.getCachedResult(M); + if (!IA) + OS << "No Inline Advisor\n"; + else + IA->getAdvisor()->print(OS); + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp --- a/llvm/lib/Analysis/MLInlineAdvisor.cpp +++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp @@ -129,6 +129,11 @@ FunctionLevels[&CG.get(*F)] = Level; } } + for (auto KVP : FunctionLevels) { + AllNodes.insert(KVP.first); + EdgeCount += getLocalCalls(KVP.first->getFunction()); + } + NodeCount = AllNodes.size(); } unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const { @@ -138,16 +143,56 @@ void MLInlineAdvisor::onPassEntry() { // Function passes executed between InlinerPass runs may have changed the // module-wide features. - if (!Invalid) - return; - NodeCount = 0; - EdgeCount = 0; - for (auto &F : M) - if (!F.isDeclaration()) { - ++NodeCount; - EdgeCount += getLocalCalls(F); + // The cgscc pass manager rules are such that: + // - if a pass leads to merging SCCs, then the pipeline is restarted on the + // merged SCC + // - if a pass leads to splitting the SCC, then we continue with one of the + // splits + // This means that the NodesInLastSCC is a superset (not strict) of the nodes + // that subsequent passes would have processed + // - in addition, if new Nodes were created by a pass (e.g. CoroSplit), + // they'd be adjacent to Nodes in the last SCC. So we just need to check the + // boundary of Nodes in NodesInLastSCC for Nodes we haven't seen. We don't + // care about the nature of the Edge (call or ref). + NodeCount -= static_cast(NodesInLastSCC.size()); + while (!NodesInLastSCC.empty()) { + const auto *N = NodesInLastSCC.front(); + NodesInLastSCC.pop_front(); + // The Function wrapped by N could have been deleted since we last saw it. + if (N->isDead()) { + assert(!N->getFunction().isDeclaration()); + continue; } - Invalid = false; + ++NodeCount; + EdgeCount += getLocalCalls(N->getFunction()); + for (const auto &E : *(*N)) { + const auto *AdjNode = &E.getNode(); + assert(!AdjNode->isDead() && !AdjNode->getFunction().isDeclaration()); + auto I = AllNodes.insert(AdjNode); + if (I.second) + NodesInLastSCC.push_back(AdjNode); + } + } + + EdgeCount -= EdgesOfLastSeenNodes; + EdgesOfLastSeenNodes = 0; +} + +void MLInlineAdvisor::onPassExit(LazyCallGraph::SCC *LastSCC) { + if (!LastSCC) + return; + // Keep track of the nodes and edges we last saw. Then, in onPassEntry, + // we update the node count and edge count from the subset of these nodes that + // survived. + assert(NodesInLastSCC.empty()); + assert(NodeCount >= LastSCC->size()); + EdgesOfLastSeenNodes = 0; + for (const auto &N : *LastSCC) { + assert(!N.isDead()); + EdgesOfLastSeenNodes += getLocalCalls(N.getFunction()); + NodesInLastSCC.push_back(&N); + } + assert(EdgeCount >= EdgesOfLastSeenNodes); } int64_t MLInlineAdvisor::getLocalCalls(Function &F) { diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -62,6 +62,7 @@ MODULE_PASS("hotcoldsplit", HotColdSplittingPass()) MODULE_PASS("inferattrs", InferFunctionAttrsPass()) MODULE_PASS("inliner-wrapper", ModuleInlinerWrapperPass()) +MODULE_PASS("print", InlineAdvisorAnalysisPrinterPass(dbgs())) MODULE_PASS("inliner-wrapper-no-mandatory-first", ModuleInlinerWrapperPass( getInlineParams(), false)) diff --git a/llvm/lib/Transforms/IPO/Inliner.cpp b/llvm/lib/Transforms/IPO/Inliner.cpp --- a/llvm/lib/Transforms/IPO/Inliner.cpp +++ b/llvm/lib/Transforms/IPO/Inliner.cpp @@ -92,6 +92,11 @@ DisableInlinedAllocaMerging("disable-inlined-alloca-merging", cl::init(false), cl::Hidden); +/// A flag for test, so we can print the content of the advisor when running it +/// as part of the default (e.g. -O3) pipeline. +static cl::opt KeepAdvisorForPrinting("keep-inline-advisor-for-printing", + cl::init(false), cl::Hidden); + extern cl::opt InlinerFunctionImportStats; static cl::opt CGSCCInlineReplayFile( @@ -741,7 +746,7 @@ InlineAdvisor &Advisor = getAdvisor(MAMProxy, FAM, M); Advisor.onPassEntry(); - auto AdvisorOnExit = make_scope_exit([&] { Advisor.onPassExit(); }); + auto AdvisorOnExit = make_scope_exit([&] { Advisor.onPassExit(&InitialC); }); // We use a single common worklist for calls across the entire SCC. We // process these in-order and append new calls introduced during inlining to @@ -1124,7 +1129,8 @@ // Discard the InlineAdvisor, a subsequent inlining session should construct // its own. auto PA = PreservedAnalyses::all(); - PA.abandon(); + if (!KeepAdvisorForPrinting) + PA.abandon(); return PA; } diff --git a/llvm/test/Transforms/Inline/ML/state-tracking-coro.ll b/llvm/test/Transforms/Inline/ML/state-tracking-coro.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/Inline/ML/state-tracking-coro.ll @@ -0,0 +1,59 @@ +; Based on llvm/test/Transforms/Coroutines/coro-split-02.ll +; Corosplit will keep f1 and add 3 more functions. +; RUN: opt -passes='default,print' -training-log=/dev/null \ +; RUN: -S -enable-ml-inliner=development -keep-inline-advisor-for-printing < %s 2>&1 | FileCheck %s +; REQUIRES: have_tf_api +; +; CHECK: [MLInlineAdvisor] Nodes: 4 Edges: 0 + +%"struct.std::coroutine_handle" = type { i8* } +%"struct.std::coroutine_handle.0" = type { %"struct.std::coroutine_handle" } +%"struct.lean_future::Awaiter" = type { i32, %"struct.std::coroutine_handle.0" } + +declare i8* @malloc(i64) +declare void @print(i32) + +define void @a() "coroutine.presplit"="1" { +entry: + %ref.tmp7 = alloca %"struct.lean_future::Awaiter", align 8 + %testval = alloca i32 + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %alloc = call i8* @malloc(i64 16) #3 + %vFrame = call noalias nonnull i8* @llvm.coro.begin(token %id, i8* %alloc) + + %save = call token @llvm.coro.save(i8* null) + %Result.i19 = getelementptr inbounds %"struct.lean_future::Awaiter", %"struct.lean_future::Awaiter"* %ref.tmp7, i64 0, i32 0 + %suspend = call i8 @llvm.coro.suspend(token %save, i1 false) + switch i8 %suspend, label %exit [ + i8 0, label %await.ready + i8 1, label %exit + ] +await.ready: + %StrayCoroSave = call token @llvm.coro.save(i8* null) + %val = load i32, i32* %Result.i19 + %cast = bitcast i32* %testval to i8* + call void @llvm.lifetime.start.p0i8(i64 4, i8* %cast) + %test = load i32, i32* %testval + call void @print(i32 %test) + call void @llvm.lifetime.end.p0i8(i64 4, i8* %cast) + call void @print(i32 %val) + br label %exit +exit: + call i1 @llvm.coro.end(i8* null, i1 false) + ret void +} + +declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) +declare i1 @llvm.coro.alloc(token) #3 +declare noalias nonnull i8* @"\01??2@YAPEAX_K@Z"(i64) local_unnamed_addr +declare i64 @llvm.coro.size.i64() #5 +declare i8* @llvm.coro.begin(token, i8* writeonly) #3 +declare void @"\01?puts@@YAXZZ"(...) +declare token @llvm.coro.save(i8*) #3 +declare i8* @llvm.coro.frame() #5 +declare i8 @llvm.coro.suspend(token, i1) #3 +declare void @"\01??3@YAXPEAX@Z"(i8*) local_unnamed_addr #10 +declare i8* @llvm.coro.free(token, i8* nocapture readonly) #2 +declare i1 @llvm.coro.end(i8*, i1) #3 +declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture) #4 +declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture) #4 diff --git a/llvm/test/Transforms/Inline/ML/state-tracking-scc-splits.ll b/llvm/test/Transforms/Inline/ML/state-tracking-scc-splits.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/Inline/ML/state-tracking-scc-splits.ll @@ -0,0 +1,293 @@ +; Based on llvm/test/Other/cgscc-iterate-function-mutation.ll +; RUN: opt -passes='default,print' -training-log=/dev/null \ +; RUN: -S -enable-ml-inliner=development -keep-inline-advisor-for-printing < %s 2>&1 | FileCheck %s +; REQUIRES: have_tf_api +; +; CHECK: [MLInlineAdvisor] Nodes: 36 Edges: 0 + +declare void @readnone() nofree nosync readnone +declare void @unknown() +declare void @reference_function_pointer(void()*) nofree nosync readnone + +; The @test1_* set of functions checks that when we mutate functions with +; simplifycfg to delete call edges and this ends up splitting both the SCCs +; and the RefSCCs that those functions are in, we re-run the CGSCC passes to +; observe the refined call graph structure. + +define void @test1_a() { + call void @test1_b1() + call void @test1_b2() + call void @test1_b3() + call void @test1_b4() + ret void +} + +define void @test1_b1() { + call void @readnone() + ret void +} + +define void @test1_b2() { + call void @readnone() + br i1 false, label %dead, label %exit + +dead: + call void @test1_a() + br label %exit + +exit: + ret void +} + +define void @test1_b3() { + call void @unknown() + br i1 false, label %dead, label %exit + +dead: + call void @test1_a() + br label %exit + +exit: + ret void +} + +define void @test1_b4() { + call void @readnone() + br i1 false, label %dead, label %exit + +dead: + call void @test1_a() + br label %exit + +exit: + ret void +} + +define void @test2_a() { + call void @test2_b1() + call void @test2_b2() + call void @test2_b3() + call void @test2_b4() + ret void +} + +define void @test2_b1() { + call void @readnone() + ret void +} + +define void @test2_b2() { + call void @reference_function_pointer(void()* @test2_a) + br i1 false, label %dead, label %exit + +dead: + call void @test2_a() + br label %exit + +exit: + ret void +} + +define void @test2_b3() { + call void @reference_function_pointer(void()* @test2_a) + call void @unknown() + br i1 false, label %dead, label %exit + +dead: + call void @test2_a() + br label %exit + +exit: + ret void +} + +define void @test2_b4() { + call void @reference_function_pointer(void()* @test2_a) + br i1 false, label %dead, label %exit + +dead: + call void @test2_a() + br label %exit + +exit: + ret void +} + +define void @test3_a() { + call void @test3_b11() + call void @test3_b21() + call void @test3_b31() + call void @test3_b41() + ret void +} + +define void @test3_b11() { + call void @test3_b12() + ret void +} + +define void @test3_b12() { + call void @test3_b13() + ret void +} + +define void @test3_b13() { + call void @readnone() + ret void +} + +define void @test3_b21() { + call void @test3_b22() + ret void +} + +define void @test3_b22() { + call void @test3_b23() + ret void +} + +define void @test3_b23() { + call void @readnone() + br i1 false, label %dead, label %exit + +dead: + call void @test3_a() + br label %exit + +exit: + ret void +} + +define void @test3_b31() { + call void @test3_b32() + ret void +} + +define void @test3_b32() { + call void @test3_b33() + ret void +} + +define void @test3_b33() { + call void @unknown() + br i1 false, label %dead, label %exit + +dead: + call void @test3_a() + br label %exit + +exit: + ret void +} + +define void @test3_b41() { + call void @test3_b42() + ret void +} + +define void @test3_b42() { + call void @test3_b43() + ret void +} + +define void @test3_b43() { + call void @readnone() + br i1 false, label %dead, label %exit + +dead: + call void @test3_a() + br label %exit + +exit: + ret void +} + +define void @test4_a() { + call void @test4_b11() + call void @test4_b21() + call void @test4_b31() + call void @test4_b41() + ret void +} + +define void @test4_b11() { + call void @test4_b12() + ret void +} + +define void @test4_b12() { + call void @test4_b13() + ret void +} + +define void @test4_b13() { + call void @readnone() + ret void +} + +define void @test4_b21() { + call void @test4_b22() + ret void +} + +define void @test4_b22() { + call void @test4_b23() + ret void +} + +define void @test4_b23() { + call void @reference_function_pointer(void()* @test4_a) + br i1 false, label %dead, label %exit + +dead: + call void @test4_a() + br label %exit + +exit: + ret void +} + +define void @test4_b31() { + call void @test4_b32() + ret void +} + +define void @test4_b32() { + call void @test4_b33() + ret void +} + +define void @test4_b33() { + call void @reference_function_pointer(void()* @test4_a) + call void @unknown() + br i1 false, label %dead, label %exit + +dead: + call void @test4_a() + br label %exit + +exit: + ret void +} + +define void @test4_b41() { + call void @test4_b42() + ret void +} + +define void @test4_b42() { + call void @test4_b43() + ret void +} + +define void @test4_b43() { + call void @reference_function_pointer(void()* @test4_a) + br i1 false, label %dead, label %exit + +dead: + call void @test4_a() + br label %exit + +exit: + ret void +} diff --git a/llvm/test/Transforms/Inline/ML/state-tracking.ll b/llvm/test/Transforms/Inline/ML/state-tracking.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/Inline/ML/state-tracking.ll @@ -0,0 +1,24 @@ +; RUN: opt -passes='default,print' -training-log=/dev/null \ +; RUN: -S -enable-ml-inliner=development -keep-inline-advisor-for-printing < %s 2>&1 | FileCheck %s +; REQUIRES: have_tf_api +; +; CHECK: [MLInlineAdvisor] Nodes: 3 Edges: 1 + +declare i32 @f1() + +define i32 @f2() { + ret i32 1 +} + +define i32 @f3() noinline { + ret i32 2 +} + +define i32 @f4() { + %a = call i32 @f1() + %b = call i32 @f2() + %c = call i32 @f3() + %d = add i32 %a, %b + %e = add i32 %d, %c + ret i32 %e +} \ No newline at end of file