Index: include/llvm/Analysis/BlockFrequencyInfo.h =================================================================== --- include/llvm/Analysis/BlockFrequencyInfo.h +++ include/llvm/Analysis/BlockFrequencyInfo.h @@ -61,6 +61,11 @@ /// the enclosing function's count (if available) and returns the value. Optional getBlockProfileCount(const BasicBlock *BB) const; + /// \brief Returns the estimated profile count of \p Freq. + /// This uses the frequency \p Freq and multiplies it by + /// the enclosing function's count (if available) and returns the value. + Optional getFreqToProfileCount(uint64_t Freq) const; + // Set the frequency of the given basic block. void setBlockFreq(const BasicBlock *BB, uint64_t Freq); Index: include/llvm/Analysis/BlockFrequencyInfoImpl.h =================================================================== --- include/llvm/Analysis/BlockFrequencyInfoImpl.h +++ include/llvm/Analysis/BlockFrequencyInfoImpl.h @@ -482,6 +482,8 @@ BlockFrequency getBlockFreq(const BlockNode &Node) const; Optional getBlockProfileCount(const Function &F, const BlockNode &Node) const; + Optional getFreqToProfileCount(const Function &F, + uint64_t Freq) const; void setBlockFreq(const BlockNode &Node, uint64_t Freq); @@ -925,6 +927,10 @@ const BlockT *BB) const { return BlockFrequencyInfoImplBase::getBlockProfileCount(F, getNode(BB)); } + Optional getFreqToProfileCount(const Function &F, + uint64_t Freq) const { + return BlockFrequencyInfoImplBase::getFreqToProfileCount(F, Freq); + } void setBlockFreq(const BlockT *BB, uint64_t Freq); Scaled64 getFloatingBlockFreq(const BlockT *BB) const { return BlockFrequencyInfoImplBase::getFloatingBlockFreq(getNode(BB)); Index: include/llvm/CodeGen/MachineBlockFrequencyInfo.h =================================================================== --- include/llvm/CodeGen/MachineBlockFrequencyInfo.h +++ include/llvm/CodeGen/MachineBlockFrequencyInfo.h @@ -52,6 +52,8 @@ BlockFrequency getBlockFreq(const MachineBasicBlock *MBB) const; Optional getBlockProfileCount(const MachineBasicBlock *MBB) const; + Optional getFreqToProfileCount(uint64_t Freq) const; + const MachineFunction *getFunction() const; const MachineBranchProbabilityInfo *getMBPI() const; Index: include/llvm/Transforms/Utils/CodeExtractor.h =================================================================== --- include/llvm/Transforms/Utils/CodeExtractor.h +++ include/llvm/Transforms/Utils/CodeExtractor.h @@ -18,8 +18,9 @@ #include "llvm/ADT/SetVector.h" namespace llvm { -template class ArrayRef; + template class ArrayRef; class BasicBlock; + class BlockFrequencyInfo; class DominatorTree; class Function; class Loop; @@ -47,6 +48,7 @@ // Various bits of state computed on construction. DominatorTree *const DT; const bool AggregateArgs; + BlockFrequencyInfo *BFI; // Bits of intermediate state computed at various phases of extraction. SetVector Blocks; @@ -58,7 +60,8 @@ /// /// In this formation, we don't require a dominator tree. The given basic /// block is set up for extraction. - CodeExtractor(BasicBlock *BB, bool AggregateArgs = false); + CodeExtractor(BasicBlock *BB, bool AggregateArgs = false, + BlockFrequencyInfo *BFI = nullptr); /// \brief Create a code extractor for a sequence of blocks. /// @@ -67,20 +70,22 @@ /// sequence out into its new function. When a DominatorTree is also given, /// extra checking and transformations are enabled. CodeExtractor(ArrayRef BBs, DominatorTree *DT = nullptr, - bool AggregateArgs = false); + bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr); /// \brief Create a code extractor for a loop body. /// /// Behaves just like the generic code sequence constructor, but uses the /// block sequence of the loop. - CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false); + CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false, + BlockFrequencyInfo *BFI = nullptr); /// \brief Create a code extractor for a region node. /// /// Behaves just like the generic code sequence constructor, but uses the /// block sequence of the region node passed in. CodeExtractor(DominatorTree &DT, const RegionNode &RN, - bool AggregateArgs = false); + bool AggregateArgs = false, + BlockFrequencyInfo *BFI = nullptr); /// \brief Perform the extraction, returning the new function. /// @@ -123,4 +128,4 @@ }; } -#endif +#endif \ No newline at end of file Index: lib/Analysis/BlockFrequencyInfo.cpp =================================================================== --- lib/Analysis/BlockFrequencyInfo.cpp +++ lib/Analysis/BlockFrequencyInfo.cpp @@ -162,6 +162,14 @@ return BFI->getBlockProfileCount(*getFunction(), BB); } +Optional +BlockFrequencyInfo::getFreqToProfileCount(uint64_t Freq) const { + if (!BFI) + return None; + + return BFI->getFreqToProfileCount(*getFunction(), Freq); +} + void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, uint64_t Freq) { assert(BFI && "Expected analysis to be available"); BFI->setBlockFreq(BB, Freq); Index: lib/Analysis/BlockFrequencyInfoImpl.cpp =================================================================== --- lib/Analysis/BlockFrequencyInfoImpl.cpp +++ lib/Analysis/BlockFrequencyInfoImpl.cpp @@ -533,12 +533,18 @@ Optional BlockFrequencyInfoImplBase::getBlockProfileCount(const Function &F, const BlockNode &Node) const { + return getFreqToProfileCount(F, getBlockFreq(Node).getFrequency()); +} + +Optional +BlockFrequencyInfoImplBase::getFreqToProfileCount(const Function &F, + uint64_t Freq) const { auto EntryCount = F.getEntryCount(); if (!EntryCount) return None; // Use 128 bit APInt to do the arithmetic to avoid overflow. APInt BlockCount(128, EntryCount.getValue()); - APInt BlockFreq(128, getBlockFreq(Node).getFrequency()); + APInt BlockFreq(128, Freq); APInt EntryFreq(128, getEntryFreq()); BlockCount *= BlockFreq; BlockCount = BlockCount.udiv(EntryFreq); Index: lib/CodeGen/MachineBlockFrequencyInfo.cpp =================================================================== --- lib/CodeGen/MachineBlockFrequencyInfo.cpp +++ lib/CodeGen/MachineBlockFrequencyInfo.cpp @@ -175,6 +175,12 @@ return MBFI ? MBFI->getBlockProfileCount(*F, MBB) : None; } +Optional MachineBlockFrequencyInfo::getFreqToProfileCount( + uint64_t Freq) const { + const Function *F = MBFI->getFunction()->getFunction(); + return MBFI ? MBFI->getFreqToProfileCount(*F, Freq) : None; +} + const MachineFunction *MachineBlockFrequencyInfo::getFunction() const { return MBFI ? MBFI->getFunction() : nullptr; } Index: lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- lib/Transforms/Utils/CodeExtractor.cpp +++ lib/Transforms/Utils/CodeExtractor.cpp @@ -17,6 +17,9 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/RegionInfo.h" #include "llvm/Analysis/RegionIterator.h" @@ -26,9 +29,11 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/Pass.h" +#include "llvm/Support/BlockFrequency.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" @@ -45,8 +50,8 @@ // extracted functions to pthread-based code, as only one argument (void*) can // be passed in to pthread_create(). static cl::opt -AggregateArgsOpt("aggregate-extracted-args", cl::Hidden, - cl::desc("Aggregate arguments to code-extracted functions")); + AggregateArgsOpt("aggregate-extracted-args", cl::Hidden, + cl::desc("Aggregate arguments to code-extracted functions")); /// \brief Test whether a block is valid for extraction. static bool isBlockValidForExtraction(const BasicBlock &BB) { @@ -59,9 +64,9 @@ if (isa(I) || isa(I)) return false; if (const CallInst *CI = dyn_cast(I)) - if (const Function *F = CI->getCalledFunction()) - if (F->getIntrinsicID() == Intrinsic::vastart) - return false; + if (const Function *F = CI->getCalledFunction()) + if (F->getIntrinsicID() == Intrinsic::vastart) + return false; } return true; @@ -89,13 +94,13 @@ #ifndef NDEBUG for (SetVector::iterator I = std::next(Result.begin()), - E = Result.end(); + E = Result.end(); I != E; ++I) for (pred_iterator PI = pred_begin(*I), PE = pred_end(*I); PI != PE; ++PI) assert(Result.count(*PI) && "No blocks in this region may have entries from outside the region" - " except for the first block!"); + " except for the first block!"); #endif return Result; @@ -119,30 +124,32 @@ return buildExtractionBlockSet(R.block_begin(), R.block_end()); } -CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs) - : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} +CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs, + BlockFrequencyInfo *BFI) + : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt), BFI(BFI), + Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} CodeExtractor::CodeExtractor(ArrayRef BBs, DominatorTree *DT, - bool AggregateArgs) - : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} - -CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs) - : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {} + bool AggregateArgs, BlockFrequencyInfo *BFI) + : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), BFI(BFI), + Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} + +CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, + BlockFrequencyInfo *BFI) + : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), BFI(BFI), + Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {} CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN, - bool AggregateArgs) - : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {} + bool AggregateArgs, BlockFrequencyInfo *BFI) + : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), BFI(BFI), + Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {} /// definedInRegion - Return true if the specified value is defined in the /// extracted region. static bool definedInRegion(const SetVector &Blocks, Value *V) { if (Instruction *I = dyn_cast(V)) - if (Blocks.count(I->getParent())) - return true; + if (Blocks.count(I->getParent())) + return true; return false; } @@ -152,8 +159,8 @@ static bool definedInCaller(const SetVector &Blocks, Value *V) { if (isa(V)) return true; if (Instruction *I = dyn_cast(V)) - if (!Blocks.count(I->getParent())) - return true; + if (!Blocks.count(I->getParent())) + return true; return false; } @@ -292,10 +299,10 @@ // This function returns unsigned, outputs will go back by reference. switch (NumExitBlocks) { - case 0: - case 1: RetTy = Type::getVoidTy(header->getContext()); break; - case 2: RetTy = Type::getInt1Ty(header->getContext()); break; - default: RetTy = Type::getInt16Ty(header->getContext()); break; + case 0: + case 1: RetTy = Type::getVoidTy(header->getContext()); break; + case 2: RetTy = Type::getInt1Ty(header->getContext()); break; + default: RetTy = Type::getInt16Ty(header->getContext()); break; } std::vector paramTy; @@ -316,11 +323,11 @@ } DEBUG({ - dbgs() << "Function type: " << *RetTy << " f("; - for (Type *i : paramTy) - dbgs() << *i << ", "; - dbgs() << ")\n"; - }); + dbgs() << "Function type: " << *RetTy << " f("; + for (Type *i : paramTy) + dbgs() << *i << ", "; + dbgs() << ")\n"; + }); StructType *StructTy; if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { @@ -329,7 +336,7 @@ paramTy.push_back(PointerType::getUnqual(StructTy)); } FunctionType *funcType = - FunctionType::get(RetTy, paramTy, false); + FunctionType::get(RetTy, paramTy, false); // Create the new function Function *newFunction = Function::Create(funcType, @@ -339,7 +346,7 @@ // If the old function is no-throw, so is the new one. if (oldFunction->doesNotThrow()) newFunction->setDoesNotThrow(); - + newFunction->getBasicBlockList().push_back(newRootNode); // Create an iterator to name all of the arguments we inserted. @@ -363,8 +370,8 @@ std::vector Users(inputs[i]->user_begin(), inputs[i]->user_end()); for (User *use : Users) if (Instruction *inst = dyn_cast(use)) - if (Blocks.count(inst->getParent())) - inst->replaceUsesOfWith(inputs[i], RewriteVal); + if (Blocks.count(inst->getParent())) + inst->replaceUsesOfWith(inputs[i], RewriteVal); } // Set names for input and output arguments. @@ -384,9 +391,9 @@ // The BasicBlock which contains the branch is not in the region // modify the branch target to a new block if (TerminatorInst *TI = dyn_cast(Users[i])) - if (!Blocks.count(TI->getParent()) && - TI->getParent()->getParent() == oldFunction) - TI->replaceUsesOfWith(header, newHeader); + if (!Blocks.count(TI->getParent()) && + TI->getParent()->getParent() == oldFunction) + TI->replaceUsesOfWith(header, newHeader); return newFunction; } @@ -396,9 +403,9 @@ /// block associated with that use, or return 0 if none is found. static BasicBlock* FindPhiPredForUseInBlock(Value* Used, BasicBlock* BB) { for (Use &U : Used->uses()) { - PHINode *P = dyn_cast(U.getUser()); - if (P && P->getParent() == BB) - return P->getIncomingBlock(U); + PHINode *P = dyn_cast(U.getUser()); + if (P && P->getParent() == BB) + return P->getIncomingBlock(U); } return nullptr; @@ -413,7 +420,7 @@ // Emit a call to the new function, passing in: *pointer to struct (if // aggregating parameters), or plan inputs and allocated memory for outputs std::vector params, StructValues, ReloadOutputs, Reloads; - + LLVMContext &Context = newFunction->getContext(); // Add inputs as params, or to be filled into the struct @@ -441,7 +448,7 @@ if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { std::vector ArgTypes; for (ValueSet::iterator v = StructValues.begin(), - ve = StructValues.end(); v != ve; ++v) + ve = StructValues.end(); v != ve; ++v) ArgTypes.push_back((*v)->getType()); // Allocate a struct at the beginning of this function @@ -527,14 +534,14 @@ Value *brVal = nullptr; switch (NumExitBlocks) { - case 0: - case 1: break; // No value needed. - case 2: // Conditional branch, return a bool - brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum); - break; - default: - brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum); - break; + case 0: + case 1: break; // No value needed. + case 2: // Conditional branch, return a bool + brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum); + break; + default: + brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum); + break; } ReturnInst *NTRet = ReturnInst::Create(Context, brVal, NewTarget); @@ -578,12 +585,12 @@ if (DT) { DominatesDef = DT->dominates(DefBlock, OldTarget); - + // If the output value is used by a phi in the target block, // then we need to test for dominance of the phi's predecessor // instead. Unfortunately, this a little complicated since we // have already rewritten uses of the value to uses of the reload. - BasicBlock* pred = FindPhiPredForUseInBlock(Reloads[out], + BasicBlock* pred = FindPhiPredForUseInBlock(Reloads[out], OldTarget); if (pred && DT && DT->dominates(DefBlock, pred)) DominatesDef = true; @@ -616,45 +623,45 @@ // Now that we've done the deed, simplify the switch instruction. Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType(); switch (NumExitBlocks) { - case 0: - // There are no successors (the block containing the switch itself), which - // means that previously this was the last part of the function, and hence - // this should be rewritten as a `ret' - - // Check if the function should return a value - if (OldFnRetTy->isVoidTy()) { - ReturnInst::Create(Context, nullptr, TheSwitch); // Return void - } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) { - // return what we have - ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch); - } else { - // Otherwise we must have code extracted an unwind or something, just - // return whatever we want. - ReturnInst::Create(Context, - Constant::getNullValue(OldFnRetTy), TheSwitch); - } - - TheSwitch->eraseFromParent(); - break; - case 1: - // Only a single destination, change the switch into an unconditional - // branch. - BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch); - TheSwitch->eraseFromParent(); - break; - case 2: - BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2), - call, TheSwitch); - TheSwitch->eraseFromParent(); - break; - default: - // Otherwise, make the default destination of the switch instruction be one - // of the other successors. - TheSwitch->setCondition(call); - TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks)); - // Remove redundant case - TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1)); - break; + case 0: + // There are no successors (the block containing the switch itself), which + // means that previously this was the last part of the function, and hence + // this should be rewritten as a `ret' + + // Check if the function should return a value + if (OldFnRetTy->isVoidTy()) { + ReturnInst::Create(Context, nullptr, TheSwitch); // Return void + } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) { + // return what we have + ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch); + } else { + // Otherwise we must have code extracted an unwind or something, just + // return whatever we want. + ReturnInst::Create(Context, + Constant::getNullValue(OldFnRetTy), TheSwitch); + } + + TheSwitch->eraseFromParent(); + break; + case 1: + // Only a single destination, change the switch into an unconditional + // branch. + BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch); + TheSwitch->eraseFromParent(); + break; + case 2: + BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2), + call, TheSwitch); + TheSwitch->eraseFromParent(); + break; + default: + // Otherwise, make the default destination of the switch instruction be one + // of the other successors. + TheSwitch->setCondition(call); + TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks)); + // Remove redundant case + TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1)); + break; } } @@ -682,6 +689,22 @@ // block in the region. BasicBlock *header = *Blocks.begin(); + // Get non-const in case we modify the exit probabilities when the exit count is + // greater than 1. + BranchProbabilityInfo *BPI = nullptr; + if (BFI) + BPI = const_cast(BFI->getBPI()); + + // Calculate the entry of the new function before we change the root block. + BlockFrequency EntryCount; + if (BFI) { + for (BasicBlock *Pred : predecessors(header)) { + if (Blocks.count(Pred)) + continue; + EntryCount += BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header); + } + } + // If we have to split PHI nodes or the entry block, do so now. severSplitPHINodes(header); @@ -692,25 +715,34 @@ Function *oldFunction = header->getParent(); // This takes place of the original loop - BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), + BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), "codeRepl", oldFunction, header); // The new function needs a root node because other nodes can branch to the // head of the region, but the entry node of a function cannot have preds. - BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(), + BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(), "newFuncRoot"); newFuncRoot->getInstList().push_back(BranchInst::Create(header)); // Find inputs to, outputs from the code region. findInputsOutputs(inputs, outputs); + DenseMap ExitWeights; SmallPtrSet ExitBlocks; - for (BasicBlock *Block : Blocks) + for (BasicBlock *Block : Blocks) { for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE; - ++SI) - if (!Blocks.count(*SI)) + ++SI) { + if (!Blocks.count(*SI)) { + // Update the branch weight for this successor. + if (BPI) { + BlockFrequency &BF = ExitWeights[*SI]; + BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI); + } ExitBlocks.insert(*SI); + } + } + } NumExitBlocks = ExitBlocks.size(); // Construct new function based on inputs/outputs & add allocas for all defs. @@ -719,10 +751,61 @@ codeReplacer, oldFunction, oldFunction->getParent()); + // Update the entry count of the function. + if (BFI) { + Optional OptEntryCount = + BFI->getFreqToProfileCount(EntryCount.getFrequency()); + newFunction->setEntryCount(OptEntryCount ? OptEntryCount.getValue() : 0); + BFI->setBlockFreq(codeReplacer, EntryCount.getFrequency()); + } + emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); moveCodeToFunction(newFunction); + // Update the branch weights for the exit block. + TerminatorInst *TI = codeReplacer->getTerminator(); + if (BPI && TI->getNumSuccessors() > 1) { + // Make sure that the weights are valid, we cant have a weight of zero. + bool HasValidWeights = true; + for(auto &Weight : ExitWeights) + if(Weight.second.getFrequency() == 0) { + HasValidWeights = false; + break; + } + + if(HasValidWeights) { + // Block Frequency distribution for creating normalized weights. + BlockFrequencyInfoImplBase::Distribution BranchDist; + + // Add each of the frequencies of the successors + for (unsigned I = 0, E = TI->getNumSuccessors(); I < E; ++I) { + // Dummy node to represent this successor. + BlockFrequencyInfoImplBase::BlockNode DummyNode(I); + BranchDist.addExit(DummyNode, + ExitWeights[TI->getSuccessor(I)].getFrequency()); + } + + // Normalize the distribution so that they can fit in unsigned. + BranchDist.normalize(); + + // Create normalized branch weights and set the metadata. + SmallVector BranchWeights; + BranchWeights.reserve(TI->getNumSuccessors()); + for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) { + const auto &Weight = BranchDist.Weights[I]; + + // Get the weight and update the current BFI. + BranchWeights.push_back(Weight.Amount); + BranchProbability BP(Weight.Amount, BranchDist.Total); + BPI->setEdgeProbability(codeReplacer, Weight.TargetNode.Index, BP); + } + TI->setMetadata(LLVMContext::MD_prof, + MDBuilder(TI->getContext()). + createBranchWeights(BranchWeights)); + } + } + // Loop over all of the PHI nodes in the header block, and change any // references to the old incoming edge to be the new incoming edge. for (BasicBlock::iterator I = header->begin(); isa(I); ++I) { @@ -760,7 +843,7 @@ // cerr << "OLD FUNCTION: " << *oldFunction; // verifyFunction(*oldFunction); - DEBUG(if (verifyFunction(*newFunction)) - report_fatal_error("verifyFunction failed!")); + DEBUG(if (verifyFunction(*newFunction)) + report_fatal_error("verifyFunction failed!")); return newFunction; -} +} \ No newline at end of file Index: test/Transforms/CodeExtractor/2016-07-20-ExtractedFnEntryCount.ll =================================================================== --- test/Transforms/CodeExtractor/2016-07-20-ExtractedFnEntryCount.ll +++ test/Transforms/CodeExtractor/2016-07-20-ExtractedFnEntryCount.ll @@ -0,0 +1,33 @@ +; RUN: opt < %s -partial-inliner -S | FileCheck %s + +; This test checks to make sure that the CodeExtractor +; properly sets the entry count for the function that is +; extracted based on the root block being extracted and also +; takes into consideration if the block has edges coming from +; a block that is also being extracted. + +define i32 @inlinedFunc(i1 %cond) !prof !1 { +entry: + br i1 %cond, label %if.then, label %return, !prof !2 +if.then: + br i1 %cond, label %if.then, label %return, !prof !3 +return: ; preds = %entry + ret i32 0 +} + + +define internal i32 @dummyCaller(i1 %cond) !prof !1 { +entry: + %val = call i32 @inlinedFunc(i1 %cond) + ret i32 %val +} + +; CHECK: @inlinedFunc.1_if.then(i1 %cond) !prof [[COUNT1:![0-9]+]] + + +!llvm.module.flags = !{!0} +; CHECK: [[COUNT1]] = !{!"function_entry_count", i64 250} +!0 = !{i32 1, !"MaxFunctionCount", i32 1000} +!1 = !{!"function_entry_count", i64 1000} +!2 = !{!"branch_weights", i32 250, i32 750} +!3 = !{!"branch_weights", i32 125, i32 125} Index: test/Transforms/CodeExtractor/2016-07-20-MultipleExitBranchProb.ll =================================================================== --- test/Transforms/CodeExtractor/2016-07-20-MultipleExitBranchProb.ll +++ test/Transforms/CodeExtractor/2016-07-20-MultipleExitBranchProb.ll @@ -0,0 +1,30 @@ +; RUN: opt < %s -partial-inliner -S | FileCheck %s + +; This test checks to make sure that CodeExtractor updates +; the exit branch probabilities for multiple exit blocks. + +define i32 @inlinedFunc(i1 %cond) !prof !1 { +entry: + br i1 %cond, label %if.then, label %return, !prof !2 +if.then: + br i1 %cond, label %return, label %return.2, !prof !3 +return.2: + ret i32 10 +return: ; preds = %entry + ret i32 0 +} + + +define internal i32 @dummyCaller(i1 %cond) !prof !1 { +entry: +%val = call i32 @inlinedFunc(i1 %cond) +ret i32 %val +} + +!llvm.module.flags = !{!0} +!0 = !{i32 1, !"MaxFunctionCount", i32 10000} +!1 = !{!"function_entry_count", i64 10000} +!2 = !{!"branch_weights", i32 5, i32 5} +!3 = !{!"branch_weights", i32 4, i32 1} + +; CHECK: !4 = !{!"branch_weights", i32 8, i32 31}