Index: include/llvm/Analysis/BlockFrequencyInfo.h =================================================================== --- include/llvm/Analysis/BlockFrequencyInfo.h +++ include/llvm/Analysis/BlockFrequencyInfo.h @@ -61,6 +61,11 @@ /// the enclosing function's count (if available) and returns the value. Optional getBlockProfileCount(const BasicBlock *BB) const; + /// \brief Returns the estimated profile count of \p Freq. + /// This uses the frequency \p Freq and multiplies it by + /// the enclosing function's count (if available) and returns the value. + Optional getProfileCountFromFreq(uint64_t Freq) const; + // Set the frequency of the given basic block. void setBlockFreq(const BasicBlock *BB, uint64_t Freq); Index: include/llvm/Analysis/BlockFrequencyInfoImpl.h =================================================================== --- include/llvm/Analysis/BlockFrequencyInfoImpl.h +++ include/llvm/Analysis/BlockFrequencyInfoImpl.h @@ -482,6 +482,8 @@ BlockFrequency getBlockFreq(const BlockNode &Node) const; Optional getBlockProfileCount(const Function &F, const BlockNode &Node) const; + Optional getProfileCountFromFreq(const Function &F, + uint64_t Freq) const; void setBlockFreq(const BlockNode &Node, uint64_t Freq); @@ -925,6 +927,10 @@ const BlockT *BB) const { return BlockFrequencyInfoImplBase::getBlockProfileCount(F, getNode(BB)); } + Optional getProfileCountFromFreq(const Function &F, + uint64_t Freq) const { + return BlockFrequencyInfoImplBase::getProfileCountFromFreq(F, Freq); + } void setBlockFreq(const BlockT *BB, uint64_t Freq); Scaled64 getFloatingBlockFreq(const BlockT *BB) const { return BlockFrequencyInfoImplBase::getFloatingBlockFreq(getNode(BB)); Index: include/llvm/CodeGen/MachineBlockFrequencyInfo.h =================================================================== --- include/llvm/CodeGen/MachineBlockFrequencyInfo.h +++ include/llvm/CodeGen/MachineBlockFrequencyInfo.h @@ -52,6 +52,8 @@ BlockFrequency getBlockFreq(const MachineBasicBlock *MBB) const; Optional getBlockProfileCount(const MachineBasicBlock *MBB) const; + Optional getProfileCountFromFreq(uint64_t Freq) const; + const MachineFunction *getFunction() const; const MachineBranchProbabilityInfo *getMBPI() const; Index: include/llvm/Transforms/Utils/CodeExtractor.h =================================================================== --- include/llvm/Transforms/Utils/CodeExtractor.h +++ include/llvm/Transforms/Utils/CodeExtractor.h @@ -20,11 +20,15 @@ namespace llvm { template class ArrayRef; class BasicBlock; + class BlockFrequency; + class BlockFrequencyInfo; + class BranchProbabilityInfo; class DominatorTree; class Function; class Loop; class Module; class RegionNode; + class TargetTransformInfo; class Type; class Value; @@ -47,6 +51,7 @@ // Various bits of state computed on construction. DominatorTree *const DT; const bool AggregateArgs; + BlockFrequencyInfo *BFI; // Bits of intermediate state computed at various phases of extraction. SetVector Blocks; @@ -58,7 +63,8 @@ /// /// In this formation, we don't require a dominator tree. The given basic /// block is set up for extraction. - CodeExtractor(BasicBlock *BB, bool AggregateArgs = false); + CodeExtractor(BasicBlock *BB, bool AggregateArgs = false, + BlockFrequencyInfo *BFI = nullptr); /// \brief Create a code extractor for a sequence of blocks. /// @@ -67,20 +73,23 @@ /// sequence out into its new function. When a DominatorTree is also given, /// extra checking and transformations are enabled. CodeExtractor(ArrayRef BBs, DominatorTree *DT = nullptr, - bool AggregateArgs = false); + bool AggregateArgs = false, + BlockFrequencyInfo *BFI = nullptr); /// \brief Create a code extractor for a loop body. /// /// Behaves just like the generic code sequence constructor, but uses the /// block sequence of the loop. - CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false); + CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false, + BlockFrequencyInfo *BFI = nullptr); /// \brief Create a code extractor for a region node. /// /// Behaves just like the generic code sequence constructor, but uses the /// block sequence of the region node passed in. CodeExtractor(DominatorTree &DT, const RegionNode &RN, - bool AggregateArgs = false); + bool AggregateArgs = false, + BlockFrequencyInfo *BFI = nullptr); /// \brief Perform the extraction, returning the new function. /// @@ -115,6 +124,10 @@ Function *oldFunction, Module *M); void moveCodeToFunction(Function *newFunction); + void updateNewCallTerminatorWeights(BasicBlock *CodeReplacer, + DenseMap &ExitWeights, + BranchProbabilityInfo *BPI); void emitCallAndSwitchStatement(Function *newFunction, BasicBlock *newHeader, Index: lib/Analysis/BlockFrequencyInfo.cpp =================================================================== --- lib/Analysis/BlockFrequencyInfo.cpp +++ lib/Analysis/BlockFrequencyInfo.cpp @@ -162,6 +162,14 @@ return BFI->getBlockProfileCount(*getFunction(), BB); } +Optional +BlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const { + if (!BFI) + return None; + + return BFI->getProfileCountFromFreq(*getFunction(), Freq); +} + void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, uint64_t Freq) { assert(BFI && "Expected analysis to be available"); BFI->setBlockFreq(BB, Freq); Index: lib/Analysis/BlockFrequencyInfoImpl.cpp =================================================================== --- lib/Analysis/BlockFrequencyInfoImpl.cpp +++ lib/Analysis/BlockFrequencyInfoImpl.cpp @@ -533,12 +533,18 @@ Optional BlockFrequencyInfoImplBase::getBlockProfileCount(const Function &F, const BlockNode &Node) const { + return getProfileCountFromFreq(F, getBlockFreq(Node).getFrequency()); +} + +Optional +BlockFrequencyInfoImplBase::getProfileCountFromFreq(const Function &F, + uint64_t Freq) const { auto EntryCount = F.getEntryCount(); if (!EntryCount) return None; // Use 128 bit APInt to do the arithmetic to avoid overflow. APInt BlockCount(128, EntryCount.getValue()); - APInt BlockFreq(128, getBlockFreq(Node).getFrequency()); + APInt BlockFreq(128, Freq); APInt EntryFreq(128, getEntryFreq()); BlockCount *= BlockFreq; BlockCount = BlockCount.udiv(EntryFreq); Index: lib/CodeGen/MachineBlockFrequencyInfo.cpp =================================================================== --- lib/CodeGen/MachineBlockFrequencyInfo.cpp +++ lib/CodeGen/MachineBlockFrequencyInfo.cpp @@ -175,6 +175,12 @@ return MBFI ? MBFI->getBlockProfileCount(*F, MBB) : None; } +Optional MachineBlockFrequencyInfo::getProfileCountFromFreq( + uint64_t Freq) const { + const Function *F = MBFI->getFunction()->getFunction(); + return MBFI ? MBFI->getProfileCountFromFreq(*F, Freq) : None; +} + const MachineFunction *MachineBlockFrequencyInfo::getFunction() const { return MBFI ? MBFI->getFunction() : nullptr; } Index: lib/Transforms/IPO/PartialInlining.cpp =================================================================== --- lib/Transforms/IPO/PartialInlining.cpp +++ lib/Transforms/IPO/PartialInlining.cpp @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +#include #include "llvm/Transforms/IPO/PartialInlining.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/CFG.h" @@ -29,13 +30,16 @@ STATISTIC(NumPartialInlined, "Number of functions partially inlined"); namespace { +typedef std::function GetBlockFreqFn; struct PartialInlinerImpl { - PartialInlinerImpl(InlineFunctionInfo IFI) : IFI(IFI) {} + PartialInlinerImpl(InlineFunctionInfo IFI, GetBlockFreqFn GetBlockFreq) + : IFI(IFI), GetBlockFrequency(GetBlockFreq) {} bool run(Module &M); Function *unswitchFunction(Function *F); private: InlineFunctionInfo IFI; + GetBlockFreqFn GetBlockFrequency; }; struct PartialInlinerLegacyPass : public ModulePass { static char ID; // Pass identification, replacement for typeid @@ -45,6 +49,7 @@ void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired(); + AU.addRequired(); } bool runOnModule(Module &M) override { if (skipModule(M)) @@ -55,8 +60,12 @@ [&ACT](Function &F) -> AssumptionCache & { return ACT->getAssumptionCache(F); }; + GetBlockFreqFn GetBlockFreq = + [this](Function &F) -> BlockFrequencyInfo *{ + return &getAnalysis(F).getBFI(); + }; InlineFunctionInfo IFI(nullptr, &GetAssumptionCache); - return PartialInlinerImpl(IFI).run(M); + return PartialInlinerImpl(IFI, GetBlockFreq).run(M); } }; } @@ -133,9 +142,11 @@ DominatorTree DT; DT.recalculate(*DuplicateFunction); + BlockFrequencyInfo *BFI = GetBlockFrequency(*DuplicateFunction); + // Extract the body of the if. Function *ExtractedFunction = - CodeExtractor(ToExtract, &DT).extractCodeRegion(); + CodeExtractor(ToExtract, &DT, false, BFI).extractCodeRegion(); // Inline the top-level if test into all callers. std::vector Users(DuplicateFunction->user_begin(), @@ -181,8 +192,8 @@ if (Recursive) continue; - if (Function *newFunc = unswitchFunction(CurrFunc)) { - Worklist.push_back(newFunc); + if (Function *NewFunc = unswitchFunction(CurrFunc)) { + Worklist.push_back(NewFunc); Changed = true; } } @@ -208,8 +219,12 @@ [&FAM](Function &F) -> AssumptionCache & { return FAM.getResult(F); }; + GetBlockFreqFn GetBlockFreq = + [&FAM](Function &F) -> BlockFrequencyInfo * { + return &FAM.getResult(F); + }; InlineFunctionInfo IFI(nullptr, &GetAssumptionCache); - if (PartialInlinerImpl(IFI).run(M)) + if (PartialInlinerImpl(IFI, GetBlockFreq).run(M)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); } Index: lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- lib/Transforms/Utils/CodeExtractor.cpp +++ lib/Transforms/Utils/CodeExtractor.cpp @@ -17,18 +17,24 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/RegionInfo.h" #include "llvm/Analysis/RegionIterator.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/Pass.h" +#include "llvm/Support/BlockFrequency.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" @@ -119,23 +125,25 @@ return buildExtractionBlockSet(R.block_begin(), R.block_end()); } -CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs) - : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt), +CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs, + BlockFrequencyInfo *BFI) + : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt), BFI(BFI), Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} CodeExtractor::CodeExtractor(ArrayRef BBs, DominatorTree *DT, - bool AggregateArgs) - : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), + bool AggregateArgs, BlockFrequencyInfo *BFI) + : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), BFI(BFI), Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} -CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs) - : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), +CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, + BlockFrequencyInfo *BFI) + : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), BFI(BFI), Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {} CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN, - bool AggregateArgs) - : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {} + bool AggregateArgs, BlockFrequencyInfo *BFI) + : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), + Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {} /// definedInRegion - Return true if the specified value is defined in the /// extracted region. @@ -339,7 +347,7 @@ // If the old function is no-throw, so is the new one. if (oldFunction->doesNotThrow()) newFunction->setDoesNotThrow(); - + newFunction->getBasicBlockList().push_back(newRootNode); // Create an iterator to name all of the arguments we inserted. @@ -413,7 +421,7 @@ // Emit a call to the new function, passing in: *pointer to struct (if // aggregating parameters), or plan inputs and allocated memory for outputs std::vector params, StructValues, ReloadOutputs, Reloads; - + LLVMContext &Context = newFunction->getContext(); // Add inputs as params, or to be filled into the struct @@ -578,12 +586,12 @@ if (DT) { DominatesDef = DT->dominates(DefBlock, OldTarget); - + // If the output value is used by a phi in the target block, // then we need to test for dominance of the phi's predecessor // instead. Unfortunately, this a little complicated since we // have already rewritten uses of the value to uses of the reload. - BasicBlock* pred = FindPhiPredForUseInBlock(Reloads[out], + BasicBlock* pred = FindPhiPredForUseInBlock(Reloads[out], OldTarget); if (pred && DT && DT->dominates(DefBlock, pred)) DominatesDef = true; @@ -630,7 +638,7 @@ } else { // Otherwise we must have code extracted an unwind or something, just // return whatever we want. - ReturnInst::Create(Context, + ReturnInst::Create(Context, Constant::getNullValue(OldFnRetTy), TheSwitch); } @@ -672,6 +680,53 @@ } } +void CodeExtractor::updateNewCallTerminatorWeights( + BasicBlock *CodeReplacer, + DenseMap &ExitWeights, + BranchProbabilityInfo *BPI) { + // Update the branch weights for the exit block. + TerminatorInst *TI = CodeReplacer->getTerminator(); + + // Make sure that the weights are valid, we cant have a weight of zero. + bool HasValidWeights = true; + for (auto &Weight : ExitWeights) { + if (Weight.second.getFrequency() == 0) { + HasValidWeights = false; + break; + } + } + if (!HasValidWeights) + return; + + // Block Frequency distribution with dummy node. + BlockFrequencyInfoImplBase::Distribution BranchDist; + + // Add each of the frequencies of the successors + for (unsigned I = 0, E = TI->getNumSuccessors(); I < E; ++I) { + BlockFrequencyInfoImplBase::BlockNode DummyNode(I); + BranchDist.addExit(DummyNode, + ExitWeights[TI->getSuccessor(I)].getFrequency()); + } + + // Normalize the distribution so that they can fit in unsigned. + BranchDist.normalize(); + + // Create normalized branch weights and set the metadata. + SmallVector BranchWeights; + BranchWeights.reserve(TI->getNumSuccessors()); + for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) { + const auto &Weight = BranchDist.Weights[I]; + + // Get the weight and update the current BFI. + BranchWeights.push_back(Weight.Amount); + BranchProbability BP(Weight.Amount, BranchDist.Total); + BPI->setEdgeProbability(CodeReplacer, I, BP); + } + TI->setMetadata( + LLVMContext::MD_prof, + MDBuilder(TI->getContext()).createBranchWeights(BranchWeights)); +} + Function *CodeExtractor::extractCodeRegion() { if (!isEligible()) return nullptr; @@ -682,6 +737,24 @@ // block in the region. BasicBlock *header = *Blocks.begin(); + // Get non-const in case we modify the exit probabilities when the exit count + // is + // greater than 1. + BranchProbabilityInfo *BPI = nullptr; + if (BFI) + BPI = const_cast(BFI->getBPI()); + + // Calculate the entry of the new function before we change the root block. + BlockFrequency EntryFreq; + if (BFI) { + for (BasicBlock *Pred : predecessors(header)) { + if (Blocks.count(Pred)) + continue; + EntryFreq += + BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header); + } + } + // If we have to split PHI nodes or the entry block, do so now. severSplitPHINodes(header); @@ -692,25 +765,34 @@ Function *oldFunction = header->getParent(); // This takes place of the original loop - BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), + BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), "codeRepl", oldFunction, header); // The new function needs a root node because other nodes can branch to the // head of the region, but the entry node of a function cannot have preds. - BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(), + BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(), "newFuncRoot"); newFuncRoot->getInstList().push_back(BranchInst::Create(header)); // Find inputs to, outputs from the code region. findInputsOutputs(inputs, outputs); + DenseMap ExitWeights; SmallPtrSet ExitBlocks; - for (BasicBlock *Block : Blocks) + for (BasicBlock *Block : Blocks) { for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE; - ++SI) - if (!Blocks.count(*SI)) + ++SI) { + if (!Blocks.count(*SI)) { + // Update the branch weight for this successor. + if (BPI) { + BlockFrequency &BF = ExitWeights[*SI]; + BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI); + } ExitBlocks.insert(*SI); + } + } + } NumExitBlocks = ExitBlocks.size(); // Construct new function based on inputs/outputs & add allocas for all defs. @@ -719,10 +801,22 @@ codeReplacer, oldFunction, oldFunction->getParent()); + // Update the entry count of the function. + if (BFI && EntryFreq.getFrequency() > 0) { + Optional OptEntryCount = + BFI->getProfileCountFromFreq(EntryFreq.getFrequency()); + newFunction->setEntryCount( OptEntryCount.getValue()); + BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency()); + } + emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); moveCodeToFunction(newFunction); + // Update the branch weights for the exit block. + if (BPI && NumExitBlocks > 1) + updateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI); + // Loop over all of the PHI nodes in the header block, and change any // references to the old incoming edge to be the new incoming edge. for (BasicBlock::iterator I = header->begin(); isa(I); ++I) { @@ -760,7 +854,7 @@ // cerr << "OLD FUNCTION: " << *oldFunction; // verifyFunction(*oldFunction); - DEBUG(if (verifyFunction(*newFunction)) + DEBUG(if (verifyFunction(*newFunction)) report_fatal_error("verifyFunction failed!")); return newFunction; -} +} \ No newline at end of file Index: test/Transforms/CodeExtractor/2016-07-20-ExtractedFnEntryCount.ll =================================================================== --- test/Transforms/CodeExtractor/2016-07-20-ExtractedFnEntryCount.ll +++ test/Transforms/CodeExtractor/2016-07-20-ExtractedFnEntryCount.ll @@ -0,0 +1,33 @@ +; RUN: opt < %s -partial-inliner -S | FileCheck %s + +; This test checks to make sure that the CodeExtractor +; properly sets the entry count for the function that is +; extracted based on the root block being extracted and also +; takes into consideration if the block has edges coming from +; a block that is also being extracted. + +define i32 @inlinedFunc(i1 %cond) !prof !1 { +entry: + br i1 %cond, label %if.then, label %return, !prof !2 +if.then: + br i1 %cond, label %if.then, label %return, !prof !3 +return: ; preds = %entry + ret i32 0 +} + + +define internal i32 @dummyCaller(i1 %cond) !prof !1 { +entry: + %val = call i32 @inlinedFunc(i1 %cond) + ret i32 %val +} + +; CHECK: @inlinedFunc.1_if.then(i1 %cond) !prof [[COUNT1:![0-9]+]] + + +!llvm.module.flags = !{!0} +; CHECK: [[COUNT1]] = !{!"function_entry_count", i64 250} +!0 = !{i32 1, !"MaxFunctionCount", i32 1000} +!1 = !{!"function_entry_count", i64 1000} +!2 = !{!"branch_weights", i32 250, i32 750} +!3 = !{!"branch_weights", i32 125, i32 125} Index: test/Transforms/CodeExtractor/2016-07-20-MultipleExitBranchProb.ll =================================================================== --- test/Transforms/CodeExtractor/2016-07-20-MultipleExitBranchProb.ll +++ test/Transforms/CodeExtractor/2016-07-20-MultipleExitBranchProb.ll @@ -0,0 +1,30 @@ +; RUN: opt < %s -partial-inliner -S | FileCheck %s + +; This test checks to make sure that CodeExtractor updates +; the exit branch probabilities for multiple exit blocks. + +define i32 @inlinedFunc(i1 %cond) !prof !1 { +entry: + br i1 %cond, label %if.then, label %return, !prof !2 +if.then: + br i1 %cond, label %return, label %return.2, !prof !3 +return.2: + ret i32 10 +return: ; preds = %entry + ret i32 0 +} + + +define internal i32 @dummyCaller(i1 %cond) !prof !1 { +entry: +%val = call i32 @inlinedFunc(i1 %cond) +ret i32 %val +} + +!llvm.module.flags = !{!0} +!0 = !{i32 1, !"MaxFunctionCount", i32 10000} +!1 = !{!"function_entry_count", i64 10000} +!2 = !{!"branch_weights", i32 5, i32 5} +!3 = !{!"branch_weights", i32 4, i32 1} + +; CHECK: !4 = !{!"branch_weights", i32 8, i32 31}