Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -194,6 +194,10 @@ int getIntrinsicCost(Intrinsic::ID IID, Type *RetTy, ArrayRef Arguments) const; + /// \return The number of case clusters which will be created when lowering + /// the 'SI'. + int getNumberOfCaseClusters(const SwitchInst &SI) const; + /// \brief Estimate the cost of a given IR user when lowered. /// /// This can estimate the cost of either a ConstantExpr or Instruction when @@ -746,6 +750,7 @@ ArrayRef ParamTys) = 0; virtual int getIntrinsicCost(Intrinsic::ID IID, Type *RetTy, ArrayRef Arguments) = 0; + virtual int getNumberOfCaseClusters(const SwitchInst &SI) = 0; virtual int getUserCost(const User *U) = 0; virtual bool hasBranchDivergence() = 0; virtual bool isSourceOfDivergence(const Value *V) = 0; @@ -1037,6 +1042,9 @@ unsigned getMaxInterleaveFactor(unsigned VF) override { return Impl.getMaxInterleaveFactor(VF); } + int getNumberOfCaseClusters(const SwitchInst &SI) override { + return Impl.getNumberOfCaseClusters(SI); + } unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, OperandValueKind Opd2Info, Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -114,6 +114,8 @@ return TTI::TCC_Free; } + int getNumberOfCaseClusters(const SwitchInst &SI) { return SI.getNumCases(); } + unsigned getCallCost(FunctionType *FTy, int NumArgs) { assert(FTy && "FunctionType must be provided to this routine."); Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -17,11 +17,12 @@ #define LLVM_CODEGEN_BASICTTIIMPL_H #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfoImpl.h" +#include "llvm/CodeGen/SwitchLoweringCaseCluster.h" #include "llvm/Support/CommandLine.h" #include "llvm/Target/TargetLowering.h" #include "llvm/Target/TargetSubtargetInfo.h" -#include "llvm/Analysis/TargetLibraryInfo.h" namespace llvm { @@ -171,6 +172,15 @@ return BaseT::getIntrinsicCost(IID, RetTy, ParamTys); } + int getNumberOfCaseClusters(const SwitchInst &SI) { + CaseClusterVector Clusters; + SwitchLoweringCaseCluster CaseClusters( + this->getDataLayout(), *getST()->getTargetLowering(), + getTLI()->getTargetMachine().getOptLevel()); + CaseClusters.findCaseClusters(SI, Clusters, nullptr); + return Clusters.size(); + } + unsigned getJumpBufAlignment() { return getTLI()->getJumpBufAlignment(); } unsigned getJumpBufSize() { return getTLI()->getJumpBufSize(); } Index: lib/Analysis/InlineCost.cpp =================================================================== --- lib/Analysis/InlineCost.cpp +++ lib/Analysis/InlineCost.cpp @@ -1003,22 +1003,41 @@ if (isa(V)) return true; - // Otherwise, we need to accumulate a cost proportional to the number of - // distinct successor blocks. This fan-out in the CFG cannot be represented - // for free even if we can represent the core switch as a jumptable that - // takes a single instruction. + // Otherwise, we assume the most general case where the big swith is lowered + // into a balanced binary tree consisting of case clusters, the probability of + // entering each case cluster is equal. The cost of the switch is proportional + // to the size of the tree. // // NB: We convert large switches which are just used to initialize large phi // nodes to lookup tables instead in simplify-cfg, so this shouldn't prevent // inlining those. It will prevent inlining in cases where the optimization // does not (yet) fire. - SmallPtrSet SuccessorBlocks; - SuccessorBlocks.insert(SI.getDefaultDest()); - for (auto I = SI.case_begin(), E = SI.case_end(); I != E; ++I) - SuccessorBlocks.insert(I.getCaseSuccessor()); - // Add cost corresponding to the number of distinct destinations. The first - // we model as free because of fallthrough. - Cost += (SuccessorBlocks.size() - 1) * InlineConstants::InstrCost; + int NumCaseCluster = TTI.getNumberOfCaseClusters(SI); + SmallVector SwitchWorkList; + SwitchWorkList.push_back(NumCaseCluster); + Cost -= InlineConstants::InstrCost; + while (!SwitchWorkList.empty()) { + unsigned NumCases = SwitchWorkList.back(); + SwitchWorkList.pop_back(); + + if (NumCases <= 3) + // Do not split the tree if the number of remaining cases is less than 3. + // Just compare switch condition with each case value. Suppose each + // comparison includes one compare and one conditional branch. + Cost += (2 * NumCases * InlineConstants::InstrCost); + else { + // Split the remaining nodes and add one more comparison. + unsigned NumLeft = NumCases / 2; + unsigned NumRight = NumCases - NumLeft; + SwitchWorkList.push_back(NumLeft); + SwitchWorkList.push_back(NumRight); + Cost += (2 * InlineConstants::InstrCost); + } + // Exit early if Cost is already larger than Threshold. + if (Cost > Threshold) + return false; + } + return false; } Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -83,6 +83,10 @@ return Cost; } +int TargetTransformInfo::getNumberOfCaseClusters(const SwitchInst &SI) const { + return TTIImpl->getNumberOfCaseClusters(SI); +} + int TargetTransformInfo::getUserCost(const User *U) const { int Cost = TTIImpl->getUserCost(U); assert(Cost >= 0 && "TTI should not produce negative costs!"); Index: test/Transforms/Inline/switch.ll =================================================================== --- test/Transforms/Inline/switch.ll +++ test/Transforms/Inline/switch.ll @@ -1,7 +1,7 @@ ; RUN: opt < %s -inline -inline-threshold=20 -S | FileCheck %s ; RUN: opt < %s -passes='cgscc(inline)' -inline-threshold=20 -S | FileCheck %s -define i32 @callee(i32 %a) { +define i32 @callee1(i32 %a) { switch i32 %a, label %sw.default [ i32 0, label %sw.bb0 i32 1, label %sw.bb1 @@ -52,10 +52,36 @@ ret i32 42 } +define i32 @callee2(i32 %a) { + switch i32 %a, label %sw.default [ + i32 0, label %sw.bb0 + i32 1, label %sw.bb0 + i32 2, label %sw.bb0 + i32 3, label %sw.bb0 + i32 4, label %sw.bb0 + i32 5, label %sw.bb0 + i32 6, label %sw.bb0 + i32 7, label %sw.bb0 + i32 8, label %sw.bb0 + i32 9, label %sw.bb0 + ] + +sw.default: + br label %return + +sw.bb0: + br label %return + +return: + ret i32 42 +} + define i32 @caller(i32 %a) { ; CHECK-LABEL: @caller( -; CHECK: call i32 @callee( +; CHECK: call i32 @callee1( +; CHECK: call i32 @callee2( - %result = call i32 @callee(i32 %a) - ret i32 %result + %result1 = call i32 @callee1(i32 %a) + %result2 = call i32 @callee2(i32 %a) + ret i32 %result1 }