Index: include/llvm/IR/IntrinsicInst.h =================================================================== --- include/llvm/IR/IntrinsicInst.h +++ include/llvm/IR/IntrinsicInst.h @@ -361,6 +361,17 @@ } }; + class InstrProfIncrementInstStep : public InstrProfIncrementInst { + public: + static inline bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::instrprof_increment_step; + } + static inline bool classof(const Value *V) { + return isa(V) && classof(cast(V)); + } + Value *getStep() const { return const_cast(getArgOperand(4)); } + }; + /// This represents the llvm.instrprof_value_profile intrinsic. class InstrProfValueProfileInst : public IntrinsicInst { public: Index: include/llvm/IR/Intrinsics.td =================================================================== --- include/llvm/IR/Intrinsics.td +++ include/llvm/IR/Intrinsics.td @@ -346,6 +346,12 @@ llvm_i32_ty, llvm_i32_ty], []>; +// A counter increment with step for instrumentation based profiling. +def int_instrprof_increment_step : Intrinsic<[], + [llvm_ptr_ty, llvm_i64_ty, + llvm_i32_ty, llvm_i32_ty, llvm_i64_ty], + []>; + // A call to profile runtime for value profiling of target expressions // through instrumentation based profiling. def int_instrprof_value_profile : Intrinsic<[], Index: lib/Transforms/Instrumentation/InstrProfiling.cpp =================================================================== --- lib/Transforms/Instrumentation/InstrProfiling.cpp +++ lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -107,6 +107,13 @@ return getInstrProfCoverageSectionName(isMachO()); } +static InstrProfIncrementInst *castToIncrementInst(Instruction *Instr) { + InstrProfIncrementInst *Inc = dyn_cast(Instr); + if (Inc) + return Inc; + return dyn_cast(Instr); +} + bool InstrProfiling::run(Module &M) { bool MadeChange = false; @@ -138,7 +145,8 @@ for (BasicBlock &BB : F) for (auto I = BB.begin(), E = BB.end(); I != E;) { auto Instr = I++; - if (auto *Inc = dyn_cast(Instr)) { + InstrProfIncrementInst *Inc = castToIncrementInst(&*Instr); + if (Inc) { lowerIncrement(Inc); MadeChange = true; } else if (auto *Ind = dyn_cast(Instr)) { @@ -214,6 +222,14 @@ Ind->eraseFromParent(); } +static Value *getIncrementStep(InstrProfIncrementInst *Inc, + Value *DefaultStep) { + auto *IncWithStep = dyn_cast(Inc); + if (IncWithStep) + return IncWithStep->getStep(); + return DefaultStep; +} + void InstrProfiling::lowerIncrement(InstrProfIncrementInst *Inc) { GlobalVariable *Counters = getOrCreateRegionCounters(Inc); @@ -221,7 +237,7 @@ uint64_t Index = Inc->getIndex()->getZExtValue(); Value *Addr = Builder.CreateConstInBoundsGEP2_64(Counters, 0, Index); Value *Count = Builder.CreateLoad(Addr, "pgocount"); - Count = Builder.CreateAdd(Count, Builder.getInt64(1)); + Count = Builder.CreateAdd(Count, getIncrementStep(Inc, Builder.getInt64(1))); Inc->replaceAllUsesWith(Builder.CreateStore(Count, Addr)); Inc->eraseFromParent(); } Index: lib/Transforms/Instrumentation/PGOInstrumentation.cpp =================================================================== --- lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -86,6 +86,7 @@ #define DEBUG_TYPE "pgo-instrumentation" STATISTIC(NumOfPGOInstrument, "Number of edges instrumented."); +STATISTIC(NumOfPGOSelectInsts, "Number of select instruction instrumented."); STATISTIC(NumOfPGOEdge, "Number of edges."); STATISTIC(NumOfPGOBB, "Number of basic-blocks."); STATISTIC(NumOfPGOSplit, "Number of critical edge splits."); @@ -133,7 +134,56 @@ static cl::opt NoPGOWarnMismatch("no-pgo-warn-mismatch", cl::init(false), cl::Hidden); +// Command line option to enable/disable select instruction instrumentation. +static cl::opt PGOInstrSelect("pgo-instr-select", cl::init(true), + cl::Hidden); namespace { + +/// The select instruction visitor plays three roles specified +/// by the mode. In \c VM_counting mode, it simply counts the number of +/// select instructions. In \c VM_instrument mode, it inserts code to count +/// the number times TrueValue of select is taken. In \c VM_annotate mode, +/// it reads the profile data and annotate the select instruction with metadata. +enum VisitMode { VM_counting, VM_instrument, VM_annotate }; +class PGOUseFunc; + +/// Instruction Visitor class to visit select instructions. +struct SelectInstVisitor : public InstVisitor { + Function &F; + unsigned NSIs = 0; // Number of select instructions instrumented. + VisitMode Mode = VM_counting; // Visiting mode. + unsigned *CI = nullptr; // pointer to current counter index. + unsigned NC = 0; // Number of counters + GlobalVariable *FuncNameVar = nullptr; + uint64_t FuncHash = 0; + PGOUseFunc *UseFunc = nullptr; + + SelectInstVisitor(Function &Func) : F(Func) {} + + // Set the visitor in \c VM_instrument mode. The initial value + // of \p *I will be first the counter index for select instructions. + // \p *I will be updated after each select instruction visit. + // \p C is the total number of counters in this function. + void SetCounterIndex(unsigned *I, unsigned C, GlobalVariable *FN, + uint64_t FH) { + Mode = VM_instrument; + CI = I; + NC = C; + FuncHash = FH; + FuncNameVar = FN; + } + // Set the visitor in \c VM_annotate mode. \p UF is the function + // annotator, and \p I is the pointer to the counter index variable. + void SetCounterUse(PGOUseFunc *UF, unsigned *I) { + Mode = VM_annotate; + UseFunc = UF; + CI = I; + } + // Visit \p SI instruction and perform tasks according to visit mode. + void visitSelectInst(SelectInst &SI); + unsigned size() const { return NSIs; } +}; + class PGOInstrumentationGenLegacyPass : public ModulePass { public: static char ID; @@ -180,6 +230,7 @@ AU.addRequired(); } }; + } // end anonymous namespace char PGOInstrumentationGenLegacyPass::ID = 0; @@ -254,6 +305,7 @@ std::unordered_multimap &ComdatMembers; public: + SelectInstVisitor SIVisitor; std::string FuncName; GlobalVariable *FuncNameVar; // CFG hash value for this function. @@ -280,8 +332,14 @@ std::unordered_multimap &ComdatMembers, bool CreateGlobalVar = false, BranchProbabilityInfo *BPI = nullptr, BlockFrequencyInfo *BFI = nullptr) - : F(Func), ComdatMembers(ComdatMembers), FunctionHash(0), + : F(Func), ComdatMembers(ComdatMembers), SIVisitor(Func), FunctionHash(0), MST(F, BPI, BFI) { + + // This should be done before CFG hash computation. + assert(SIVisitor.Mode == VM_counting && "Wrong select visiting mode!"); + SIVisitor.visit(Func); + NumOfPGOSelectInsts += SIVisitor.size(); + FuncName = getPGOFuncName(F); computeCFGHash(); if (ComdatMembers.size()) @@ -308,7 +366,7 @@ if (!E->InMST && !E->Removed) NumCounters++; } - return NumCounters; + return NumCounters + SIVisitor.size(); } }; @@ -328,7 +386,8 @@ } } JC.update(Indexes); - FunctionHash = (uint64_t)findIndirectCallSites(F).size() << 48 | + FunctionHash = (uint64_t)SIVisitor.size() << 56 | + (uint64_t)findIndirectCallSites(F).size() << 48 | (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC(); } @@ -473,6 +532,11 @@ Builder.getInt64(FuncInfo.FunctionHash), Builder.getInt32(NumCounters), Builder.getInt32(I++)}); } + + // Now instrument select instructions: + FuncInfo.SIVisitor.SetCounterIndex(&I, NumCounters, FuncInfo.FuncNameVar, + FuncInfo.FunctionHash); + FuncInfo.SIVisitor.visit(F); assert(I == NumCounters); if (DisableValueProfiling) @@ -594,17 +658,17 @@ // Return the profile record for this function; InstrProfRecord &getProfileRecord() { return ProfileRecord; } + // Return the auxiliary BB information. + UseBBInfo &getBBInfo(const BasicBlock *BB) const { + return FuncInfo.getBBInfo(BB); + } + private: Function &F; Module *M; // This member stores the shared information with class PGOGenFunc. FuncPGOInstrumentation FuncInfo; - // Return the auxiliary BB information. - UseBBInfo &getBBInfo(const BasicBlock *BB) const { - return FuncInfo.getBBInfo(BB); - } - // The maximum count value in the profile. This is only used in PGO use // compilation. uint64_t ProgramMaxCount; @@ -677,6 +741,9 @@ NewEdge1.InMST = true; getBBInfo(InstrBB).setBBInfoCount(CountValue); } + // Now annotate select instructions + FuncInfo.SIVisitor.SetCounterUse(this, &I); + FuncInfo.SIVisitor.visit(F); assert(I == CountFromProfile.size()); } @@ -820,7 +887,7 @@ DEBUG(FuncInfo.dumpInfo("after reading profile.")); } -static void setProfMetadata(Module *M, TerminatorInst *TI, +static void setProfMetadata(Module *M, Instruction *TI, ArrayRef EdgeCounts, uint64_t MaxCount) { MDBuilder MDB(M->getContext()); assert(MaxCount > 0 && "Bad max count"); @@ -869,6 +936,47 @@ } } +void SelectInstVisitor::visitSelectInst(SelectInst &SI) { + if (!PGOInstrSelect) + return; + // FIXME: do not handle this yet. + if (SI.getCondition()->getType()->isVectorTy()) + return; + + NSIs++; + if (Mode == VM_counting) + return; + else if (Mode == VM_instrument) { + Module *M = F.getParent(); + IRBuilder<> Builder(&SI); + assert(Builder.GetInsertPoint() != InstrBB->end() && + "Cannot get the Instrumentation point"); + Type *Int64Ty = Builder.getInt64Ty(); + Type *I8PtrTy = Builder.getInt8PtrTy(); + auto *Step = Builder.CreateZExt(SI.getCondition(), Int64Ty); + Builder.CreateCall( + Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment_step), + {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), + Builder.getInt64(FuncHash), Builder.getInt32(NC), + Builder.getInt32(*CI), Step}); + ++(*CI); + } else { + assert(Mode == VM_annotate && "Unknown visiting mode"); + std::vector &CountFromProfile = + UseFunc->getProfileRecord().Counts; + assert(*CI < CountFromProfile.size() && "Out of bound access of counters"); + uint64_t SCounts[2]; + SCounts[0] = CountFromProfile[*CI]; // True count + ++(*CI); + uint64_t TotalCount = UseFunc->getBBInfo(SI.getParent()).CountValue; + // False Count + SCounts[1] = (TotalCount > SCounts[0] ? TotalCount - SCounts[0] : 0); + uint64_t MaxCount = std::max(SCounts[0], SCounts[1]); + + setProfMetadata(F.getParent(), &SI, SCounts, MaxCount); + } +} + // Traverse all the indirect callsites and annotate the instructions. void PGOUseFunc::annotateIndirectCallSites() { if (DisableValueProfiling) Index: test/Transforms/PGOProfile/Inputs/select1.proftext =================================================================== --- test/Transforms/PGOProfile/Inputs/select1.proftext +++ test/Transforms/PGOProfile/Inputs/select1.proftext @@ -0,0 +1,9 @@ +# :ir is the flag to indicate this is IR level profile. +:ir +test_br_2 +72057623705475732 +3 +4 +1 +1 + Index: test/Transforms/PGOProfile/select1.ll =================================================================== --- test/Transforms/PGOProfile/select1.ll +++ test/Transforms/PGOProfile/select1.ll @@ -0,0 +1,31 @@ +; RUN: opt < %s -pgo-instr-gen -pgo-instr-select=true -S | FileCheck %s --check-prefix=GEN +; RUN: opt < %s -passes=pgo-instr-gen -pgo-instr-select=true -S | FileCheck %s --check-prefix=GEN +; RUN: llvm-profdata merge %S/Inputs/select1.proftext -o %t.profdata +; RUN: opt < %s -pgo-instr-use -pgo-test-profile-file=%t.profdata -pgo-instr-select=true -S | FileCheck %s --check-prefix=USE +; RUN: opt < %s -passes=pgo-instr-use -pgo-test-profile-file=%t.profdata -pgo-instr-select=true -S | FileCheck %s --check-prefix=USE +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +define i32 @test_br_2(i32 %i) { +entry: + %cmp = icmp sgt i32 %i, 0 + br i1 %cmp, label %if.then, label %if.else + +if.then: + %add = add nsw i32 %i, 2 +;GEN: [[STEP:.*]] = zext i1 %cmp to i64 +;GEN: call void @llvm.instrprof.increment.step({{.*}} i32 3, i32 2, i64 %0) + %s = select i1 %cmp, i32 %add, i32 0 +;USE: select i1 %cmp{{.*}}, !prof ![[BW_ENTRY:[0-9]+]] +;USE: ![[BW_ENTRY]] = !{!"branch_weights", i32 1, i32 3} + + br label %if.end + +if.else: + %sub = sub nsw i32 %i, 2 + br label %if.end + +if.end: + %retv = phi i32 [ %add, %if.then ], [ %sub, %if.else ] + ret i32 %retv +}