Index: include/llvm/IR/Instruction.h =================================================================== --- include/llvm/IR/Instruction.h +++ include/llvm/IR/Instruction.h @@ -252,6 +252,10 @@ /// Returns false if no metadata was found. bool extractProfTotalWeight(uint64_t &TotalVal) const; + /// Updates branch_weights metadata with the given ratio \p R. If \p Reverse + /// is true, update branch_weight with 1 - \p R. + void updateProfWeight(double R, bool Reverse); + /// Set the debug location information for this instruction. void setDebugLoc(DebugLoc Loc) { DbgLoc = std::move(Loc); } Index: lib/IR/Instruction.cpp =================================================================== --- lib/IR/Instruction.cpp +++ lib/IR/Instruction.cpp @@ -17,6 +17,7 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Operator.h" #include "llvm/IR/Type.h" using namespace llvm; @@ -674,3 +675,24 @@ New->copyMetadata(*this); return New; } + +void Instruction::updateProfWeight(double R, bool Reverse) { + auto *ProfileData = getMetadata(LLVMContext::MD_prof); + if (!ProfileData) + return; + + auto *ProfDataName = dyn_cast(ProfileData->getOperand(0)); + if (!ProfDataName || !ProfDataName->getString().equals("branch_weights")) + return; + + SmallVector Weights; + for (unsigned i = 1; i < ProfileData->getNumOperands(); i++) { + auto *V = mdconst::dyn_extract(ProfileData->getOperand(i)); + if (!V) + return; + uint64_t Val = V->getValue().getZExtValue(); + Weights.push_back(Reverse ? Val - Val * R : Val * R); + } + MDBuilder MDB(getContext()); + setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); +} Index: lib/Transforms/Utils/InlineFunction.cpp =================================================================== --- lib/Transforms/Utils/InlineFunction.cpp +++ lib/Transforms/Utils/InlineFunction.cpp @@ -1425,12 +1425,40 @@ ClonedBBs); } +static void updateCallProfile(Function *Callee, + const ValueToValueMapTy &VMap, + const Optional &CalleeEntryCount, + const Instruction *TheCall) { + uint64_t CallCount = 0; + if (!CalleeEntryCount || CalleeEntryCount.getValue() < 1) + return; + if (!TheCall->extractProfTotalWeight(CallCount)) + CallCount = 0; + double Ratio = std::min((double) CallCount / CalleeEntryCount.getValue(), + 1.0); + + SmallPtrSet ClonedBBs; + for (auto const &Entry : VMap) { + if (!isa(Entry.first) || !Entry.second) + continue; + BasicBlock *ClonedBB = cast(Entry.second); + if (ClonedBBs.insert(ClonedBB).second) + for (Instruction &I : *ClonedBB) + if (CallInst *CI = dyn_cast(&I)) + CI->updateProfWeight(Ratio, false); + } + for (BasicBlock &BB : *Callee) + for (Instruction &I : BB) + if (CallInst *CI = dyn_cast(&I)) + CI->updateProfWeight(Ratio, true); +} + /// Update the entry count of callee after inlining. /// /// The callsite's block count is subtracted from the callee's function entry /// count. -static void updateCalleeCount(BlockFrequencyInfo &CallerBFI, BasicBlock *CallBB, - Function *Callee) { +static void updateCalleeCount(BlockFrequencyInfo *CallerBFI, BasicBlock *CallBB, + Instruction *CallInst, Function *Callee) { // If the callee has a original count of N, and the estimated count of // callsite is M, the new callee count is set to N - M. M is estimated from // the caller's entry count, its entry block frequency and the block frequency @@ -1438,15 +1466,21 @@ Optional CalleeCount = Callee->getEntryCount(); if (!CalleeCount) return; - Optional CallSiteCount = CallerBFI.getBlockProfileCount(CallBB); - if (!CallSiteCount) + uint64_t CallCount = 0; + if (!CallInst->extractProfTotalWeight(CallCount)) + if (CallerBFI) { + Optional CallSiteCount = CallerBFI->getBlockProfileCount(CallBB); + if (CallSiteCount) + CallCount = CallSiteCount.getValue(); + } + if (CallCount == 0) return; // Since CallSiteCount is an estimate, it could exceed the original callee // count and has to be set to 0. - if (CallSiteCount.getValue() > CalleeCount.getValue()) + if (CallCount > CalleeCount.getValue()) Callee->setEntryCount(0); else - Callee->setEntryCount(CalleeCount.getValue() - CallSiteCount.getValue()); + Callee->setEntryCount(CalleeCount.getValue() - CallCount); } /// This function inlines the called function into the basic block of the @@ -1636,13 +1670,14 @@ // Remember the first block that is newly cloned over. FirstNewBlock = LastBlock; ++FirstNewBlock; - if (IFI.CallerBFI != nullptr && IFI.CalleeBFI != nullptr) { + if (IFI.CallerBFI != nullptr && IFI.CalleeBFI != nullptr) // Update the BFI of blocks cloned into the caller. updateCallerBFI(OrigBB, VMap, IFI.CallerBFI, IFI.CalleeBFI, CalledFunc->front()); - // Update the profile count of callee. - updateCalleeCount(*IFI.CallerBFI, OrigBB, CalledFunc); - } + + updateCallProfile(CalledFunc, VMap, CalledFunc->getEntryCount(), TheCall); + // Update the profile count of callee. + updateCalleeCount(IFI.CallerBFI, OrigBB, TheCall, CalledFunc); // Inject byval arguments initialization. for (std::pair &Init : ByValInit) Index: test/Transforms/Inline/prof-update.ll =================================================================== --- /dev/null +++ test/Transforms/Inline/prof-update.ll @@ -0,0 +1,28 @@ +; RUN: opt < %s -inline -S | FileCheck %s +; Checks if inliner updates branch_weights annotation for call instructions. + +declare void @ext(); + +; CHECK: define void @callee() !prof ![[ENTRY_COUNT:[0-9]*]] +define void @callee() !prof !1 { +; CHECK: call void @ext(), !prof ![[COUNT_CALLEE:[0-9]*]] + call void @ext(), !prof !2 + ret void +} + +; CHECK: define void @caller() +define void @caller() { +; CHECK: call void @ext(), !prof ![[COUNT_CALLER:[0-9]*]] + call void @callee(), !prof !3 + ret void +} + +!llvm.module.flags = !{!0} +!0 = !{i32 1, !"MaxFunctionCount", i32 2000} +!1 = !{!"function_entry_count", i64 1000} +!2 = !{!"branch_weights", i64 2000} +!3 = !{!"branch_weights", i64 400} +attributes #0 = { alwaysinline } +; CHECK: ![[ENTRY_COUNT]] = !{!"function_entry_count", i64 600} +; CHECK: ![[COUNT_CALLEE]] = !{!"branch_weights", i32 1200} +; CHECK: ![[COUNT_CALLER]] = !{!"branch_weights", i32 800}