Index: include/llvm/Analysis/BlockFrequencyInfo.h =================================================================== --- include/llvm/Analysis/BlockFrequencyInfo.h +++ include/llvm/Analysis/BlockFrequencyInfo.h @@ -61,6 +61,11 @@ /// the enclosing function's count (if available) and returns the value. Optional getBlockProfileCount(const BasicBlock *BB) const; + /// \brief Returns the estimated profile count of \p Freq. + /// This uses the frequency \p Freq and multiplies it by + /// the enclosing function's count (if available) and returns the value. + Optional getProfileCountFromFreq(uint64_t Freq) const; + // Set the frequency of the given basic block. void setBlockFreq(const BasicBlock *BB, uint64_t Freq); Index: include/llvm/Analysis/BlockFrequencyInfoImpl.h =================================================================== --- include/llvm/Analysis/BlockFrequencyInfoImpl.h +++ include/llvm/Analysis/BlockFrequencyInfoImpl.h @@ -482,6 +482,8 @@ BlockFrequency getBlockFreq(const BlockNode &Node) const; Optional getBlockProfileCount(const Function &F, const BlockNode &Node) const; + Optional getProfileCountFromFreq(const Function &F, + uint64_t Freq) const; void setBlockFreq(const BlockNode &Node, uint64_t Freq); @@ -925,6 +927,10 @@ const BlockT *BB) const { return BlockFrequencyInfoImplBase::getBlockProfileCount(F, getNode(BB)); } + Optional getProfileCountFromFreq(const Function &F, + uint64_t Freq) const { + return BlockFrequencyInfoImplBase::getProfileCountFromFreq(F, Freq); + } void setBlockFreq(const BlockT *BB, uint64_t Freq); Scaled64 getFloatingBlockFreq(const BlockT *BB) const { return BlockFrequencyInfoImplBase::getFloatingBlockFreq(getNode(BB)); Index: include/llvm/CodeGen/MachineBlockFrequencyInfo.h =================================================================== --- include/llvm/CodeGen/MachineBlockFrequencyInfo.h +++ include/llvm/CodeGen/MachineBlockFrequencyInfo.h @@ -52,6 +52,7 @@ BlockFrequency getBlockFreq(const MachineBasicBlock *MBB) const; Optional getBlockProfileCount(const MachineBasicBlock *MBB) const; + Optional getProfileCountFromFreq(uint64_t Freq) const; const MachineFunction *getFunction() const; const MachineBranchProbabilityInfo *getMBPI() const; Index: include/llvm/Transforms/Utils/CodeExtractor.h =================================================================== --- include/llvm/Transforms/Utils/CodeExtractor.h +++ include/llvm/Transforms/Utils/CodeExtractor.h @@ -20,6 +20,9 @@ namespace llvm { template class ArrayRef; class BasicBlock; + class BlockFrequency; + class BlockFrequencyInfo; + class BranchProbabilityInfo; class DominatorTree; class Function; class Loop; @@ -47,6 +50,8 @@ // Various bits of state computed on construction. DominatorTree *const DT; const bool AggregateArgs; + BlockFrequencyInfo *BFI; + BranchProbabilityInfo *BPI; // Bits of intermediate state computed at various phases of extraction. SetVector Blocks; @@ -64,7 +69,9 @@ /// /// In this formation, we don't require a dominator tree. The given basic /// block is set up for extraction. - CodeExtractor(BasicBlock *BB, bool AggregateArgs = false); + CodeExtractor(BasicBlock *BB, bool AggregateArgs = false, + BlockFrequencyInfo *BFI = nullptr, + BranchProbabilityInfo *BPI = nullptr); /// \brief Create a code extractor for a sequence of blocks. /// @@ -73,20 +80,24 @@ /// sequence out into its new function. When a DominatorTree is also given, /// extra checking and transformations are enabled. CodeExtractor(ArrayRef BBs, DominatorTree *DT = nullptr, - bool AggregateArgs = false); + bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr, + BranchProbabilityInfo *BPI = nullptr); /// \brief Create a code extractor for a loop body. /// /// Behaves just like the generic code sequence constructor, but uses the /// block sequence of the loop. - CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false); + CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false, + BlockFrequencyInfo *BFI = nullptr, + BranchProbabilityInfo *BPI = nullptr); /// \brief Create a code extractor for a region node. /// /// Behaves just like the generic code sequence constructor, but uses the /// block sequence of the region node passed in. CodeExtractor(DominatorTree &DT, const RegionNode &RN, - bool AggregateArgs = false); + bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr, + BranchProbabilityInfo *BPI = nullptr); /// \brief Perform the extraction, returning the new function. /// @@ -122,6 +133,11 @@ void moveCodeToFunction(Function *newFunction); + void calculateNewCallTerminatorWeights( + BasicBlock *CodeReplacer, + DenseMap &ExitWeights, + BranchProbabilityInfo *BPI); + void emitCallAndSwitchStatement(Function *newFunction, BasicBlock *newHeader, ValueSet &inputs, Index: lib/Analysis/BlockFrequencyInfo.cpp =================================================================== --- lib/Analysis/BlockFrequencyInfo.cpp +++ lib/Analysis/BlockFrequencyInfo.cpp @@ -162,6 +162,13 @@ return BFI->getBlockProfileCount(*getFunction(), BB); } +Optional +BlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const { + if (!BFI) + return None; + return BFI->getProfileCountFromFreq(*getFunction(), Freq); +} + void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, uint64_t Freq) { assert(BFI && "Expected analysis to be available"); BFI->setBlockFreq(BB, Freq); Index: lib/Analysis/BlockFrequencyInfoImpl.cpp =================================================================== --- lib/Analysis/BlockFrequencyInfoImpl.cpp +++ lib/Analysis/BlockFrequencyInfoImpl.cpp @@ -533,12 +533,18 @@ Optional BlockFrequencyInfoImplBase::getBlockProfileCount(const Function &F, const BlockNode &Node) const { + return getProfileCountFromFreq(F, getBlockFreq(Node).getFrequency()); +} + +Optional +BlockFrequencyInfoImplBase::getProfileCountFromFreq(const Function &F, + uint64_t Freq) const { auto EntryCount = F.getEntryCount(); if (!EntryCount) return None; // Use 128 bit APInt to do the arithmetic to avoid overflow. APInt BlockCount(128, EntryCount.getValue()); - APInt BlockFreq(128, getBlockFreq(Node).getFrequency()); + APInt BlockFreq(128, Freq); APInt EntryFreq(128, getEntryFreq()); BlockCount *= BlockFreq; BlockCount = BlockCount.udiv(EntryFreq); Index: lib/CodeGen/MachineBlockFrequencyInfo.cpp =================================================================== --- lib/CodeGen/MachineBlockFrequencyInfo.cpp +++ lib/CodeGen/MachineBlockFrequencyInfo.cpp @@ -175,6 +175,12 @@ return MBFI ? MBFI->getBlockProfileCount(*F, MBB) : None; } +Optional +MachineBlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const { + const Function *F = MBFI->getFunction()->getFunction(); + return MBFI ? MBFI->getProfileCountFromFreq(*F, Freq) : None; +} + const MachineFunction *MachineBlockFrequencyInfo::getFunction() const { return MBFI ? MBFI->getFunction() : nullptr; } Index: lib/Transforms/IPO/PartialInlining.cpp =================================================================== --- lib/Transforms/IPO/PartialInlining.cpp +++ lib/Transforms/IPO/PartialInlining.cpp @@ -14,6 +14,8 @@ #include "llvm/Transforms/IPO/PartialInlining.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" @@ -29,13 +31,18 @@ STATISTIC(NumPartialInlined, "Number of functions partially inlined"); namespace { +typedef std::function( + Function &)> + GetProfileDataFn; struct PartialInlinerImpl { - PartialInlinerImpl(InlineFunctionInfo IFI) : IFI(IFI) {} + PartialInlinerImpl(InlineFunctionInfo IFI, GetProfileDataFn GetProfileInfo) + : IFI(IFI), GetProfileInfo(GetProfileInfo) {} bool run(Module &M); Function *unswitchFunction(Function *F); private: InlineFunctionInfo IFI; + GetProfileDataFn GetProfileInfo; }; struct PartialInlinerLegacyPass : public ModulePass { static char ID; // Pass identification, replacement for typeid @@ -45,6 +52,8 @@ void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired(); + AU.addRequired(); + AU.addRequired(); } bool runOnModule(Module &M) override { if (skipModule(M)) @@ -55,8 +64,14 @@ [&ACT](Function &F) -> AssumptionCache & { return ACT->getAssumptionCache(F); }; + GetProfileDataFn GetProfileData = [this](Function &F) + -> std::pair { + auto *BFI = &getAnalysis(F).getBFI(); + auto *BPI = &getAnalysis(F).getBPI(); + return std::make_pair(BFI, BPI); + }; InlineFunctionInfo IFI(nullptr, &GetAssumptionCache); - return PartialInlinerImpl(IFI).run(M); + return PartialInlinerImpl(IFI, GetProfileData).run(M); } }; } @@ -133,9 +148,13 @@ DominatorTree DT; DT.recalculate(*DuplicateFunction); + auto ProfileInfo = GetProfileInfo(*DuplicateFunction); + // Extract the body of the if. Function *ExtractedFunction = - CodeExtractor(ToExtract, &DT).extractCodeRegion(); + CodeExtractor(ToExtract, &DT, /*AggregateArgs*/false, ProfileInfo.first, + ProfileInfo.second) + .extractCodeRegion(); // Inline the top-level if test into all callers. std::vector Users(DuplicateFunction->user_begin(), @@ -181,8 +200,8 @@ if (Recursive) continue; - if (Function *newFunc = unswitchFunction(CurrFunc)) { - Worklist.push_back(newFunc); + if (Function *NewFunc = unswitchFunction(CurrFunc)) { + Worklist.push_back(NewFunc); Changed = true; } } @@ -194,6 +213,8 @@ INITIALIZE_PASS_BEGIN(PartialInlinerLegacyPass, "partial-inliner", "Partial Inliner", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) INITIALIZE_PASS_END(PartialInlinerLegacyPass, "partial-inliner", "Partial Inliner", false, false) @@ -208,8 +229,14 @@ [&FAM](Function &F) -> AssumptionCache & { return FAM.getResult(F); }; + GetProfileDataFn GetProfileData = [&FAM]( + Function &F) -> std::pair { + auto *BFI = &FAM.getResult(F); + auto *BPI = &FAM.getResult(F); + return std::make_pair(BFI, BPI); + }; InlineFunctionInfo IFI(nullptr, &GetAssumptionCache); - if (PartialInlinerImpl(IFI).run(M)) + if (PartialInlinerImpl(IFI, GetProfileData).run(M)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); } Index: lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- lib/Transforms/Utils/CodeExtractor.cpp +++ lib/Transforms/Utils/CodeExtractor.cpp @@ -17,6 +17,9 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/RegionInfo.h" #include "llvm/Analysis/RegionIterator.h" @@ -26,9 +29,11 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/Pass.h" +#include "llvm/Support/BlockFrequency.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" @@ -119,23 +124,30 @@ return buildExtractionBlockSet(R.block_begin(), R.block_end()); } -CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs) - : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} +CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs, + BlockFrequencyInfo *BFI, + BranchProbabilityInfo *BPI) + : DT(nullptr), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), + BPI(BPI), Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} CodeExtractor::CodeExtractor(ArrayRef BBs, DominatorTree *DT, - bool AggregateArgs) - : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} - -CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs) - : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {} + bool AggregateArgs, BlockFrequencyInfo *BFI, + BranchProbabilityInfo *BPI) + : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), + BPI(BPI), Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} + +CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, + BlockFrequencyInfo *BFI, + BranchProbabilityInfo *BPI) + : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), + BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks())), + NumExitBlocks(~0U) {} CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN, - bool AggregateArgs) - : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {} + bool AggregateArgs, BlockFrequencyInfo *BFI, + BranchProbabilityInfo *BPI) + : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), + BPI(BPI), Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {} /// definedInRegion - Return true if the specified value is defined in the /// extracted region. @@ -672,16 +684,74 @@ } } +void CodeExtractor::calculateNewCallTerminatorWeights( + BasicBlock *CodeReplacer, + DenseMap &ExitWeights, + BranchProbabilityInfo *BPI) { + typedef BlockFrequencyInfoImplBase::Distribution Distribution; + typedef BlockFrequencyInfoImplBase::BlockNode BlockNode; + + // Update the branch weights for the exit block. + TerminatorInst *TI = CodeReplacer->getTerminator(); + SmallVector BranchWeights(TI->getNumSuccessors(), 0); + + // Block Frequency distribution with dummy node. + Distribution BranchDist; + + // Add each of the frequencies of the successors. + for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) { + BlockNode ExitNode(i); + uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency(); + if (ExitFreq != 0) + BranchDist.addExit(ExitNode, ExitFreq); + else + BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero()); + } + + // Check for no total weight. + if (BranchDist.Total == 0) + return; + + // Normalize the distribution so that they can fit in unsigned. + BranchDist.normalize(); + + // Create normalized branch weights and set the metadata. + for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) { + const auto &Weight = BranchDist.Weights[I]; + + // Get the weight and update the current BFI. + BranchWeights[Weight.TargetNode.Index] = Weight.Amount; + BranchProbability BP(Weight.Amount, BranchDist.Total); + BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP); + } + TI->setMetadata( + LLVMContext::MD_prof, + MDBuilder(TI->getContext()).createBranchWeights(BranchWeights)); +} + Function *CodeExtractor::extractCodeRegion() { if (!isEligible()) return nullptr; ValueSet inputs, outputs; // Assumption: this is a single-entry code region, and the header is the first // block in the region. BasicBlock *header = *Blocks.begin(); + // Calculate the entry frequency of the new function before we change the root + // block. + BlockFrequency EntryFreq; + if (BFI) { + assert(BPI && "Both BPI and BFI are required to preserve profile info"); + for (BasicBlock *Pred : predecessors(header)) { + if (Blocks.count(Pred)) + continue; + EntryFreq += + BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header); + } + } + // If we have to split PHI nodes or the entry block, do so now. severSplitPHINodes(header); @@ -705,12 +775,23 @@ // Find inputs to, outputs from the code region. findInputsOutputs(inputs, outputs); + // Calculate the exit blocks for the extracted region and the total exit + // weights for each of those blocks. + DenseMap ExitWeights; SmallPtrSet ExitBlocks; - for (BasicBlock *Block : Blocks) + for (BasicBlock *Block : Blocks) { for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE; - ++SI) - if (!Blocks.count(*SI)) + ++SI) { + if (!Blocks.count(*SI)) { + // Update the branch weight for this successor. + if (BFI) { + BlockFrequency &BF = ExitWeights[*SI]; + BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI); + } ExitBlocks.insert(*SI); + } + } + } NumExitBlocks = ExitBlocks.size(); // Construct new function based on inputs/outputs & add allocas for all defs. @@ -719,10 +800,23 @@ codeReplacer, oldFunction, oldFunction->getParent()); + // Update the entry count of the function. + if (BFI) { + Optional EntryCount = + BFI->getProfileCountFromFreq(EntryFreq.getFrequency()); + if (EntryCount.hasValue()) + newFunction->setEntryCount(EntryCount.getValue()); + BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency()); + } + emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); moveCodeToFunction(newFunction); + // Update the branch weights for the exit block. + if (BFI && NumExitBlocks > 1) + calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI); + // Loop over all of the PHI nodes in the header block, and change any // references to the old incoming edge to be the new incoming edge. for (BasicBlock::iterator I = header->begin(); isa(I); ++I) { Index: test/Transforms/CodeExtractor/ExtractedFnEntryCount.ll =================================================================== --- /dev/null +++ test/Transforms/CodeExtractor/ExtractedFnEntryCount.ll @@ -0,0 +1,33 @@ +; RUN: opt < %s -partial-inliner -S | FileCheck %s + +; This test checks to make sure that the CodeExtractor +; properly sets the entry count for the function that is +; extracted based on the root block being extracted and also +; takes into consideration if the block has edges coming from +; a block that is also being extracted. + +define i32 @inlinedFunc(i1 %cond) !prof !1 { +entry: + br i1 %cond, label %if.then, label %return, !prof !2 +if.then: + br i1 %cond, label %if.then, label %return, !prof !3 +return: ; preds = %entry + ret i32 0 +} + + +define internal i32 @dummyCaller(i1 %cond) !prof !1 { +entry: + %val = call i32 @inlinedFunc(i1 %cond) + ret i32 %val +} + +; CHECK: @inlinedFunc.1_if.then(i1 %cond) !prof [[COUNT1:![0-9]+]] + + +!llvm.module.flags = !{!0} +; CHECK: [[COUNT1]] = !{!"function_entry_count", i64 250} +!0 = !{i32 1, !"MaxFunctionCount", i32 1000} +!1 = !{!"function_entry_count", i64 1000} +!2 = !{!"branch_weights", i32 250, i32 750} +!3 = !{!"branch_weights", i32 125, i32 125} Index: test/Transforms/CodeExtractor/MultipleExitBranchProb.ll =================================================================== --- /dev/null +++ test/Transforms/CodeExtractor/MultipleExitBranchProb.ll @@ -0,0 +1,34 @@ +; RUN: opt < %s -partial-inliner -S | FileCheck %s + +; This test checks to make sure that CodeExtractor updates +; the exit branch probabilities for multiple exit blocks. + +define i32 @inlinedFunc(i1 %cond) !prof !1 { +entry: + br i1 %cond, label %if.then, label %return, !prof !2 +if.then: + br i1 %cond, label %return, label %return.2, !prof !3 +return.2: + ret i32 10 +return: ; preds = %entry + ret i32 0 +} + + +define internal i32 @dummyCaller(i1 %cond) !prof !1 { +entry: +%val = call i32 @inlinedFunc(i1 %cond) +ret i32 %val + +; CHECK-LABEL: @dummyCaller +; CHECK: call +; CHECK-NEXT: br i1 {{.*}}!prof [[COUNT1:![0-9]+]] +} + +!llvm.module.flags = !{!0} +!0 = !{i32 1, !"MaxFunctionCount", i32 10000} +!1 = !{!"function_entry_count", i64 10000} +!2 = !{!"branch_weights", i32 5, i32 5} +!3 = !{!"branch_weights", i32 4, i32 1} + +; CHECK: [[COUNT1]] = !{!"branch_weights", i32 8, i32 31}