Index: include/llvm/Analysis/ProfileSummaryInfo.h =================================================================== --- include/llvm/Analysis/ProfileSummaryInfo.h +++ include/llvm/Analysis/ProfileSummaryInfo.h @@ -54,9 +54,12 @@ ProfileSummaryInfo(Module &M) : M(M) {} ProfileSummaryInfo(ProfileSummaryInfo &&Arg) : M(Arg.M), Summary(std::move(Arg.Summary)) {} + /// Returns the profile kind for the module. + static Optional getKind(const Module *M); /// Returns the profile count for \p CallInst. static Optional getProfileCount(const Instruction *CallInst, - BlockFrequencyInfo *BFI); + BlockFrequencyInfo *BFI, + ProfileSummary *Summary = nullptr); /// \brief Returns true if \p F has hot function entry. bool isFunctionEntryHot(const Function *F); /// Returns true if \p F has hot function entry or hot call edge. Index: include/llvm/IR/Module.h =================================================================== --- include/llvm/IR/Module.h +++ include/llvm/IR/Module.h @@ -804,7 +804,7 @@ void setProfileSummary(Metadata *M); /// \brief Returns profile summary metadata - Metadata *getProfileSummary(); + Metadata *getProfileSummary() const; /// @} /// Take ownership of the given memory buffer. Index: include/llvm/IR/ProfileSummary.h =================================================================== --- include/llvm/IR/ProfileSummary.h +++ include/llvm/IR/ProfileSummary.h @@ -18,6 +18,7 @@ #include #include +#include "llvm/ADT/Optional.h" #include "llvm/Support/Casting.h" namespace llvm { @@ -70,6 +71,8 @@ Kind getKind() const { return PSK; } /// \brief Return summary information as metadata. Metadata *getMD(LLVMContext &Context); + /// Return the profile kind from metadata. + static Optional getKindFromMD(Metadata *MD); /// \brief Construct profile summary from metdata. static ProfileSummary *getFromMD(Metadata *MD); SummaryEntryVector &getDetailedSummary() { return DetailedSummary; } Index: lib/Analysis/ProfileSummaryInfo.cpp =================================================================== --- lib/Analysis/ProfileSummaryInfo.cpp +++ lib/Analysis/ProfileSummaryInfo.cpp @@ -68,18 +68,32 @@ return true; } -Optional -ProfileSummaryInfo::getProfileCount(const Instruction *Inst, - BlockFrequencyInfo *BFI) { +Optional ProfileSummaryInfo::getKind(const Module *M) { + auto *SummaryMD = M->getProfileSummary(); + if (!SummaryMD) + return None; + return ProfileSummary::getKindFromMD(SummaryMD); +} + +Optional ProfileSummaryInfo::getProfileCount( + const Instruction *Inst, BlockFrequencyInfo *BFI, ProfileSummary *Summary) { if (!Inst) return None; assert((isa(Inst) || isa(Inst)) && "We can only get profile count for call/invoke instruction."); - // Check if there is a profile metadata on the instruction. If it is present, - // determine hotness solely based on that. - uint64_t TotalCount; - if (Inst->extractProfTotalWeight(TotalCount)) - return TotalCount; + bool IsSamplePGO = false; + if (Summary) + IsSamplePGO = Summary->getKind() == ProfileSummary::PSK_Sample; + else if (auto Kind = getKind(Inst->getModule())) + IsSamplePGO = Kind.getValue() == ProfileSummary::PSK_Sample; + if (IsSamplePGO) { + // In sample PGO mode, check if there is a profile metadata on the + // instruction. If it is present, determine hotness solely based on that, + // since the sampled entry count may not be accurate. + uint64_t TotalCount; + if (Inst->extractProfTotalWeight(TotalCount)) + return TotalCount; + } if (BFI) return BFI->getBlockProfileCount(Inst->getParent()); return None; @@ -112,7 +126,7 @@ for (const auto &BB : *F) for (const auto &I : BB) if (isa(I) || isa(I)) - if (auto CallCount = getProfileCount(&I, nullptr)) + if (auto CallCount = getProfileCount(&I, nullptr, Summary.get())) TotalCallCount += CallCount.getValue(); return isHotCount(TotalCallCount); } @@ -131,7 +145,7 @@ for (const auto &BB : *F) for (const auto &I : BB) if (isa(I) || isa(I)) - if (auto CallCount = getProfileCount(&I, nullptr)) + if (auto CallCount = getProfileCount(&I, nullptr, Summary.get())) TotalCallCount += CallCount.getValue(); return isColdCount(TotalCallCount); } @@ -189,13 +203,13 @@ bool ProfileSummaryInfo::isHotCallSite(const CallSite &CS, BlockFrequencyInfo *BFI) { - auto C = getProfileCount(CS.getInstruction(), BFI); + auto C = getProfileCount(CS.getInstruction(), BFI, Summary.get()); return C && isHotCount(*C); } bool ProfileSummaryInfo::isColdCallSite(const CallSite &CS, BlockFrequencyInfo *BFI) { - auto C = getProfileCount(CS.getInstruction(), BFI); + auto C = getProfileCount(CS.getInstruction(), BFI, Summary.get()); return C && isColdCount(*C); } Index: lib/IR/Module.cpp =================================================================== --- lib/IR/Module.cpp +++ lib/IR/Module.cpp @@ -483,7 +483,7 @@ addModuleFlag(ModFlagBehavior::Error, "ProfileSummary", M); } -Metadata *Module::getProfileSummary() { +Metadata *Module::getProfileSummary() const { return getModuleFlag("ProfileSummary"); } Index: lib/IR/ProfileSummary.cpp =================================================================== --- lib/IR/ProfileSummary.cpp +++ lib/IR/ProfileSummary.cpp @@ -143,6 +143,26 @@ return true; } +Optional ProfileSummary::getKindFromMD(Metadata *MD) { + if (!MD) + return None; + if (!isa(MD)) + return None; + MDTuple *Tuple = cast(MD); + if (Tuple->getNumOperands() != 8) + return None; + + auto &FormatMD = Tuple->getOperand(0); + if (isKeyValuePair(dyn_cast_or_null(FormatMD), "ProfileFormat", + "SampleProfile")) + return PSK_Sample; + else if (isKeyValuePair(dyn_cast_or_null(FormatMD), "ProfileFormat", + "InstrProf")) + return PSK_Instr; + + return None; +} + ProfileSummary *ProfileSummary::getFromMD(Metadata *MD) { if (!MD) return nullptr; Index: test/Bitcode/thinlto-function-summary-callgraph-profile-summary.ll =================================================================== --- test/Bitcode/thinlto-function-summary-callgraph-profile-summary.ll +++ test/Bitcode/thinlto-function-summary-callgraph-profile-summary.ll @@ -29,7 +29,7 @@ ; CHECK-NEXT: ; op4=hot1 op6=cold op8=hot2 op10=hot4 op12=none1 op14=hot3 op16=none2 op18=none3 op20=123 -; CHECK-NEXT: +; CHECK-NEXT: ; CHECK-NEXT: ; CHECK: getContext()); + CI2->setMetadata(llvm::LLVMContext::MD_prof, MDB.createBranchWeights({400})); + EXPECT_FALSE(PSI.isHotCallSite(CS2, &BFI)); } TEST_F(ProfileSummaryInfoTest, SampleProf) {