Index: include/llvm/Analysis/ProfileSummaryInfo.h =================================================================== --- include/llvm/Analysis/ProfileSummaryInfo.h +++ include/llvm/Analysis/ProfileSummaryInfo.h @@ -49,6 +49,10 @@ void computeThresholds(); // Count thresholds to answer isHotCount and isColdCount queries. Optional HotCountThreshold, ColdCountThreshold; + /// Returns the profile count for \p CallInst. + static Optional getProfileCount(const Instruction *CallInst, + BlockFrequencyInfo *BFI, + bool IsSamplePGO); public: ProfileSummaryInfo(Module &M) : M(M) {} Index: include/llvm/IR/Module.h =================================================================== --- include/llvm/IR/Module.h +++ include/llvm/IR/Module.h @@ -804,7 +804,7 @@ void setProfileSummary(Metadata *M); /// \brief Returns profile summary metadata - Metadata *getProfileSummary(); + Metadata *getProfileSummary() const; /// @} /// Take ownership of the given memory buffer. Index: include/llvm/IR/ProfileSummary.h =================================================================== --- include/llvm/IR/ProfileSummary.h +++ include/llvm/IR/ProfileSummary.h @@ -18,6 +18,7 @@ #include #include +#include "llvm/ADT/Optional.h" #include "llvm/Support/Casting.h" namespace llvm { @@ -70,6 +71,8 @@ Kind getKind() const { return PSK; } /// \brief Return summary information as metadata. Metadata *getMD(LLVMContext &Context); + /// Return the profile kind from metadata. + static Optional getKindFromMD(Metadata *MD); /// \brief Construct profile summary from metdata. static ProfileSummary *getFromMD(Metadata *MD); SummaryEntryVector &getDetailedSummary() { return DetailedSummary; } Index: lib/Analysis/ProfileSummaryInfo.cpp =================================================================== --- lib/Analysis/ProfileSummaryInfo.cpp +++ lib/Analysis/ProfileSummaryInfo.cpp @@ -68,23 +68,44 @@ return true; } -Optional -ProfileSummaryInfo::getProfileCount(const Instruction *Inst, - BlockFrequencyInfo *BFI) { +static Optional getKind(const Module *M) { + auto *SummaryMD = M->getProfileSummary(); + if (!SummaryMD) + return None; + return ProfileSummary::getKindFromMD(SummaryMD); +} + +Optional ProfileSummaryInfo::getProfileCount(const Instruction *Inst, + BlockFrequencyInfo *BFI, + bool IsSamplePGO) { if (!Inst) return None; assert((isa(Inst) || isa(Inst)) && "We can only get profile count for call/invoke instruction."); - // Check if there is a profile metadata on the instruction. If it is present, - // determine hotness solely based on that. - uint64_t TotalCount; - if (Inst->extractProfTotalWeight(TotalCount)) - return TotalCount; + if (IsSamplePGO) { + // In sample PGO mode, check if there is a profile metadata on the + // instruction. If it is present, determine hotness solely based on that, + // since the sampled entry count may not be accurate. + uint64_t TotalCount; + if (Inst->extractProfTotalWeight(TotalCount)) + return TotalCount; + } if (BFI) return BFI->getBlockProfileCount(Inst->getParent()); return None; } +Optional +ProfileSummaryInfo::getProfileCount(const Instruction *Inst, + BlockFrequencyInfo *BFI) { + if (!Inst) + return None; + bool IsSamplePGO = false; + if (auto Kind = getKind(Inst->getModule())) + IsSamplePGO = Kind.getValue() == ProfileSummary::PSK_Sample; + return getProfileCount(Inst, BFI, IsSamplePGO); +} + /// Returns true if the function's entry is hot. If it returns false, it /// either means it is not hot or it is unknown whether it is hot or not (for /// example, no profile data is available). @@ -108,11 +129,13 @@ if (isHotCount(FunctionCount.getValue())) return true; + assert(Summary.get()); uint64_t TotalCallCount = 0; for (const auto &BB : *F) for (const auto &I : BB) if (isa(I) || isa(I)) - if (auto CallCount = getProfileCount(&I, nullptr)) + if (auto CallCount = getProfileCount( + &I, nullptr, Summary->getKind() == ProfileSummary::PSK_Sample)) TotalCallCount += CallCount.getValue(); return isHotCount(TotalCallCount); } @@ -126,12 +149,14 @@ if (auto FunctionCount = F->getEntryCount()) if (!isColdCount(FunctionCount.getValue())) return false; - + + assert(Summary.get()); uint64_t TotalCallCount = 0; for (const auto &BB : *F) for (const auto &I : BB) if (isa(I) || isa(I)) - if (auto CallCount = getProfileCount(&I, nullptr)) + if (auto CallCount = getProfileCount( + &I, nullptr, Summary->getKind() == ProfileSummary::PSK_Sample)) TotalCallCount += CallCount.getValue(); return isColdCount(TotalCallCount); } @@ -189,13 +214,19 @@ bool ProfileSummaryInfo::isHotCallSite(const CallSite &CS, BlockFrequencyInfo *BFI) { - auto C = getProfileCount(CS.getInstruction(), BFI); + if (!computeSummary()) + return false; + auto C = getProfileCount(CS.getInstruction(), BFI, + Summary->getKind() == ProfileSummary::PSK_Sample); return C && isHotCount(*C); } bool ProfileSummaryInfo::isColdCallSite(const CallSite &CS, BlockFrequencyInfo *BFI) { - auto C = getProfileCount(CS.getInstruction(), BFI); + if (!computeSummary()) + return false; + auto C = getProfileCount(CS.getInstruction(), BFI, + Summary->getKind() == ProfileSummary::PSK_Sample); return C && isColdCount(*C); } Index: lib/IR/Module.cpp =================================================================== --- lib/IR/Module.cpp +++ lib/IR/Module.cpp @@ -483,7 +483,7 @@ addModuleFlag(ModFlagBehavior::Error, "ProfileSummary", M); } -Metadata *Module::getProfileSummary() { +Metadata *Module::getProfileSummary() const { return getModuleFlag("ProfileSummary"); } Index: lib/IR/ProfileSummary.cpp =================================================================== --- lib/IR/ProfileSummary.cpp +++ lib/IR/ProfileSummary.cpp @@ -143,6 +143,26 @@ return true; } +Optional ProfileSummary::getKindFromMD(Metadata *MD) { + if (!MD) + return None; + if (!isa(MD)) + return None; + MDTuple *Tuple = cast(MD); + if (Tuple->getNumOperands() != 8) + return None; + + auto &FormatMD = Tuple->getOperand(0); + if (isKeyValuePair(dyn_cast_or_null(FormatMD), "ProfileFormat", + "SampleProfile")) + return PSK_Sample; + else if (isKeyValuePair(dyn_cast_or_null(FormatMD), "ProfileFormat", + "InstrProf")) + return PSK_Instr; + + return None; +} + ProfileSummary *ProfileSummary::getFromMD(Metadata *MD) { if (!MD) return nullptr; Index: test/Bitcode/thinlto-function-summary-callgraph-profile-summary.ll =================================================================== --- test/Bitcode/thinlto-function-summary-callgraph-profile-summary.ll +++ test/Bitcode/thinlto-function-summary-callgraph-profile-summary.ll @@ -29,7 +29,7 @@ ; CHECK-NEXT: ; op4=hot1 op6=cold op8=hot2 op10=hot4 op12=none1 op14=hot3 op16=none2 op18=none3 op20=123 -; CHECK-NEXT: +; CHECK-NEXT: ; CHECK-NEXT: ; CHECK: getContext()); + CI2->setMetadata(llvm::LLVMContext::MD_prof, MDB.createBranchWeights({400})); + EXPECT_FALSE(PSI.isHotCallSite(CS2, &BFI)); } TEST_F(ProfileSummaryInfoTest, SampleProf) {