diff --git a/llvm/include/llvm/Analysis/ProfileSummaryInfo.h b/llvm/include/llvm/Analysis/ProfileSummaryInfo.h --- a/llvm/include/llvm/Analysis/ProfileSummaryInfo.h +++ b/llvm/include/llvm/Analysis/ProfileSummaryInfo.h @@ -24,7 +24,7 @@ namespace llvm { class BasicBlock; class BlockFrequencyInfo; -class CallSite; +class CallBase; class Instruction; class ProfileSummary; /// Analysis providing profile information. @@ -144,10 +144,10 @@ /// cold percentile cutoff value. bool isColdBlockNthPercentile(int PercentileCutoff, const BasicBlock *BB, BlockFrequencyInfo *BFI); - /// Returns true if CallSite \p CS is considered hot. - bool isHotCallSite(const CallSite &CS, BlockFrequencyInfo *BFI); - /// Returns true if Callsite \p CS is considered cold. - bool isColdCallSite(const CallSite &CS, BlockFrequencyInfo *BFI); + /// Returns true if the call site \p CB is considered hot. + bool isHotCallSite(const CallBase &CB, BlockFrequencyInfo *BFI); + /// Returns true if call site \p CB is considered cold. + bool isColdCallSite(const CallBase &CB, BlockFrequencyInfo *BFI); /// Returns HotCountThreshold if set. Recompute HotCountThreshold /// if not set. uint64_t getOrCompHotCountThreshold(); diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp --- a/llvm/lib/Analysis/InlineCost.cpp +++ b/llvm/lib/Analysis/InlineCost.cpp @@ -1173,7 +1173,7 @@ // If global profile summary is available, then callsite's coldness is // determined based on that. if (PSI && PSI->hasProfileSummary()) - return PSI->isColdCallSite(CallSite(&Call), CallerBFI); + return PSI->isColdCallSite(Call, CallerBFI); // Otherwise we need BFI to be available. if (!CallerBFI) @@ -1197,8 +1197,7 @@ // If global profile summary is available, then callsite's hotness is // determined based on that. - if (PSI && PSI->hasProfileSummary() && - PSI->isHotCallSite(CallSite(&Call), CallerBFI)) + if (PSI && PSI->hasProfileSummary() && PSI->isHotCallSite(Call, CallerBFI)) return Params.HotCallSiteThreshold; // Otherwise we need BFI to be available and to have a locally hot callsite diff --git a/llvm/lib/Analysis/ProfileSummaryInfo.cpp b/llvm/lib/Analysis/ProfileSummaryInfo.cpp --- a/llvm/lib/Analysis/ProfileSummaryInfo.cpp +++ b/llvm/lib/Analysis/ProfileSummaryInfo.cpp @@ -101,6 +101,7 @@ return true; } +// FIXME(CallSite): the parameter should be a CallBase. Optional ProfileSummaryInfo::getProfileCount(const Instruction *Inst, BlockFrequencyInfo *BFI, @@ -385,21 +386,21 @@ return isHotOrColdBlockNthPercentile(PercentileCutoff, BB, BFI); } -bool ProfileSummaryInfo::isHotCallSite(const CallSite &CS, +bool ProfileSummaryInfo::isHotCallSite(const CallBase &CB, BlockFrequencyInfo *BFI) { - auto C = getProfileCount(CS.getInstruction(), BFI); + auto C = getProfileCount(&CB, BFI); return C && isHotCount(*C); } -bool ProfileSummaryInfo::isColdCallSite(const CallSite &CS, +bool ProfileSummaryInfo::isColdCallSite(const CallBase &CB, BlockFrequencyInfo *BFI) { - auto C = getProfileCount(CS.getInstruction(), BFI); + auto C = getProfileCount(&CB, BFI); if (C) return isColdCount(*C); // In SamplePGO, if the caller has been sampled, and there is no profile // annotated on the callsite, we consider the callsite as cold. - return hasSampleProfile() && CS.getCaller()->hasProfileData(); + return hasSampleProfile() && CB.getCaller()->hasProfileData(); } INITIALIZE_PASS(ProfileSummaryInfoWrapperPass, "profile-summary-info", diff --git a/llvm/unittests/Analysis/ProfileSummaryInfoTest.cpp b/llvm/unittests/Analysis/ProfileSummaryInfoTest.cpp --- a/llvm/unittests/Analysis/ProfileSummaryInfoTest.cpp +++ b/llvm/unittests/Analysis/ProfileSummaryInfoTest.cpp @@ -12,7 +12,6 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/LLVMContext.h" @@ -134,7 +133,7 @@ EXPECT_FALSE(PSI.isHotBlock(&BB0, &BFI)); EXPECT_FALSE(PSI.isColdBlock(&BB0, &BFI)); - CallSite CS1(BB1->getFirstNonPHI()); + CallBase &CS1 = *cast(BB1->getFirstNonPHI()); EXPECT_FALSE(PSI.isHotCallSite(CS1, &BFI)); EXPECT_FALSE(PSI.isColdCallSite(CS1, &BFI)); } @@ -228,9 +227,9 @@ EXPECT_TRUE(PSI.isColdBlockNthPercentile(10000, BB2, &BFI)); EXPECT_TRUE(PSI.isColdBlockNthPercentile(10000, BB3, &BFI)); - CallSite CS1(BB1->getFirstNonPHI()); + CallBase &CS1 = *cast(BB1->getFirstNonPHI()); auto *CI2 = BB2->getFirstNonPHI(); - CallSite CS2(CI2); + CallBase &CS2 = *cast(CI2); EXPECT_TRUE(PSI.isHotCallSite(CS1, &BFI)); EXPECT_FALSE(PSI.isHotCallSite(CS2, &BFI)); @@ -323,14 +322,14 @@ EXPECT_TRUE(PSI.isColdBlockNthPercentile(10000, BB2, &BFI)); EXPECT_TRUE(PSI.isColdBlockNthPercentile(10000, BB3, &BFI)); - CallSite CS1(BB1->getFirstNonPHI()); + CallBase &CS1 = *cast(BB1->getFirstNonPHI()); auto *CI2 = BB2->getFirstNonPHI(); // Manually attach branch weights metadata to the call instruction. SmallVector Weights; Weights.push_back(1000); MDBuilder MDB(M->getContext()); CI2->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); - CallSite CS2(CI2); + CallBase &CS2 = *cast(CI2); EXPECT_FALSE(PSI.isHotCallSite(CS1, &BFI)); EXPECT_TRUE(PSI.isHotCallSite(CS2, &BFI));