diff --git a/llvm/include/llvm/Analysis/ProfileSummaryInfo.h b/llvm/include/llvm/Analysis/ProfileSummaryInfo.h --- a/llvm/include/llvm/Analysis/ProfileSummaryInfo.h +++ b/llvm/include/llvm/Analysis/ProfileSummaryInfo.h @@ -25,7 +25,6 @@ class BasicBlock; class BlockFrequencyInfo; class CallBase; -class Instruction; class ProfileSummary; /// Analysis providing profile information. /// @@ -97,7 +96,7 @@ } /// Returns the profile count for \p CallInst. - Optional getProfileCount(const Instruction *CallInst, + Optional getProfileCount(const CallBase &CallInst, BlockFrequencyInfo *BFI, bool AllowSynthetic = false); /// Returns true if the working set size of the code is considered huge. diff --git a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp --- a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp +++ b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp @@ -340,7 +340,7 @@ } // We should have named any anonymous globals assert(CalledFunction->hasName()); - auto ScaledCount = PSI->getProfileCount(&I, BFI); + auto ScaledCount = PSI->getProfileCount(*CB, BFI); auto Hotness = ScaledCount ? getHotness(ScaledCount.getValue(), PSI) : CalleeInfo::HotnessType::Unknown; if (ForceSummaryEdgesCold != FunctionSummary::FSHT_None) diff --git a/llvm/lib/Analysis/ProfileSummaryInfo.cpp b/llvm/lib/Analysis/ProfileSummaryInfo.cpp --- a/llvm/lib/Analysis/ProfileSummaryInfo.cpp +++ b/llvm/lib/Analysis/ProfileSummaryInfo.cpp @@ -101,14 +101,10 @@ return true; } -// FIXME(CallSite): the parameter should be a CallBase. -Optional -ProfileSummaryInfo::getProfileCount(const Instruction *Inst, - BlockFrequencyInfo *BFI, - bool AllowSynthetic) { - if (!Inst) - return None; - assert((isa(Inst) || isa(Inst)) && +Optional ProfileSummaryInfo::getProfileCount(const CallBase &Call, + BlockFrequencyInfo *BFI, + bool AllowSynthetic) { + assert((isa(Call) || isa(Call)) && "We can only get profile count for call/invoke instruction."); if (hasSampleProfile()) { // In sample PGO mode, check if there is a profile metadata on the @@ -116,12 +112,12 @@ // since the sampled entry count may not be accurate. If there is no // annotated on the instruction, return None. uint64_t TotalCount; - if (Inst->extractProfTotalWeight(TotalCount)) + if (Call.extractProfTotalWeight(TotalCount)) return TotalCount; return None; } if (BFI) - return BFI->getBlockProfileCount(Inst->getParent(), AllowSynthetic); + return BFI->getBlockProfileCount(Call.getParent(), AllowSynthetic); return None; } @@ -156,7 +152,7 @@ for (const auto &BB : *F) for (const auto &I : BB) if (isa(I) || isa(I)) - if (auto CallCount = getProfileCount(&I, nullptr)) + if (auto CallCount = getProfileCount(cast(I), nullptr)) TotalCallCount += CallCount.getValue(); if (isHotCount(TotalCallCount)) return true; @@ -185,7 +181,7 @@ for (const auto &BB : *F) for (const auto &I : BB) if (isa(I) || isa(I)) - if (auto CallCount = getProfileCount(&I, nullptr)) + if (auto CallCount = getProfileCount(cast(I), nullptr)) TotalCallCount += CallCount.getValue(); if (!isColdCount(TotalCallCount)) return false; @@ -214,7 +210,7 @@ for (const auto &BB : *F) for (const auto &I : BB) if (isa(I) || isa(I)) - if (auto CallCount = getProfileCount(&I, nullptr)) + if (auto CallCount = getProfileCount(cast(I), nullptr)) TotalCallCount += CallCount.getValue(); if (isHot && isHotCountNthPercentile(PercentileCutoff, TotalCallCount)) return true; @@ -388,13 +384,13 @@ bool ProfileSummaryInfo::isHotCallSite(const CallBase &CB, BlockFrequencyInfo *BFI) { - auto C = getProfileCount(&CB, BFI); + auto C = getProfileCount(CB, BFI); return C && isHotCount(*C); } bool ProfileSummaryInfo::isColdCallSite(const CallBase &CB, BlockFrequencyInfo *BFI) { - auto C = getProfileCount(&CB, BFI); + auto C = getProfileCount(CB, BFI); if (C) return isColdCount(*C); diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp --- a/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -1559,8 +1559,7 @@ /// Update the branch metadata for cloned call instructions. static void updateCallProfile(Function *Callee, const ValueToValueMapTy &VMap, const ProfileCount &CalleeEntryCount, - const Instruction *TheCall, - ProfileSummaryInfo *PSI, + const CallBase &TheCall, ProfileSummaryInfo *PSI, BlockFrequencyInfo *CallerBFI) { if (!CalleeEntryCount.hasValue() || CalleeEntryCount.isSynthetic() || CalleeEntryCount.getCount() < 1) @@ -1810,7 +1809,7 @@ updateCallerBFI(OrigBB, VMap, IFI.CallerBFI, IFI.CalleeBFI, CalledFunc->front()); - updateCallProfile(CalledFunc, VMap, CalledFunc->getEntryCount(), &CB, + updateCallProfile(CalledFunc, VMap, CalledFunc->getEntryCount(), CB, IFI.PSI, IFI.CallerBFI); // Inject byval arguments initialization.