Index: llvm/trunk/include/llvm/Analysis/BlockFrequencyInfo.h =================================================================== --- llvm/trunk/include/llvm/Analysis/BlockFrequencyInfo.h +++ llvm/trunk/include/llvm/Analysis/BlockFrequencyInfo.h @@ -61,6 +61,11 @@ /// the enclosing function's count (if available) and returns the value. Optional getBlockProfileCount(const BasicBlock *BB) const; + /// \brief Returns the estimated profile count of \p Freq. + /// This uses the frequency \p Freq and multiplies it by + /// the enclosing function's count (if available) and returns the value. + Optional getProfileCountFromFreq(uint64_t Freq) const; + // Set the frequency of the given basic block. void setBlockFreq(const BasicBlock *BB, uint64_t Freq); Index: llvm/trunk/include/llvm/Analysis/BlockFrequencyInfoImpl.h =================================================================== --- llvm/trunk/include/llvm/Analysis/BlockFrequencyInfoImpl.h +++ llvm/trunk/include/llvm/Analysis/BlockFrequencyInfoImpl.h @@ -482,6 +482,8 @@ BlockFrequency getBlockFreq(const BlockNode &Node) const; Optional getBlockProfileCount(const Function &F, const BlockNode &Node) const; + Optional getProfileCountFromFreq(const Function &F, + uint64_t Freq) const; void setBlockFreq(const BlockNode &Node, uint64_t Freq); @@ -925,6 +927,10 @@ const BlockT *BB) const { return BlockFrequencyInfoImplBase::getBlockProfileCount(F, getNode(BB)); } + Optional getProfileCountFromFreq(const Function &F, + uint64_t Freq) const { + return BlockFrequencyInfoImplBase::getProfileCountFromFreq(F, Freq); + } void setBlockFreq(const BlockT *BB, uint64_t Freq); Scaled64 getFloatingBlockFreq(const BlockT *BB) const { return BlockFrequencyInfoImplBase::getFloatingBlockFreq(getNode(BB)); Index: llvm/trunk/include/llvm/CodeGen/MachineBlockFrequencyInfo.h =================================================================== --- llvm/trunk/include/llvm/CodeGen/MachineBlockFrequencyInfo.h +++ llvm/trunk/include/llvm/CodeGen/MachineBlockFrequencyInfo.h @@ -52,6 +52,7 @@ BlockFrequency getBlockFreq(const MachineBasicBlock *MBB) const; Optional getBlockProfileCount(const MachineBasicBlock *MBB) const; + Optional getProfileCountFromFreq(uint64_t Freq) const; const MachineFunction *getFunction() const; const MachineBranchProbabilityInfo *getMBPI() const; Index: llvm/trunk/include/llvm/Transforms/Utils/CodeExtractor.h =================================================================== --- llvm/trunk/include/llvm/Transforms/Utils/CodeExtractor.h +++ llvm/trunk/include/llvm/Transforms/Utils/CodeExtractor.h @@ -20,6 +20,9 @@ namespace llvm { template class ArrayRef; class BasicBlock; + class BlockFrequency; + class BlockFrequencyInfo; + class BranchProbabilityInfo; class DominatorTree; class Function; class Loop; @@ -47,6 +50,8 @@ // Various bits of state computed on construction. DominatorTree *const DT; const bool AggregateArgs; + BlockFrequencyInfo *BFI; + BranchProbabilityInfo *BPI; // Bits of intermediate state computed at various phases of extraction. SetVector Blocks; @@ -64,7 +69,9 @@ /// /// In this formation, we don't require a dominator tree. The given basic /// block is set up for extraction. - CodeExtractor(BasicBlock *BB, bool AggregateArgs = false); + CodeExtractor(BasicBlock *BB, bool AggregateArgs = false, + BlockFrequencyInfo *BFI = nullptr, + BranchProbabilityInfo *BPI = nullptr); /// \brief Create a code extractor for a sequence of blocks. /// @@ -73,20 +80,24 @@ /// sequence out into its new function. When a DominatorTree is also given, /// extra checking and transformations are enabled. CodeExtractor(ArrayRef BBs, DominatorTree *DT = nullptr, - bool AggregateArgs = false); + bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr, + BranchProbabilityInfo *BPI = nullptr); /// \brief Create a code extractor for a loop body. /// /// Behaves just like the generic code sequence constructor, but uses the /// block sequence of the loop. - CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false); + CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false, + BlockFrequencyInfo *BFI = nullptr, + BranchProbabilityInfo *BPI = nullptr); /// \brief Create a code extractor for a region node. /// /// Behaves just like the generic code sequence constructor, but uses the /// block sequence of the region node passed in. CodeExtractor(DominatorTree &DT, const RegionNode &RN, - bool AggregateArgs = false); + bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr, + BranchProbabilityInfo *BPI = nullptr); /// \brief Perform the extraction, returning the new function. /// @@ -122,6 +133,11 @@ void moveCodeToFunction(Function *newFunction); + void calculateNewCallTerminatorWeights( + BasicBlock *CodeReplacer, + DenseMap &ExitWeights, + BranchProbabilityInfo *BPI); + void emitCallAndSwitchStatement(Function *newFunction, BasicBlock *newHeader, ValueSet &inputs, Index: llvm/trunk/lib/Analysis/BlockFrequencyInfo.cpp =================================================================== --- llvm/trunk/lib/Analysis/BlockFrequencyInfo.cpp +++ llvm/trunk/lib/Analysis/BlockFrequencyInfo.cpp @@ -162,6 +162,13 @@ return BFI->getBlockProfileCount(*getFunction(), BB); } +Optional +BlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const { + if (!BFI) + return None; + return BFI->getProfileCountFromFreq(*getFunction(), Freq); +} + void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, uint64_t Freq) { assert(BFI && "Expected analysis to be available"); BFI->setBlockFreq(BB, Freq); Index: llvm/trunk/lib/Analysis/BlockFrequencyInfoImpl.cpp =================================================================== --- llvm/trunk/lib/Analysis/BlockFrequencyInfoImpl.cpp +++ llvm/trunk/lib/Analysis/BlockFrequencyInfoImpl.cpp @@ -533,12 +533,18 @@ Optional BlockFrequencyInfoImplBase::getBlockProfileCount(const Function &F, const BlockNode &Node) const { + return getProfileCountFromFreq(F, getBlockFreq(Node).getFrequency()); +} + +Optional +BlockFrequencyInfoImplBase::getProfileCountFromFreq(const Function &F, + uint64_t Freq) const { auto EntryCount = F.getEntryCount(); if (!EntryCount) return None; // Use 128 bit APInt to do the arithmetic to avoid overflow. APInt BlockCount(128, EntryCount.getValue()); - APInt BlockFreq(128, getBlockFreq(Node).getFrequency()); + APInt BlockFreq(128, Freq); APInt EntryFreq(128, getEntryFreq()); BlockCount *= BlockFreq; BlockCount = BlockCount.udiv(EntryFreq); Index: llvm/trunk/lib/CodeGen/MachineBlockFrequencyInfo.cpp =================================================================== --- llvm/trunk/lib/CodeGen/MachineBlockFrequencyInfo.cpp +++ llvm/trunk/lib/CodeGen/MachineBlockFrequencyInfo.cpp @@ -175,6 +175,12 @@ return MBFI ? MBFI->getBlockProfileCount(*F, MBB) : None; } +Optional +MachineBlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const { + const Function *F = MBFI->getFunction()->getFunction(); + return MBFI ? MBFI->getProfileCountFromFreq(*F, Freq) : None; +} + const MachineFunction *MachineBlockFrequencyInfo::getFunction() const { return MBFI ? MBFI->getFunction() : nullptr; } Index: llvm/trunk/lib/Transforms/IPO/PartialInlining.cpp =================================================================== --- llvm/trunk/lib/Transforms/IPO/PartialInlining.cpp +++ llvm/trunk/lib/Transforms/IPO/PartialInlining.cpp @@ -14,6 +14,9 @@ #include "llvm/Transforms/IPO/PartialInlining.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" @@ -133,9 +136,15 @@ DominatorTree DT; DT.recalculate(*DuplicateFunction); + // Manually calculate a BlockFrequencyInfo and BranchProbabilityInfo. + LoopInfo LI(DT); + BranchProbabilityInfo BPI(*DuplicateFunction, LI); + BlockFrequencyInfo BFI(*DuplicateFunction, BPI, LI); + // Extract the body of the if. Function *ExtractedFunction = - CodeExtractor(ToExtract, &DT).extractCodeRegion(); + CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false, &BFI, &BPI) + .extractCodeRegion(); // Inline the top-level if test into all callers. std::vector Users(DuplicateFunction->user_begin(), @@ -181,8 +190,8 @@ if (Recursive) continue; - if (Function *newFunc = unswitchFunction(CurrFunc)) { - Worklist.push_back(newFunc); + if (Function *NewFunc = unswitchFunction(CurrFunc)) { + Worklist.push_back(NewFunc); Changed = true; } } Index: llvm/trunk/lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- llvm/trunk/lib/Transforms/Utils/CodeExtractor.cpp +++ llvm/trunk/lib/Transforms/Utils/CodeExtractor.cpp @@ -17,6 +17,9 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/RegionInfo.h" #include "llvm/Analysis/RegionIterator.h" @@ -26,9 +29,11 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/Pass.h" +#include "llvm/Support/BlockFrequency.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" @@ -119,23 +124,30 @@ return buildExtractionBlockSet(R.block_begin(), R.block_end()); } -CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs) - : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} +CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs, + BlockFrequencyInfo *BFI, + BranchProbabilityInfo *BPI) + : DT(nullptr), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), + BPI(BPI), Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} CodeExtractor::CodeExtractor(ArrayRef BBs, DominatorTree *DT, - bool AggregateArgs) - : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} - -CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs) - : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {} + bool AggregateArgs, BlockFrequencyInfo *BFI, + BranchProbabilityInfo *BPI) + : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), + BPI(BPI), Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} + +CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, + BlockFrequencyInfo *BFI, + BranchProbabilityInfo *BPI) + : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), + BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks())), + NumExitBlocks(~0U) {} CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN, - bool AggregateArgs) - : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {} + bool AggregateArgs, BlockFrequencyInfo *BFI, + BranchProbabilityInfo *BPI) + : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), + BPI(BPI), Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {} /// definedInRegion - Return true if the specified value is defined in the /// extracted region. @@ -687,6 +699,51 @@ } } +void CodeExtractor::calculateNewCallTerminatorWeights( + BasicBlock *CodeReplacer, + DenseMap &ExitWeights, + BranchProbabilityInfo *BPI) { + typedef BlockFrequencyInfoImplBase::Distribution Distribution; + typedef BlockFrequencyInfoImplBase::BlockNode BlockNode; + + // Update the branch weights for the exit block. + TerminatorInst *TI = CodeReplacer->getTerminator(); + SmallVector BranchWeights(TI->getNumSuccessors(), 0); + + // Block Frequency distribution with dummy node. + Distribution BranchDist; + + // Add each of the frequencies of the successors. + for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) { + BlockNode ExitNode(i); + uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency(); + if (ExitFreq != 0) + BranchDist.addExit(ExitNode, ExitFreq); + else + BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero()); + } + + // Check for no total weight. + if (BranchDist.Total == 0) + return; + + // Normalize the distribution so that they can fit in unsigned. + BranchDist.normalize(); + + // Create normalized branch weights and set the metadata. + for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) { + const auto &Weight = BranchDist.Weights[I]; + + // Get the weight and update the current BFI. + BranchWeights[Weight.TargetNode.Index] = Weight.Amount; + BranchProbability BP(Weight.Amount, BranchDist.Total); + BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP); + } + TI->setMetadata( + LLVMContext::MD_prof, + MDBuilder(TI->getContext()).createBranchWeights(BranchWeights)); +} + Function *CodeExtractor::extractCodeRegion() { if (!isEligible()) return nullptr; @@ -697,6 +754,19 @@ // block in the region. BasicBlock *header = *Blocks.begin(); + // Calculate the entry frequency of the new function before we change the root + // block. + BlockFrequency EntryFreq; + if (BFI) { + assert(BPI && "Both BPI and BFI are required to preserve profile info"); + for (BasicBlock *Pred : predecessors(header)) { + if (Blocks.count(Pred)) + continue; + EntryFreq += + BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header); + } + } + // If we have to split PHI nodes or the entry block, do so now. severSplitPHINodes(header); @@ -720,12 +790,23 @@ // Find inputs to, outputs from the code region. findInputsOutputs(inputs, outputs); + // Calculate the exit blocks for the extracted region and the total exit + // weights for each of those blocks. + DenseMap ExitWeights; SmallPtrSet ExitBlocks; - for (BasicBlock *Block : Blocks) + for (BasicBlock *Block : Blocks) { for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE; - ++SI) - if (!Blocks.count(*SI)) + ++SI) { + if (!Blocks.count(*SI)) { + // Update the branch weight for this successor. + if (BFI) { + BlockFrequency &BF = ExitWeights[*SI]; + BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI); + } ExitBlocks.insert(*SI); + } + } + } NumExitBlocks = ExitBlocks.size(); // Construct new function based on inputs/outputs & add allocas for all defs. @@ -734,10 +815,23 @@ codeReplacer, oldFunction, oldFunction->getParent()); + // Update the entry count of the function. + if (BFI) { + Optional EntryCount = + BFI->getProfileCountFromFreq(EntryFreq.getFrequency()); + if (EntryCount.hasValue()) + newFunction->setEntryCount(EntryCount.getValue()); + BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency()); + } + emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); moveCodeToFunction(newFunction); + // Update the branch weights for the exit block. + if (BFI && NumExitBlocks > 1) + calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI); + // Loop over all of the PHI nodes in the header block, and change any // references to the old incoming edge to be the new incoming edge. for (BasicBlock::iterator I = header->begin(); isa(I); ++I) { Index: llvm/trunk/test/Transforms/CodeExtractor/ExtractedFnEntryCount.ll =================================================================== --- llvm/trunk/test/Transforms/CodeExtractor/ExtractedFnEntryCount.ll +++ llvm/trunk/test/Transforms/CodeExtractor/ExtractedFnEntryCount.ll @@ -0,0 +1,33 @@ +; RUN: opt < %s -partial-inliner -S | FileCheck %s + +; This test checks to make sure that the CodeExtractor +; properly sets the entry count for the function that is +; extracted based on the root block being extracted and also +; takes into consideration if the block has edges coming from +; a block that is also being extracted. + +define i32 @inlinedFunc(i1 %cond) !prof !1 { +entry: + br i1 %cond, label %if.then, label %return, !prof !2 +if.then: + br i1 %cond, label %if.then, label %return, !prof !3 +return: ; preds = %entry + ret i32 0 +} + + +define internal i32 @dummyCaller(i1 %cond) !prof !1 { +entry: + %val = call i32 @inlinedFunc(i1 %cond) + ret i32 %val +} + +; CHECK: @inlinedFunc.1_if.then(i1 %cond) !prof [[COUNT1:![0-9]+]] + + +!llvm.module.flags = !{!0} +; CHECK: [[COUNT1]] = !{!"function_entry_count", i64 250} +!0 = !{i32 1, !"MaxFunctionCount", i32 1000} +!1 = !{!"function_entry_count", i64 1000} +!2 = !{!"branch_weights", i32 250, i32 750} +!3 = !{!"branch_weights", i32 125, i32 125} Index: llvm/trunk/test/Transforms/CodeExtractor/MultipleExitBranchProb.ll =================================================================== --- llvm/trunk/test/Transforms/CodeExtractor/MultipleExitBranchProb.ll +++ llvm/trunk/test/Transforms/CodeExtractor/MultipleExitBranchProb.ll @@ -0,0 +1,34 @@ +; RUN: opt < %s -partial-inliner -S | FileCheck %s + +; This test checks to make sure that CodeExtractor updates +; the exit branch probabilities for multiple exit blocks. + +define i32 @inlinedFunc(i1 %cond) !prof !1 { +entry: + br i1 %cond, label %if.then, label %return, !prof !2 +if.then: + br i1 %cond, label %return, label %return.2, !prof !3 +return.2: + ret i32 10 +return: ; preds = %entry + ret i32 0 +} + + +define internal i32 @dummyCaller(i1 %cond) !prof !1 { +entry: +%val = call i32 @inlinedFunc(i1 %cond) +ret i32 %val + +; CHECK-LABEL: @dummyCaller +; CHECK: call +; CHECK-NEXT: br i1 {{.*}}!prof [[COUNT1:![0-9]+]] +} + +!llvm.module.flags = !{!0} +!0 = !{i32 1, !"MaxFunctionCount", i32 10000} +!1 = !{!"function_entry_count", i64 10000} +!2 = !{!"branch_weights", i32 5, i32 5} +!3 = !{!"branch_weights", i32 4, i32 1} + +; CHECK: [[COUNT1]] = !{!"branch_weights", i32 8, i32 31}