Index: include/llvm/Analysis/ProfileSummaryInfo.h =================================================================== --- include/llvm/Analysis/ProfileSummaryInfo.h +++ include/llvm/Analysis/ProfileSummaryInfo.h @@ -29,6 +29,7 @@ namespace llvm { class BasicBlock; class BlockFrequencyInfo; +class CallSite; class ProfileSummary; /// \brief Analysis providing profile information. /// @@ -48,6 +49,7 @@ void computeThresholds(); // Count thresholds to answer isHotCount and isColdCount queries. Optional HotCountThreshold, ColdCountThreshold; + bool extractProfTotalWeight(const Instruction *TI, uint64_t &TotalCount); public: ProfileSummaryInfo(Module &M) : M(M) {} @@ -63,6 +65,12 @@ bool isColdCount(uint64_t C); /// \brief Returns true if BasicBlock \p B is considered hot. bool isHotBB(const BasicBlock *B, BlockFrequencyInfo *BFI); + /// \brief Returns true if BasicBlock \p B is considered cold. + bool isColdBB(const BasicBlock *B, BlockFrequencyInfo *BFI); + /// \brief Returns true if CallSite \p CS is considered hot. + bool isHotCallSite(const CallSite &CS, BlockFrequencyInfo *BFI); + /// \brief Returns true if Callsite \p CS is considered cold. + bool isColdCallSite(const CallSite &CS, BlockFrequencyInfo *BFI); }; /// An analysis pass based on legacy pass manager to deliver ProfileSummaryInfo. Index: lib/Analysis/ProfileSummaryInfo.cpp =================================================================== --- lib/Analysis/ProfileSummaryInfo.cpp +++ lib/Analysis/ProfileSummaryInfo.cpp @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/ProfileSummary.h" @@ -135,10 +136,52 @@ // not update/scale branch weights. Unlike false negatives, this will not cause // performance problem. uint64_t TotalCount; - if (B->getTerminator()->extractProfTotalWeight(TotalCount) && - isHotCount(TotalCount)) - return true; - return false; + auto *TI = B->getTerminator(); + return extractProfTotalWeight(TI, TotalCount) && isHotCount(TotalCount); +} + +bool ProfileSummaryInfo::isColdBB(const BasicBlock *B, + BlockFrequencyInfo *BFI) { + auto Count = BFI->getBlockProfileCount(B); + return Count && isColdCount(*Count); +} + +bool ProfileSummaryInfo::extractProfTotalWeight(const Instruction *I, + uint64_t &TotalCount) { + // Use profile weight on metadata only for sample profiling where block counts + // could differ from the count of an instruction within the block. + if (Summary.get()->getKind() != ProfileSummary::PSK_Sample) + return false; + + return (isa(I) || + (isa(I) && !isa(I))) && + I->extractProfTotalWeight(TotalCount); +} + +bool ProfileSummaryInfo::isHotCallSite(const CallSite &CS, + BlockFrequencyInfo *BFI) { + auto *CallInst = CS.getInstruction(); + if (!CS) + return false; + // Check if there is a profile metadata on the instruction. If it is present, + // determine hotness solely based on that. + uint64_t TotalCount; + if (extractProfTotalWeight(CallInst, TotalCount)) + return isHotCount(TotalCount); + return BFI && isHotBB(CallInst->getParent(), BFI); +} + +bool ProfileSummaryInfo::isColdCallSite(const CallSite &CS, + BlockFrequencyInfo *BFI) { + auto *CallInst = CS.getInstruction(); + if (!CS) + return false; + // Check if there is a profile metadata on the instruction. If it is present, + // and tells that the callsite is not cold, then return false; + uint64_t TotalCount; + if (extractProfTotalWeight(CallInst, TotalCount) && !isColdCount(TotalCount)) + return false; + return BFI && isColdBB(CallInst->getParent(), BFI); } INITIALIZE_PASS(ProfileSummaryInfoWrapperPass, "profile-summary-info", Index: unittests/Analysis/CMakeLists.txt =================================================================== --- unittests/Analysis/CMakeLists.txt +++ unittests/Analysis/CMakeLists.txt @@ -15,6 +15,7 @@ LazyCallGraphTest.cpp LoopPassManagerTest.cpp MemoryBuiltinsTest.cpp + ProfileSummaryInfoTest.cpp ScalarEvolutionTest.cpp TBAATest.cpp ValueTrackingTest.cpp Index: unittests/Analysis/ProfileSummaryInfoTest.cpp =================================================================== --- /dev/null +++ unittests/Analysis/ProfileSummaryInfoTest.cpp @@ -0,0 +1,173 @@ +//===- ProfileSummaryInfoTest.cpp - ProfileSummaryInfo unit tests ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/DataTypes.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "gtest/gtest.h" + +namespace llvm { +namespace { + +class ProfileSummaryInfoTest : public testing::Test { +protected: + LLVMContext C; + std::unique_ptr BPI; + std::unique_ptr DT; + std::unique_ptr LI; + + ProfileSummaryInfo buildPSI(Module *M) { + return ProfileSummaryInfo(*M); + } + BlockFrequencyInfo buildBFI(Function &F) { + DT.reset(new DominatorTree(F)); + LI.reset(new LoopInfo(*DT)); + BPI.reset(new BranchProbabilityInfo(F, *LI)); + return BlockFrequencyInfo(F, *BPI, *LI); + } + std::unique_ptr makeLLVMModule(StringRef ProfKind) { + const char *ModuleStrig = + "define i32 @g(i32 %x) !prof !21 {{\n" + " ret i32 0\n" + "}\n" + "define i32 @h(i32 %x) !prof !22 {{\n" + " ret i32 0\n" + "}\n" + "define i32 @f(i32 %x) !prof !20 {{\n" + "bb0:\n" + " %y1 = icmp eq i32 %x, 0 \n" + " br i1 %y1, label %bb1, label %bb2, !prof !23 \n" + "bb1:\n" + " %z1 = call i32 @g(i32 %x)\n" + " br label %bb3\n" + "bb2:\n" + " %z2 = call i32 @h(i32 %x)\n" + " br label %bb3\n" + "bb3:\n" + " %y2 = phi i32 [0, %bb1], [1, %bb2] \n" + " ret i32 %y2\n" + "}\n" + "!llvm.module.flags = !{{!1}" + "!20 = !{{!\"function_entry_count\", i64 400}" + "!21 = !{{!\"function_entry_count\", i64 1}" + "!22 = !{{!\"function_entry_count\", i64 100}" + "!23 = !{{!\"branch_weights\", i32 64, i32 4}" + "!1 = !{{i32 1, !\"ProfileSummary\", !2}" + "!2 = !{{!3, !4, !5, !6, !7, !8, !9, !10}" + "!3 = !{{!\"ProfileFormat\", !\"{0}\"}" + "!4 = !{{!\"TotalCount\", i64 10000}" + "!5 = !{{!\"MaxCount\", i64 10}" + "!6 = !{{!\"MaxInternalCount\", i64 1}" + "!7 = !{{!\"MaxFunctionCount\", i64 1000}" + "!8 = !{{!\"NumCounts\", i64 3}" + "!9 = !{{!\"NumFunctions\", i64 3}" + "!10 = !{{!\"DetailedSummary\", !11}" + "!11 = !{{!12, !13, !14}" + "!12 = !{{i32 10000, i64 1000, i32 1}" + "!13 = !{{i32 999000, i64 300, i32 3}" + "!14 = !{{i32 999999, i64 5, i32 10}"; + SMDiagnostic Err; + return parseAssemblyString(StringRef(formatv(ModuleStrig, ProfKind)), Err, + C); + } +}; + +TEST_F(ProfileSummaryInfoTest, TestCommon) { + auto M = makeLLVMModule("InstrProf"); + Function *F = M->getFunction("f"); + Function *G = M->getFunction("g"); + Function *H = M->getFunction("h"); + + ProfileSummaryInfo PSI = buildPSI(M.get()); + EXPECT_TRUE(PSI.isHotCount(400)); + EXPECT_TRUE(PSI.isColdCount(2)); + EXPECT_FALSE(PSI.isColdCount(100)); + EXPECT_FALSE(PSI.isHotCount(100)); + + EXPECT_TRUE(PSI.isFunctionEntryHot(F)); + EXPECT_FALSE(PSI.isFunctionEntryHot(G)); + EXPECT_FALSE(PSI.isFunctionEntryHot(H)); +} + +TEST_F(ProfileSummaryInfoTest, InstrProf) { + auto M = makeLLVMModule("InstrProf"); + Function *F = M->getFunction("f"); + ProfileSummaryInfo PSI = buildPSI(M.get()); + + BasicBlock &BB0 = F->getEntryBlock(); + BasicBlock *BB1 = BB0.getTerminator()->getSuccessor(0); + BasicBlock *BB2 = BB0.getTerminator()->getSuccessor(1); + BasicBlock *BB3 = BB1->getSingleSuccessor(); + + BlockFrequencyInfo BFI = buildBFI(*F); + EXPECT_TRUE(PSI.isHotBB(&BB0, &BFI)); + EXPECT_TRUE(PSI.isHotBB(BB1, &BFI)); + EXPECT_FALSE(PSI.isHotBB(BB2, &BFI)); + EXPECT_TRUE(PSI.isHotBB(BB3, &BFI)); + + CallSite CS1(BB1->getFirstNonPHI()); + auto *CI2 = BB2->getFirstNonPHI(); + CallSite CS2(CI2); + + EXPECT_TRUE(PSI.isHotCallSite(CS1, &BFI)); + EXPECT_FALSE(PSI.isHotCallSite(CS2, &BFI)); + + // Test that adding an MD_prof metadata with a hot count on CS2 does not + // change itas hotness as it has no effect in instrumented profiling. + MDBuilder MDB(M->getContext()); + CI2->setMetadata(llvm::LLVMContext::MD_prof, MDB.createBranchWeights({400})); + EXPECT_FALSE(PSI.isHotCallSite(CS2, &BFI)); +} + +TEST_F(ProfileSummaryInfoTest, SampleProf) { + auto M = makeLLVMModule("SampleProfile"); + Function *F = M->getFunction("f"); + ProfileSummaryInfo PSI = buildPSI(M.get()); + + BasicBlock &BB0 = F->getEntryBlock(); + BasicBlock *BB1 = BB0.getTerminator()->getSuccessor(0); + BasicBlock *BB2 = BB0.getTerminator()->getSuccessor(1); + BasicBlock *BB3 = BB1->getSingleSuccessor(); + + BlockFrequencyInfo BFI = buildBFI(*F); + EXPECT_TRUE(PSI.isHotBB(&BB0, &BFI)); + EXPECT_TRUE(PSI.isHotBB(BB1, &BFI)); + EXPECT_FALSE(PSI.isHotBB(BB2, &BFI)); + EXPECT_TRUE(PSI.isHotBB(BB3, &BFI)); + + CallSite CS1(BB1->getFirstNonPHI()); + auto *CI2 = BB2->getFirstNonPHI(); + CallSite CS2(CI2); + + EXPECT_TRUE(PSI.isHotCallSite(CS1, &BFI)); + EXPECT_FALSE(PSI.isHotCallSite(CS2, &BFI)); + + // Test that CS2 is considered hot when it gets an MD_prof metadata with + // weights that exceed the hot count threshold. + MDBuilder MDB(M->getContext()); + CI2->setMetadata(llvm::LLVMContext::MD_prof, MDB.createBranchWeights({400})); + EXPECT_TRUE(PSI.isHotCallSite(CS2, &BFI)); +} + +} // end anonymous namespace +} // end namespace llvm