Index: include/llvm/InitializePasses.h =================================================================== --- include/llvm/InitializePasses.h +++ include/llvm/InitializePasses.h @@ -214,6 +214,7 @@ void initializeOptimizePHIsPass(PassRegistry&); void initializePartiallyInlineLibCallsPass(PassRegistry&); void initializePEIPass(PassRegistry&); +void initializePGOIndirectCallTransformPass(PassRegistry&); void initializePHIEliminationPass(PassRegistry&); void initializePartialInlinerPass(PassRegistry&); void initializePeepholeOptimizerPass(PassRegistry&); Index: include/llvm/LinkAllPasses.h =================================================================== --- include/llvm/LinkAllPasses.h +++ include/llvm/LinkAllPasses.h @@ -119,6 +119,7 @@ (void) llvm::createObjCARCContractPass(); (void) llvm::createObjCARCOptPass(); (void) llvm::createPAEvalPass(); + (void) llvm::createPGOIndirectCallTransformPass(); (void) llvm::createPromoteMemoryToRegisterPass(); (void) llvm::createDemoteRegisterToMemoryPass(); (void) llvm::createPruneEHPass(); Index: include/llvm/Transforms/IPO.h =================================================================== --- include/llvm/Transforms/IPO.h +++ include/llvm/Transforms/IPO.h @@ -211,6 +211,12 @@ ModulePass *createLowerBitSetsPass(); //===----------------------------------------------------------------------===// +/// PGOIndirectCallTransform - Convert indirect calls to direct calls +/// using profile data. +/// +ModulePass *createPGOIndirectCallTransformPass(); + +//===----------------------------------------------------------------------===// // SampleProfilePass - Loads sample profile data from disk and generates // IR metadata to reflect the profile. ModulePass *createSampleProfileLoaderPass(); Index: lib/Transforms/IPO/CMakeLists.txt =================================================================== --- lib/Transforms/IPO/CMakeLists.txt +++ lib/Transforms/IPO/CMakeLists.txt @@ -19,6 +19,7 @@ MergeFunctions.cpp PartialInlining.cpp PassManagerBuilder.cpp + PGOIndirectCallTransform.cpp PruneEH.cpp SampleProfile.cpp StripDeadPrototypes.cpp Index: lib/Transforms/IPO/IPO.cpp =================================================================== --- lib/Transforms/IPO/IPO.cpp +++ lib/Transforms/IPO/IPO.cpp @@ -39,6 +39,7 @@ initializeLowerBitSetsPass(Registry); initializeMergeFunctionsPass(Registry); initializePartialInlinerPass(Registry); + initializePGOIndirectCallTransformPass(Registry); initializePruneEHPass(Registry); initializeStripDeadPrototypesLegacyPassPass(Registry); initializeStripSymbolsPass(Registry); @@ -105,6 +106,10 @@ unwrap(PM)->add(createInternalizePass(Export)); } +void LLVMAddPGOIndirectCallTransformPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createPGOIndirectCallTransformPass()); +} + void LLVMAddStripDeadPrototypesPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createStripDeadPrototypesPass()); } Index: lib/Transforms/IPO/PGOIndirectCallTransform.cpp =================================================================== --- /dev/null +++ lib/Transforms/IPO/PGOIndirectCallTransform.cpp @@ -0,0 +1,444 @@ +//===- PGOIndirectCallTransform.cpp - PGO-based indirect call promotion ---===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass uses profile information to transform indirect calls to +// direct calls when it is profitable. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +#define DEBUG_TYPE "ic-opt" + +STATISTIC(NumIndCalls, "Number of indirect calls transformed"); +STATISTIC(NumIndInvokes, "Number of indirect invokes transformed"); + +static cl::opt + NumICTs("numICT", cl::Hidden, cl::ZeroOrMore, cl::init(UINT_MAX), + cl::desc("Allow up to num IC transformations")); + +static cl::opt + CallHotnessT("callHotnessThreshold", cl::init(0.001f), cl::Hidden, + cl::ZeroOrMore, cl::desc("IC callsite hotness threshold")); + +static cl::opt + TargetHotnessT("targetHotnessThreshold", cl::Hidden, cl::ZeroOrMore, + cl::init(40), cl::desc("IC target hotness threshold")); + +static cl::opt EnableTarget2("enable-second-target", cl::init(true), + cl::Hidden, cl::ZeroOrMore, + cl::desc("Allow peeling of second target")); + +static cl::opt + Target2HotnessT("target2HotnessThreshold", cl::Hidden, cl::ZeroOrMore, + cl::init(30), cl::desc("IC 2nd target hotness threshold")); + +static cl::opt InlineHintT("inlineHintThreshold", cl::Hidden, + cl::ZeroOrMore, cl::init(1), + cl::desc("inline hint threshold")); + +namespace llvm { +class PGOIndirectCallTransform : public ModulePass { + bool runOnModule(Module &M); + +public: + static char ID; // Pass identification, replacement for typeid + PGOIndirectCallTransform() : ModulePass(ID) { + initializePGOIndirectCallTransformPass(*PassRegistry::getPassRegistry()); + } + + virtual void getAnalysisUsage(AnalysisUsage &AU) const {} + + virtual const char *getPassName() const { + return "Profile-Based Indirect Call Optimization"; + } +}; + +ModulePass *createPGOIndirectCallTransformPass() { + return new PGOIndirectCallTransform(); +} +} + +char PGOIndirectCallTransform::ID = 0; +INITIALIZE_PASS(PGOIndirectCallTransform, "ic-opt", + "Profile-Based Indirect Call Optimization", false, false) + +static bool IsIndCallWithProfileCount(Instruction &I, uint64_t &ProfCnt) { + // Check if the instruction is a CallInst or InvokeInst + CallSite CS(&I); + if (!CS.isCall() && !CS.isInvoke()) + return false; + + // Check if the callsite makes a direct call + Value *Callee = CS.getCalledValue(); + Function *DirectCallee = nullptr; + if (const ConstantExpr *CE = dyn_cast(Callee)) + DirectCallee = dyn_cast(CE->getOperand(0)); + else + DirectCallee = dyn_cast(Callee); + if (DirectCallee) + return false; + + // Check if the indirect call contains any metadata + if (MDNode *MD = I.getMetadata(LLVMContext::MD_prof)) { + // MDNode should have at least one direct callee candidate + if (MD->getNumOperands() >= 4) { + // MDNode's first argument should be "indirect_call_targets" string + if (MDString *MS = cast(MD->getOperand(0))) { + if (MS->getString().equals("indirect_call_targets")) { + // The second argument is the number of times the call is executed + if (ConstantInt *CI = + mdconst::dyn_extract(MD->getOperand(1))) { + ProfCnt = CI->getZExtValue(); + return true; + } + } + } + } + } + return false; +} + +static void fit32BitWeights(MutableArrayRef Weights) { + uint64_t Max = *std::max_element(Weights.begin(), Weights.end()); + if (Max > UINT_MAX) { + unsigned Offset = 32 - countLeadingZeros(Max); + for (uint64_t &I : Weights) + I >>= Offset; + } +} + +// Peel one target from the indirect call site IC - +// update CFG and metadata +static void peelOneTarget(Instruction *&IC, Function *HotTargetFn, + uint64_t TargetCount, bool Hint) { + LLVMContext &Ctx = IC->getContext(); + MDNode *MD = IC->getMetadata(LLVMContext::MD_prof); + CallSite CS(IC); + Value *OrigCallee = CS.getCalledValue(); + Value *Callee = OrigCallee; + if (const ConstantExpr *CE = dyn_cast(Callee)) + Callee = CE->getOperand(0); + bool IsCall = CS.isCall(); + + BasicBlock *BB = IC->getParent(); + Function *F = BB->getParent(); + DEBUG(dbgs() << "\n\n== Basic Block Before ==\n"); + DEBUG(dbgs() << *BB << "\n"); + BasicBlock *MergeBB = BB->splitBasicBlock(IC, "if.merge"); + BB->getTerminator()->eraseFromParent(); + + // Create a phi to unify the return values of calls + PHINode *CallRetPhi = nullptr; + if (!(IC->getType()->isVoidTy())) + CallRetPhi = PHINode::Create(IC->getType(), 0); + + // + // Create and process the block with the direct call (TrueBB) + // + Value *NewCallee = nullptr; + if (const ConstantExpr *CE = dyn_cast(OrigCallee)) + NewCallee = CE->getWithOperandReplaced(0, HotTargetFn); + else + NewCallee = HotTargetFn; + + Instruction *NewI = IC->clone(); + NewI->setMetadata(LLVMContext::MD_prof, nullptr); + CallSite NewCS(NewI); + NewCS.setCalledFunction(NewCallee); + BasicBlock *TrueBB = BasicBlock::Create(Ctx, "if.true", F); + TrueBB->getInstList().push_back(NewI); + + FunctionType *NewCalleeType = HotTargetFn->getFunctionType(); + unsigned ArgNum = NewCS.arg_size(); + for (unsigned I = 0; I < ArgNum; ++I) { + Type *ATy = NewCS.getArgument(I)->getType(); + Type *PTy = NewCalleeType->getParamType(I); + if (ATy != PTy) { + BitCastInst *BI = new BitCastInst(NewCS.getArgument(I), PTy, "", NewI); + NewCS.setArgument(I, BI); + } + } + + if (IsCall) { + BranchInst::Create(MergeBB, TrueBB); + CallInst *CallI = dyn_cast(NewI); + if (Hint) + CallI->addAttribute(AttributeSet::FunctionIndex, Attribute::InlineHint); + } else { + InvokeInst *InvI = dyn_cast(NewI); + BranchInst::Create(InvI->getNormalDest(), MergeBB); + InvI->setNormalDest(MergeBB); + if (Hint) + InvI->addAttribute(AttributeSet::FunctionIndex, Attribute::InlineHint); + } + if (CallRetPhi) + CallRetPhi->addIncoming(NewI, TrueBB); + + // + // Process MergeBB + // + if (CallRetPhi) { + // replace the original call with phi + MergeBB->getInstList().push_front(CallRetPhi); + IC->replaceAllUsesWith(CallRetPhi); + } + IC->removeFromParent(); // remove the call from MergeBB + + // + // Create and process the block with the indirect call (FalseBB) + // + BasicBlock *FalseBB = BasicBlock::Create(Ctx, "if.false", F); + + // update metadata for the IC instruction w/o the peeled target + SmallVector Vals(MD->getNumOperands() - 2); + Vals[0] = MD->getOperand(0); // copy "indirect_call_targets" + ConstantInt *CtI = mdconst::dyn_extract(MD->getOperand(1)); + uint64_t CSCnt = CtI->getZExtValue(); + Type *Int64Ty = Type::getInt64Ty(Ctx); + Vals[1] = ValueAsMetadata::getConstant( + ConstantInt::get(Int64Ty, (CSCnt - TargetCount))); + // skip data for the peeled target and copy the rest + for (unsigned OpI = 4, OpE = MD->getNumOperands(); OpI < OpE; ++OpI) + Vals[OpI - 2] = MD->getOperand(OpI); + IC->setMetadata(LLVMContext::MD_prof, MDNode::get(Ctx, Vals)); + FalseBB->getInstList().push_back(IC); + + if (IsCall) + BranchInst::Create(MergeBB, FalseBB); + else { + InvokeInst *InvI = dyn_cast(IC); + InvI->setNormalDest(MergeBB); + + // fix unwind block's phis for invokes in TrueBB and FalseBB + BasicBlock *UDest = InvI->getUnwindDest(); + for (BasicBlock::iterator I = UDest->begin(); isa(I); ++I) { + PHINode *PN = dyn_cast(I); + Value *MBBVal = PN->getIncomingValueForBlock(MergeBB); + PN->addIncoming(MBBVal, TrueBB); + PN->addIncoming(MBBVal, FalseBB); + PN->removeIncomingValue(MergeBB); + } + } + if (CallRetPhi) + CallRetPhi->addIncoming(IC, FalseBB); + + // + // Process the original block BB + // + BitCastInst *BCI1 = new BitCastInst(Callee, Type::getInt8PtrTy(Ctx), "", BB); + BitCastInst *BCI2 = + new BitCastInst(HotTargetFn, Type::getInt8PtrTy(Ctx), "", BB); + ICmpInst *CmpI = new ICmpInst(*BB, ICmpInst::ICMP_EQ, BCI1, BCI2); + BranchInst *BrI = BranchInst::Create(TrueBB, FalseBB, CmpI, BB); + // add branch weight metadata + SmallVector NewWeights; + NewWeights.push_back(TargetCount); + NewWeights.push_back(CSCnt - TargetCount); + fit32BitWeights(NewWeights); + SmallVector MDWeights(NewWeights.begin(), NewWeights.end()); + BrI->setMetadata(LLVMContext::MD_prof, + MDBuilder(Ctx).createBranchWeights(MDWeights)); + + DEBUG(dbgs() << "\n== Basic Blocks After ==\n"); + DEBUG(dbgs() << *BB << *TrueBB << *FalseBB << *MergeBB << "\n"); + + if (IsCall) + NumIndCalls++; + else + NumIndInvokes++; +} + +bool PGOIndirectCallTransform::runOnModule(Module &M) { + DEBUG(dbgs() << "\n**** INDIRECT CALL OPTIMIZATION ***\n"); + + if (TargetHotnessT < 10 || TargetHotnessT > 100) { + DEBUG(dbgs() << "Bad ic target hotness threshold = " << TargetHotnessT); + DEBUG(dbgs() << " - should be between 10 and 100\n"); + return false; + } + DEBUG(dbgs() << "IC target hotness threshold = " << TargetHotnessT << "\n"); + + // + // Compute the sum of all indirect calls/invokes execution count + // + double TotalICCount = 0.0; + for (auto &F : M) { + if (F.isDeclaration()) + continue; + for (auto &BB : F) { + for (auto &I : BB) { + // Check if the instruction is an indirect CallInst/InvokeInst and + // return the the number of times it is executed according to profile + uint64_t IProfCnt = 0; + if (IsIndCallWithProfileCount(I, IProfCnt)) + TotalICCount += (double)IProfCnt; + } + } + } + DEBUG(dbgs() << "Total IC execution count = " << TotalICCount << "\n"); + if (TotalICCount == 0.0) + return false; + + bool Changed = false; + for (auto &F : M) { + if (F.isDeclaration()) + continue; + + SmallVector CSVec; + for (auto &BB : F) { + for (auto &I : BB) { + // Check if the instruction is an indirect CallInst/InvokeInst and + // return the the number of times it is executed according to profile + uint64_t IProfCnt = 0; + if (IsIndCallWithProfileCount(I, IProfCnt)) { + // Compare IProfCount with TotalICCount: at least CallHotnessT + // is needed to consider transforming this call/invoke + if (IProfCnt >= CallHotnessT * TotalICCount) + CSVec.push_back(&I); + } + } + } + + // + // Go over indirect calls and for each call IC: peel the hottest target(s) + // from IC if its percentage is above threshold + // + while (!CSVec.empty()) { + Instruction *IC = CSVec.pop_back_val(); + assert(isa(IC) || isa(IC)); + + // for debugging + if ((NumIndCalls + NumIndInvokes) >= NumICTs) + continue; + + MDNode *MD = IC->getMetadata(LLVMContext::MD_prof); + assert(MD); + DEBUG(dbgs() << "\nAttempting IC_opt on: " << *IC); + DEBUG(dbgs() << "\nwith: " << *MD); + DEBUG(dbgs() << "in function: " << F.getName()); + + // Get the total number of times IC is executed + ConstantInt *CtI = mdconst::dyn_extract(MD->getOperand(1)); + uint64_t CSCnt = CtI->getZExtValue(); + DEBUG(dbgs() << "\nCS hotness% = " << (CSCnt * 100) / TotalICCount + << "\n"); + + // Get callee + CallSite CS(IC); + Value *OrigCallee = CS.getCalledValue(); + Value *Callee = OrigCallee; + if (const ConstantExpr *CE = dyn_cast(Callee)) + Callee = CE->getOperand(0); + + // Collect the target functions from profile metadata + Function *HotTargetFn = nullptr; + uint64_t TargetCount = 0; + Function *HotTarget2Fn = nullptr; + uint64_t Target2Count = 0; + + for (unsigned OpI = 2, OpE = MD->getNumOperands(); OpI < OpE; OpI += 2) { + MDString *MS = dyn_cast(MD->getOperand(OpI)); + ConstantInt *CI = + mdconst::dyn_extract(MD->getOperand(OpI + 1)); + if (!MS || !CI) + break; // something is bad; bail out + + if (OpI == 2) { + TargetCount = CI->getZExtValue(); + DEBUG( + dbgs() << "\nTarget1 hotness% = " << (TargetCount * 100) / CSCnt); + if (TargetCount < (uint64_t)(TargetHotnessT * (double)CSCnt) / 100) + break; // this target is not hot enough; skip the rest + } else if (EnableTarget2 && HotTargetFn && OpI == 4) { + Target2Count = CI->getZExtValue(); + DEBUG(dbgs() << "\n\nTarget2 hotness% = " + << (Target2Count * 100) / CSCnt); + if (Target2Count < (uint64_t)(Target2HotnessT * (double)CSCnt) / 100) + break; // this target is not hot enough; skip the rest + } else + break; + + Function *TargetFn = M.getFunction(MS->getString()); + if (!TargetFn) + break; + // check if IC and target have the same types of args and return value + Type *ARTy = IC->getType(); + Type *BRTy = TargetFn->getReturnType(); + DEBUG(dbgs() << "\nindTy = " << *ARTy); + DEBUG(dbgs() << "\ndirTy = " << *BRTy); + if (ARTy != BRTy) // TODO: allow compatible types + break; + + FunctionType *TargetType = TargetFn->getFunctionType(); + unsigned ArgNum = CS.arg_size(); + if (CS.arg_size() != TargetType->getNumParams()) + break; + + bool Match = true; + for (unsigned I = 0; I < ArgNum; ++I) { + Type *ATy = CS.getArgument(I)->getType(); + Type *PTy = TargetType->getParamType(I); + DEBUG(dbgs() << "\nargTy = " << *ATy); + DEBUG(dbgs() << "\nparTy = " << *PTy); + if ((ATy != PTy) && + (!isa(ATy) || !isa(PTy))) { + Match = false; + break; // TODO: allow addtional compatible types + } + } + if (Match) { + if (OpI == 2) { + HotTargetFn = TargetFn; + continue; + } + assert(OpI == 4); + HotTarget2Fn = TargetFn; + } + break; + } + + if (!HotTargetFn) { + DEBUG(dbgs() << "\nnot a good candidate, skip it.\n"); + continue; + } + + // + // We've found a hot target, so peel it from the indirect call + // + DEBUG(dbgs() << "\nSuccess for target1: " << HotTargetFn->getName()); + + // add inline hint if the target count over total IC count is + // above a threshold + bool Hint = (TargetCount >= InlineHintT * TotalICCount / 100); + peelOneTarget(IC, HotTargetFn, TargetCount, Hint); + + if (HotTarget2Fn) { + DEBUG(dbgs() << "\nSuccess for target2: " << HotTarget2Fn->getName()); + Hint = (Target2Count >= InlineHintT * TotalICCount / 100); + peelOneTarget(IC, HotTarget2Fn, Target2Count, Hint); + } + + Changed = true; + } + } + return Changed; +} Index: test/Transforms/IndirectCallOpt/ico0.ll =================================================================== --- /dev/null +++ test/Transforms/IndirectCallOpt/ico0.ll @@ -0,0 +1,27 @@ +; RUN: opt -ic-opt -S < %s | FileCheck %s +define void @foo(i32 %a) { +entry: + %a1 = add i32 %a, 1 + ret void +} + +define void @bar(i32 %a) { +entry: + %a2 = add i32 %a, 2 + ret void + +} + +define void @main(void (i32)* %fun) { +entry: + call void %fun(i32 10), !prof !1 + ret void +} + +!1 = !{!"indirect_call_targets", i64 6000, !"foo", i64 5000, !"bar", i64 100} + +; CHECK: call void @foo +; CHECK-NOT: call void @bar +; CHECK: !"branch_weights", i32 5000, i32 1000} +; CHECK: !"indirect_call_targets", i64 1000, !"bar", i64 100} + Index: test/Transforms/IndirectCallOpt/ico1.ll =================================================================== --- /dev/null +++ test/Transforms/IndirectCallOpt/ico1.ll @@ -0,0 +1,27 @@ +; RUN: opt -ic-opt -S < %s | FileCheck %s +define i32 @foo(i32 %a) { +entry: + %a1 = add i32 %a, 1 + ret i32 %a1 +} + +define i32 @bar(i32 %a) { +entry: + %a2 = add i32 %a, 2 + ret i32 %a2 + +} + +define i32 @main(i32 (i32)* %fun) { +entry: + %b = call i32 %fun(i32 10), !prof !1 + ret i32 %b +} + +!1 = !{!"indirect_call_targets", i64 6000, !"foo", i64 5000, !"bar", i64 100} + +; CHECK: call i32 @foo +; CHECK-NOT: call i32 @bar +; CHECK: !"branch_weights", i32 5000, i32 1000} +; CHECK: !"indirect_call_targets", i64 1000, !"bar", i64 100} + Index: test/Transforms/IndirectCallOpt/ico2.ll =================================================================== --- /dev/null +++ test/Transforms/IndirectCallOpt/ico2.ll @@ -0,0 +1,25 @@ +; RUN: opt -ic-opt -S < %s | FileCheck %s +define void @foo(i32 %a) { +entry: + %a1 = add i32 %a, 1 + ret void +} + +define void @bar(i32 %a) { +entry: + %a2 = add i32 %a, 2 + ret void + +} + +define void @main(void (i32)* %fun) { +entry: + call void %fun(i32 10), !prof !1 + ret void +} + +!1 = !{!"indirect_call_targets", i64 6000, !"foo", i64 500, !"bar", i64 100} + +; CHECK-NOT: call void @foo +; CHECK-NOT: call void @bar + Index: test/Transforms/IndirectCallOpt/ico3.ll =================================================================== --- /dev/null +++ test/Transforms/IndirectCallOpt/ico3.ll @@ -0,0 +1,28 @@ +; RUN: opt -ic-opt -S < %s | FileCheck %s +define void @foo(i32 %a) { +entry: + %a1 = add i32 %a, 1 + ret void +} + +define void @bar(i32 %a) { +entry: + %a2 = add i32 %a, 2 + ret void + +} + +define void @main(void (i32)* %fun) { +entry: + call void %fun(i32 10), !prof !1 + ret void +} + +!1 = !{!"indirect_call_targets", i64 6000, !"foo", i64 3000, !"bar", i64 2500} + +; CHECK: call void @foo +; CHECK: call void @bar +; CHECK: !"branch_weights", i32 3000, i32 3000} +; CHECK: !"branch_weights", i32 2500, i32 500} +; CHECK: !"indirect_call_targets", i64 500} +