Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -197,6 +197,9 @@ int getIntrinsicCost(Intrinsic::ID IID, Type *RetTy, ArrayRef Arguments) const; + /// \return The estimated number of case clusters for the 'SI' when lowered. + int getEstimatedNumberOfCaseClusters(const SwitchInst &SI) const; + /// \brief Estimate the cost of a given IR user when lowered. /// /// This can estimate the cost of either a ConstantExpr or Instruction when @@ -755,6 +758,7 @@ ArrayRef ParamTys) = 0; virtual int getIntrinsicCost(Intrinsic::ID IID, Type *RetTy, ArrayRef Arguments) = 0; + virtual int getEstimatedNumberOfCaseClusters(const SwitchInst &SI) = 0; virtual int getUserCost(const User *U) = 0; virtual bool hasBranchDivergence() = 0; virtual bool isSourceOfDivergence(const Value *V) = 0; @@ -1052,6 +1056,9 @@ unsigned getMaxInterleaveFactor(unsigned VF) override { return Impl.getMaxInterleaveFactor(VF); } + int getEstimatedNumberOfCaseClusters(const SwitchInst &SI) override { + return Impl.getEstimatedNumberOfCaseClusters(SI); + } unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, OperandValueKind Opd2Info, Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -114,6 +114,10 @@ return TTI::TCC_Free; } + int getEstimatedNumberOfCaseClusters(const SwitchInst &SI) { + return SI.getNumCases(); + } + unsigned getCallCost(FunctionType *FTy, int NumArgs) { assert(FTy && "FunctionType must be provided to this routine."); Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -17,11 +17,12 @@ #define LLVM_CODEGEN_BASICTTIIMPL_H #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfoImpl.h" +#include "llvm/CodeGen/SwitchCaseCluster.h" #include "llvm/Support/CommandLine.h" #include "llvm/Target/TargetLowering.h" #include "llvm/Target/TargetSubtargetInfo.h" -#include "llvm/Analysis/TargetLibraryInfo.h" namespace llvm { @@ -171,6 +172,13 @@ return BaseT::getIntrinsicCost(IID, RetTy, ParamTys); } + int getEstimatedNumberOfCaseClusters(const SwitchInst &SI) { + SwitchCaseClusterFinder CaseClusters( + this->getDataLayout(), *getST()->getTargetLowering(), + getTLI()->getTargetMachine().getOptLevel()); + return CaseClusters.getEstimatedNumberOfCluster(SI); + } + unsigned getJumpBufAlignment() { return getTLI()->getJumpBufAlignment(); } unsigned getJumpBufSize() { return getTLI()->getJumpBufSize(); } Index: include/llvm/CodeGen/SwitchCaseCluster.h =================================================================== --- include/llvm/CodeGen/SwitchCaseCluster.h +++ include/llvm/CodeGen/SwitchCaseCluster.h @@ -119,6 +119,10 @@ /// Calculate clusters for cases in SI and store them in Clusters. const BasicBlock *findClusters(const SwitchInst &SI, CaseClusterVector &Clusters); + + /// Returns the estimated number of clusters. + unsigned getEstimatedNumberOfCluster(const SwitchInst &SI); + private: /// Extract cases from the switch and build initial form of case clusters. void formInitalCaseClusers(const SwitchInst &SI, CaseClusterVector &Clusters); Index: lib/Analysis/InlineCost.cpp =================================================================== --- lib/Analysis/InlineCost.cpp +++ lib/Analysis/InlineCost.cpp @@ -1003,22 +1003,40 @@ if (isa(V)) return true; - // Otherwise, we need to accumulate a cost proportional to the number of - // distinct successor blocks. This fan-out in the CFG cannot be represented - // for free even if we can represent the core switch as a jumptable that - // takes a single instruction. + // Otherwise, we assume the most general case where the big swith is lowered + // into a balanced binary tree consisting of case clusters, the probability of + // entering each case cluster is equal. The cost of the switch is proportional + // to the size of the tree. // // NB: We convert large switches which are just used to initialize large phi // nodes to lookup tables instead in simplify-cfg, so this shouldn't prevent // inlining those. It will prevent inlining in cases where the optimization // does not (yet) fire. - SmallPtrSet SuccessorBlocks; - SuccessorBlocks.insert(SI.getDefaultDest()); - for (auto I = SI.case_begin(), E = SI.case_end(); I != E; ++I) - SuccessorBlocks.insert(I.getCaseSuccessor()); - // Add cost corresponding to the number of distinct destinations. The first - // we model as free because of fallthrough. - Cost += (SuccessorBlocks.size() - 1) * InlineConstants::InstrCost; + int NumCaseCluster = TTI.getEstimatedNumberOfCaseClusters(SI); + SmallVector SwitchWorkList; + SwitchWorkList.push_back(NumCaseCluster); + Cost -= InlineConstants::InstrCost; + while (!SwitchWorkList.empty()) { + unsigned NumCases = SwitchWorkList.back(); + SwitchWorkList.pop_back(); + if (NumCases <= 3) + // Do not split the tree if the number of remaining cases is less than 3. + // Just compare switch condition with each case value. Suppose each + // comparison includes one compare and one conditional branch. + Cost += (2 * NumCases * InlineConstants::InstrCost); + else { + // Split the remaining nodes and add one more comparison. + unsigned NumLeft = NumCases / 2; + unsigned NumRight = NumCases - NumLeft; + SwitchWorkList.push_back(NumLeft); + SwitchWorkList.push_back(NumRight); + Cost += (2 * InlineConstants::InstrCost); + } + // Exit early if Cost is already larger than Threshold. + if (Cost > Threshold) + return false; + } + return false; } Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -83,6 +83,11 @@ return Cost; } +int TargetTransformInfo::getEstimatedNumberOfCaseClusters( + const SwitchInst &SI) const { + return TTIImpl->getEstimatedNumberOfCaseClusters(SI); +} + int TargetTransformInfo::getUserCost(const User *U) const { int Cost = TTIImpl->getUserCost(U); assert(Cost >= 0 && "TTI should not produce negative costs!"); Index: lib/CodeGen/SelectionDAG/SwitchCaseCluster.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SwitchCaseCluster.cpp +++ lib/CodeGen/SelectionDAG/SwitchCaseCluster.cpp @@ -61,6 +61,59 @@ return OptForSize ? OptsizeJumpTableDensity : JumpTableDensity; } +/// Returns the estimated number of clusters. Note that the number of clusters +/// identified in this function could be different from the actural numbers +/// found for lowering by findClusters(). This function ignore switches that +/// are lowered with a mix of jump table / bit test / BTree. This function was +/// mainly intended to be used when estimating the cost of switch in inline cost +/// heuristic. +unsigned +SwitchCaseClusterFinder::getEstimatedNumberOfCluster(const SwitchInst &SI) { + unsigned N = SI.getNumCases(); + const bool OptForSize = SI.getParent()->getParent()->optForSize(); + const unsigned MaxJumpTableSize = getMaxJumpTableSize(OptForSize, TLI); + bool IsJTAllowed = areJTsAllowed(TLI, &SI); + + if (N < 1 || + (DL.getPointerSizeInBits() < MaxJumpTableSize && MaxJumpTableSize < N)) + return N; + + if (!IsJTAllowed && DL.getPointerSizeInBits() < N) + return N; + + APInt MaxCaseVal = (SI.case_begin()).getCaseValue()->getValue(); + APInt MinCaseVal = MaxCaseVal; + for (auto I = SI.case_begin(), E = SI.case_end(); I != E; ++I) { + const APInt &CaseVal = I.getCaseValue()->getValue(); + if (CaseVal.sgt(MaxCaseVal)) + MaxCaseVal = CaseVal; + if (CaseVal.slt(MinCaseVal)) + MinCaseVal = CaseVal; + } + + // Check if suitable for a bit test + if (N <= DL.getPointerSizeInBits()) { + SmallPtrSet Dests; + for (auto I = SI.case_begin(), E = SI.case_end(); I != E; ++I) + Dests.insert(I.getCaseSuccessor()); + + if (isSuitableForBitTests(Dests.size(), N, MinCaseVal, MaxCaseVal)) + return 1; + } + + // Check if suitable for a jump table. + if (IsJTAllowed && + !isTooSmallForJumptable(N, TLI.getMinimumJumpTableEntries())) { + const unsigned MinDensity = getJumptableMinDensity(OptForSize); + unsigned JumpTableSize = + (MaxCaseVal - MinCaseVal).getLimitedValue(UINT64_MAX - 1) + 1; + if (JumpTableSize <= MaxJumpTableSize && + ::isDense(JumpTableSize, N, MinDensity)) + return 1; + } + return N; +} + bool SwitchCaseClusterFinder::isDense( const CaseClusterVector &Clusters, const SmallVectorImpl &TotalCases, unsigned First, unsigned Last, Index: test/Transforms/Inline/switch.ll =================================================================== --- test/Transforms/Inline/switch.ll +++ test/Transforms/Inline/switch.ll @@ -1,7 +1,7 @@ ; RUN: opt < %s -inline -inline-threshold=20 -S | FileCheck %s ; RUN: opt < %s -passes='cgscc(inline)' -inline-threshold=20 -S | FileCheck %s -define i32 @callee(i32 %a) { +define i32 @callee1(i32 %a) { switch i32 %a, label %sw.default [ i32 0, label %sw.bb0 i32 1, label %sw.bb1 @@ -52,10 +52,36 @@ ret i32 42 } +define i32 @callee2(i32 %a) { + switch i32 %a, label %sw.default [ + i32 0, label %sw.bb0 + i32 1, label %sw.bb0 + i32 2, label %sw.bb0 + i32 3, label %sw.bb0 + i32 4, label %sw.bb0 + i32 5, label %sw.bb0 + i32 6, label %sw.bb0 + i32 7, label %sw.bb0 + i32 8, label %sw.bb0 + i32 9, label %sw.bb0 + ] + +sw.default: + br label %return + +sw.bb0: + br label %return + +return: + ret i32 42 +} + define i32 @caller(i32 %a) { ; CHECK-LABEL: @caller( -; CHECK: call i32 @callee( +; CHECK: call i32 @callee1( +; CHECK: call i32 @callee2( - %result = call i32 @callee(i32 %a) - ret i32 %result + %result1 = call i32 @callee1(i32 %a) + %result2 = call i32 @callee2(i32 %a) + ret i32 %result1 }