Index: include/llvm/Analysis/ProfileSummaryInfo.h =================================================================== --- include/llvm/Analysis/ProfileSummaryInfo.h +++ include/llvm/Analysis/ProfileSummaryInfo.h @@ -29,6 +29,7 @@ namespace llvm { class BasicBlock; class BlockFrequencyInfo; +class CallSite; class ProfileSummary; /// \brief Analysis providing profile information. /// @@ -63,6 +64,12 @@ bool isColdCount(uint64_t C); /// \brief Returns true if BasicBlock \p B is considered hot. bool isHotBB(const BasicBlock *B, BlockFrequencyInfo *BFI); + /// \brief Returns true if BasicBlock \p B is considered cold. + bool isColdBB(const BasicBlock *B, BlockFrequencyInfo *BFI); + /// \brief Returns true if CallSite \p CS is considered hot. + bool isHotCallSite(const CallSite &CS, BlockFrequencyInfo *BFI); + /// \brief Returns true if Callsite \p CS is considered cold. + bool isColdCallSite(const CallSite &CS, BlockFrequencyInfo *BFI); }; /// An analysis pass based on legacy pass manager to deliver ProfileSummaryInfo. Index: lib/Analysis/ProfileSummaryInfo.cpp =================================================================== --- lib/Analysis/ProfileSummaryInfo.cpp +++ lib/Analysis/ProfileSummaryInfo.cpp @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/ProfileSummary.h" @@ -135,12 +136,46 @@ // not update/scale branch weights. Unlike false negatives, this will not cause // performance problem. uint64_t TotalCount; - if (B->getTerminator()->extractProfTotalWeight(TotalCount) && - isHotCount(TotalCount)) + auto *TI = B->getTerminator(); + if (isa(TI)) + return false; + if (TI->extractProfTotalWeight(TotalCount) && isHotCount(TotalCount)) return true; return false; } +bool ProfileSummaryInfo::isColdBB(const BasicBlock *B, + BlockFrequencyInfo *BFI) { + auto Count = BFI->getBlockProfileCount(B); + return Count && isColdCount(*Count); +} + +bool ProfileSummaryInfo::isHotCallSite(const CallSite &CS, + BlockFrequencyInfo *BFI) { + auto *CallInst = CS.getInstruction(); + if (!CS) + return false; + // Check if there is a profile metadata on the instruction. If it is present, + // determine hotness solely based on that. + uint64_t TotalCount; + if (CallInst->extractProfTotalWeight(TotalCount)) + return isHotCount(TotalCount); + return BFI && isHotBB(CallInst->getParent(), BFI); +} + +bool ProfileSummaryInfo::isColdCallSite(const CallSite &CS, + BlockFrequencyInfo *BFI) { + auto *CallInst = CS.getInstruction(); + if (!CS) + return false; + // Check if there is a profile metadata on the instruction. If it is present, + // and tells that the callsite is not cold, then return false; + uint64_t TotalCount; + if (CallInst->extractProfTotalWeight(TotalCount) && !isColdCount(TotalCount)) + return false; + return BFI && isColdBB(CallInst->getParent(), BFI); +} + INITIALIZE_PASS(ProfileSummaryInfoWrapperPass, "profile-summary-info", "Profile summary info", false, true) Index: unittests/Analysis/CMakeLists.txt =================================================================== --- unittests/Analysis/CMakeLists.txt +++ unittests/Analysis/CMakeLists.txt @@ -15,6 +15,7 @@ LazyCallGraphTest.cpp LoopPassManagerTest.cpp MemoryBuiltinsTest.cpp + ProfileSummaryInfoTest.cpp ScalarEvolutionTest.cpp TBAATest.cpp ValueTrackingTest.cpp Index: unittests/Analysis/ProfileSummaryInfoTest.cpp =================================================================== --- /dev/null +++ unittests/Analysis/ProfileSummaryInfoTest.cpp @@ -0,0 +1,135 @@ +//===- ProfileSummaryInfoTest.cpp - ProfileSummaryInfo unit tests ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/DataTypes.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "gtest/gtest.h" + +namespace llvm { +namespace { + +class ProfileSummaryInfoTest : public testing::Test { +protected: + LLVMContext C; + std::unique_ptr BPI; + std::unique_ptr DT; + std::unique_ptr LI; + + ProfileSummaryInfo buildPSI(Module *M) { + return ProfileSummaryInfo(*M); + } + BlockFrequencyInfo buildBFI(Function &F) { + DT.reset(new DominatorTree(F)); + LI.reset(new LoopInfo(*DT)); + BPI.reset(new BranchProbabilityInfo(F, *LI)); + return BlockFrequencyInfo(F, *BPI, *LI); + } + std::unique_ptr makeLLVMModule() { + const char *ModuleStrig = "define i32 @g(i32 %x) !prof !21 {\n" + " ret i32 0\n" + "}\n" + "define i32 @h(i32 %x) !prof !22 {\n" + " ret i32 0\n" + "}\n" + "define i32 @f(i32 %x) !prof !20 {\n" + "bb0:\n" + " %y1 = icmp eq i32 %x, 0 \n" + " br i1 %y1, label %bb1, label %bb2, !prof !23 \n" + "bb1:\n" + " %z1 = call i32 @g(i32 %x)\n" + " br label %bb3\n" + "bb2:\n" + " %z2 = call i32 @h(i32 %x)\n" + " br label %bb3\n" + "bb3:\n" + " %y2 = phi i32 [0, %bb1], [1, %bb2] \n" + " ret i32 %y2\n" + "}\n" + "!llvm.module.flags = !{!1}" + "!20 = !{!\"function_entry_count\", i64 400}" + "!21 = !{!\"function_entry_count\", i64 1}" + "!22 = !{!\"function_entry_count\", i64 100}" + "!23 = !{!\"branch_weights\", i32 64, i32 4}" + "!1 = !{i32 1, !\"ProfileSummary\", !2}" + "!2 = !{!3, !4, !5, !6, !7, !8, !9, !10}" + "!3 = !{!\"ProfileFormat\", !\"InstrProf\"}" + "!4 = !{!\"TotalCount\", i64 10000}" + "!5 = !{!\"MaxCount\", i64 10}" + "!6 = !{!\"MaxInternalCount\", i64 1}" + "!7 = !{!\"MaxFunctionCount\", i64 1000}" + "!8 = !{!\"NumCounts\", i64 3}" + "!9 = !{!\"NumFunctions\", i64 3}" + "!10 = !{!\"DetailedSummary\", !11}" + "!11 = !{!12, !13, !14}" + "!12 = !{i32 10000, i64 1000, i32 1}" + "!13 = !{i32 999000, i64 300, i32 3}" + "!14 = !{i32 999999, i64 5, i32 10}"; + SMDiagnostic Err; + return parseAssemblyString(ModuleStrig, Err, C); + } +}; + +TEST_F(ProfileSummaryInfoTest, Basic) { + auto M = makeLLVMModule(); + Function *F = M->getFunction("f"); + Function *G = M->getFunction("g"); + Function *H = M->getFunction("h"); + + ProfileSummaryInfo PSI = buildPSI(M.get()); + EXPECT_TRUE(PSI.isHotCount(400)); + EXPECT_TRUE(PSI.isColdCount(2)); + EXPECT_FALSE(PSI.isColdCount(100)); + EXPECT_FALSE(PSI.isHotCount(100)); + + EXPECT_TRUE(PSI.isFunctionEntryHot(F)); + EXPECT_FALSE(PSI.isFunctionEntryHot(G)); + EXPECT_FALSE(PSI.isFunctionEntryHot(H)); + + BasicBlock &BB0 = F->getEntryBlock(); + BasicBlock *BB1 = BB0.getTerminator()->getSuccessor(0); + BasicBlock *BB2 = BB0.getTerminator()->getSuccessor(1); + BasicBlock *BB3 = BB1->getSingleSuccessor(); + + BlockFrequencyInfo BFI = buildBFI(*F); + EXPECT_TRUE(PSI.isHotBB(&BB0, &BFI)); + EXPECT_TRUE(PSI.isHotBB(BB1, &BFI)); + EXPECT_FALSE(PSI.isHotBB(BB2, &BFI)); + EXPECT_TRUE(PSI.isHotBB(BB3, &BFI)); + + CallSite CS1(BB1->getFirstNonPHI()); + auto *CI2 = BB2->getFirstNonPHI(); + CallSite CS2(CI2); + + EXPECT_TRUE(PSI.isHotCallSite(CS1, &BFI)); + EXPECT_FALSE(PSI.isHotCallSite(CS2, &BFI)); + + // Test that CS2 is considered hot when it gets an MD_prof metadata with + // weights that exceed the hot count threshold. + MDBuilder MDB(M->getContext()); + CI2->setMetadata(llvm::LLVMContext::MD_prof, MDB.createBranchWeights({300})); + EXPECT_TRUE(PSI.isHotCallSite(CS2, &BFI)); + +} + +} // end anonymous namespace +} // end namespace llvm