diff --git a/llvm/include/llvm/Transforms/IPO/HotColdSplitting.h b/llvm/include/llvm/Transforms/IPO/HotColdSplitting.h --- a/llvm/include/llvm/Transforms/IPO/HotColdSplitting.h +++ b/llvm/include/llvm/Transforms/IPO/HotColdSplitting.h @@ -13,6 +13,7 @@ #define LLVM_TRANSFORMS_IPO_HOTCOLDSPLITTING_H #include "llvm/IR/PassManager.h" +#include "llvm/Support/BranchProbability.h" namespace llvm { @@ -42,6 +43,11 @@ private: bool isFunctionCold(const Function &F) const; + bool isBasicBlockCold(BasicBlock* BB, + BranchProbability ColdProbThresh, + SmallPtrSetImpl &ColdBlocks, + SmallPtrSetImpl &AnnotatedColdBlocks, + BlockFrequencyInfo *BFI) const; bool shouldOutlineFrom(const Function &F) const; bool outlineColdRegions(Function &F, bool HasProfileSummary); Function *extractColdRegion(const BlockSequence &Region, diff --git a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp --- a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp +++ b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp @@ -44,6 +44,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/Support/CommandLine.h" @@ -86,6 +87,11 @@ "hotcoldsplit-max-params", cl::init(4), cl::Hidden, cl::desc("Maximum number of parameters for a split function")); +static cl::opt ColdBranchProbDenom( + "hotcoldsplit-cold-probability-denom", cl::init(100), cl::Hidden, + cl::desc("Divisor of cold branch probability." + "BranchProbability = 1/ColdBranchProbDenom")); + namespace { // Same as blockEndsInUnreachable in CodeGen/BranchFolding.cpp. Do not modify // this function unless you modify the MBB version as well. @@ -102,6 +108,32 @@ return !(isa(I) || isa(I)); } +void analyzeProfMetadata(BasicBlock *BB, + BranchProbability ColdProbThresh, + SmallPtrSetImpl &AnnotatedColdBlocks) { + // TODO: Handle branches with > 2 successors. + BranchInst *CondBr = dyn_cast(BB->getTerminator()); + if (!CondBr) + return; + + uint64_t TrueWt, FalseWt; + if (!extractBranchWeights(*CondBr, TrueWt, FalseWt)) + return; + + auto SumWt = TrueWt + FalseWt; + if (SumWt == 0) + return; + + auto TrueProb = BranchProbability::getBranchProbability(TrueWt, SumWt); + auto FalseProb = BranchProbability::getBranchProbability(FalseWt, SumWt); + + if (TrueProb <= ColdProbThresh) + AnnotatedColdBlocks.insert(CondBr->getSuccessor(0)); + + if (FalseProb <= ColdProbThresh) + AnnotatedColdBlocks.insert(CondBr->getSuccessor(1)); +} + bool unlikelyExecuted(BasicBlock &BB) { // Exception handling blocks are unlikely executed. if (BB.isEHPad() || isa(BB.getTerminator())) @@ -183,6 +215,34 @@ return false; } +bool HotColdSplitting::isBasicBlockCold(BasicBlock *BB, + BranchProbability ColdProbThresh, + SmallPtrSetImpl &ColdBlocks, + SmallPtrSetImpl &AnnotatedColdBlocks, + BlockFrequencyInfo *BFI) const { + // This block is already part of some outlining region. + if (ColdBlocks.count(BB)) + return true; + + if (BFI) { + if (PSI->isColdBlock(BB, BFI)) + return true; + } else { + // Find cold blocks of successors of BB during a reverse postorder traversal. + analyzeProfMetadata(BB, ColdProbThresh, AnnotatedColdBlocks); + + // A statically cold BB would be known before it is visited + // because the prof-data of incoming edges are 'analyzed' as part of RPOT. + if (AnnotatedColdBlocks.count(BB)) + return true; + } + + if (EnableStaticAnalysis && unlikelyExecuted(*BB)) + return true; + + return false; +} + // Returns false if the function should not be considered for hot-cold split // optimization. bool HotColdSplitting::shouldOutlineFrom(const Function &F) const { @@ -565,6 +625,9 @@ // The set of cold blocks. SmallPtrSet ColdBlocks; + // Set of cold blocks obtained with RPOT. + SmallPtrSet AnnotatedColdBlocks; + // The worklist of non-intersecting regions left to outline. SmallVector OutliningWorklist; @@ -587,16 +650,15 @@ TargetTransformInfo &TTI = GetTTI(F); OptimizationRemarkEmitter &ORE = (*GetORE)(F); AssumptionCache *AC = LookupAC(F); + auto ColdProbThresh = TTI.getPredictableBranchThreshold().getCompl(); + + if (ColdBranchProbDenom.getNumOccurrences()) + ColdProbThresh = BranchProbability(1, ColdBranchProbDenom.getValue()); // Find all cold regions. for (BasicBlock *BB : RPOT) { - // This block is already part of some outlining region. - if (ColdBlocks.count(BB)) - continue; - - bool Cold = (BFI && PSI->isColdBlock(BB, BFI)) || - (EnableStaticAnalysis && unlikelyExecuted(*BB)); - if (!Cold) + if (!isBasicBlockCold(BB, ColdProbThresh, ColdBlocks, AnnotatedColdBlocks, + BFI)) continue; LLVM_DEBUG({ diff --git a/llvm/test/Transforms/HotColdSplit/split-static-profile.ll b/llvm/test/Transforms/HotColdSplit/split-static-profile.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/HotColdSplit/split-static-profile.ll @@ -0,0 +1,125 @@ +; Check that the unlikely branch is outlined. Override internal branch thresholds with -hotcoldsplit-cold-probability-denom + +; RUN: opt -S -passes=hotcoldsplit < %s | FileCheck %s --check-prefixes=CHECK-OUTLINE,CHECK-NOOUTLINE-BAZ +; RUN: opt -S -passes=hotcoldsplit -hotcoldsplit-cold-probability-denom=50 < %s | FileCheck --check-prefixes=CHECK-OUTLINE,CHECK-PROB %s + +; int cold(const char*); +; int hot(const char*); +; void foo(int a, int b) { +; if (a == b) [[unlikely]] { // Should be outlined. +; cold("same"); +; cold("same"); +; } else { +; hot("different"); +; } +; } + +; void bar(int a, int b) { +; if (a == b) [[likely]] { +; hot("same"); +; } else { // Should be outlined. +; cold("different"); +; cold("different"); +; } +; } + +; void baz(int a, int b) { +; if (a == b) [[likely]] { +; hot("same"); +; } else { // Should be outlined. +; cold("different"); +; cold("different"); +; } +; } + +; All the outlined cold functions are emitted after the hot functions. +; CHECK-OUTLINE: @foo +; CHECK-OUTLINE: @bar +; CHECK-OUTLINE: @baz + +; CHECK-OUTLINE: internal void @foo.cold.1() #[[ATTR0:[0-9]+]] +; CHECK-OUTLINE-NEXT: newFuncRoot +; CHECK-OUTLINE: tail call noundef i32 @cold +; CHECK-OUTLINE: tail call noundef i32 @cold + +; CHECK-OUTLINE: internal void @bar.cold.1() #[[ATTR0:[0-9]+]] +; CHECK-OUTLINE-NEXT: newFuncRoot +; CHECK-OUTLINE: tail call noundef i32 @cold +; CHECK-OUTLINE: tail call noundef i32 @cold + +; CHECK-NOOUTLINE-BAZ-NOT: internal void @baz.cold.1() + +; CHECK-PROB: internal void @baz.cold.1() #[[ATTR0:[0-9]+]] +; CHECK-PROB-NEXT: newFuncRoot +; CHECK-PROB: tail call noundef i32 @cold +; CHECK-PROB: tail call noundef i32 @cold +; CHECK-OUTLINE: attributes #[[ATTR0]] = { cold minsize } + +source_filename = "/app/example.cpp" +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@.str = private unnamed_addr constant [5 x i8] c"same\00", align 1 +@.str.1 = private unnamed_addr constant [10 x i8] c"different\00", align 1 + +define dso_local void @foo(i32 noundef %a, i32 noundef %b) local_unnamed_addr { +entry: + %cmp = icmp eq i32 %a, %b + br i1 %cmp, label %if.then, label %if.else, !prof !1 + +if.then: + %call = tail call noundef i32 @cold(ptr noundef nonnull @.str) + %call1 = tail call noundef i32 @cold(ptr noundef nonnull @.str) + br label %if.end + +if.else: + %call2 = tail call noundef i32 @hot(ptr noundef nonnull @.str.1) + br label %if.end + +if.end: + ret void +} + +declare noundef i32 @cold(ptr noundef) local_unnamed_addr #1 + +declare noundef i32 @hot(ptr noundef) local_unnamed_addr #1 + +define dso_local void @bar(i32 noundef %a, i32 noundef %b) local_unnamed_addr { +entry: + %cmp = icmp eq i32 %a, %b + br i1 %cmp, label %if.then, label %if.else, !prof !2 + +if.then: + %call = tail call noundef i32 @hot(ptr noundef nonnull @.str) + br label %if.end + +if.else: + %call1 = tail call noundef i32 @cold(ptr noundef nonnull @.str.1) + %call2 = tail call noundef i32 @cold(ptr noundef nonnull @.str.1) + br label %if.end + +if.end: + ret void +} + +define dso_local void @baz(i32 noundef %a, i32 noundef %b) local_unnamed_addr { +entry: + %cmp = icmp eq i32 %a, %b + br i1 %cmp, label %if.then, label %if.else, !prof !3 + +if.then: + %call = tail call noundef i32 @hot(ptr noundef nonnull @.str) + br label %if.end + +if.else: + %call1 = tail call noundef i32 @cold(ptr noundef nonnull @.str.1) + %call2 = tail call noundef i32 @cold(ptr noundef nonnull @.str.1) + br label %if.end + +if.end: + ret void +} + +!1 = !{!"branch_weights", i32 1, i32 2000} +!2 = !{!"branch_weights", i32 2000, i32 1} +!3 = !{!"branch_weights", i32 50, i32 1}