Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -197,6 +197,12 @@ int getIntrinsicCost(Intrinsic::ID IID, Type *RetTy, ArrayRef Arguments) const; + /// \return The estimated number of case clusters for the 'SI' when lowered. + /// \p JTSize Set a jump table size only when /p SI is suitable for a jump + /// table. + int getEstimatedNumberOfCaseClusters(const SwitchInst &SI, + unsigned *JTSize = nullptr) const; + /// \brief Estimate the cost of a given IR user when lowered. /// /// This can estimate the cost of either a ConstantExpr or Instruction when @@ -755,6 +761,8 @@ ArrayRef ParamTys) = 0; virtual int getIntrinsicCost(Intrinsic::ID IID, Type *RetTy, ArrayRef Arguments) = 0; + virtual int getEstimatedNumberOfCaseClusters(const SwitchInst &SI, + unsigned *JTSize = nullptr) = 0; virtual int getUserCost(const User *U) = 0; virtual bool hasBranchDivergence() = 0; virtual bool isSourceOfDivergence(const Value *V) = 0; @@ -1052,6 +1060,10 @@ unsigned getMaxInterleaveFactor(unsigned VF) override { return Impl.getMaxInterleaveFactor(VF); } + int getEstimatedNumberOfCaseClusters(const SwitchInst &SI, + unsigned *JTSize = nullptr) override { + return Impl.getEstimatedNumberOfCaseClusters(SI, JTSize); + } unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, OperandValueKind Opd2Info, Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -114,6 +114,11 @@ return TTI::TCC_Free; } + int getEstimatedNumberOfCaseClusters(const SwitchInst &SI, + unsigned *JTSize = nullptr) { + return SI.getNumCases(); + } + unsigned getCallCost(FunctionType *FTy, int NumArgs) { assert(FTy && "FunctionType must be provided to this routine."); Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -17,11 +17,12 @@ #define LLVM_CODEGEN_BASICTTIIMPL_H #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfoImpl.h" +#include "llvm/CodeGen/SwitchCaseCluster.h" #include "llvm/Support/CommandLine.h" #include "llvm/Target/TargetLowering.h" #include "llvm/Target/TargetSubtargetInfo.h" -#include "llvm/Analysis/TargetLibraryInfo.h" namespace llvm { @@ -171,6 +172,14 @@ return BaseT::getIntrinsicCost(IID, RetTy, ParamTys); } + int getEstimatedNumberOfCaseClusters(const SwitchInst &SI, + unsigned *JumpTableSize) { + SwitchCaseClusterFinder CaseClusters( + this->getDataLayout(), *getST()->getTargetLowering(), + getTLI()->getTargetMachine().getOptLevel()); + return CaseClusters.getEstimatedNumberOfClusters(SI, *JumpTableSize); + } + unsigned getJumpBufAlignment() { return getTLI()->getJumpBufAlignment(); } unsigned getJumpBufSize() { return getTLI()->getJumpBufSize(); } Index: include/llvm/CodeGen/SwitchCaseCluster.h =================================================================== --- include/llvm/CodeGen/SwitchCaseCluster.h +++ include/llvm/CodeGen/SwitchCaseCluster.h @@ -119,6 +119,11 @@ /// Calculate clusters for cases in SI and store them in Clusters. const BasicBlock *findClusters(const SwitchInst &SI, CaseClusterVector &Clusters); + + /// Return the estimated number of clusters. + unsigned getEstimatedNumberOfClusters(const SwitchInst &SI, + unsigned &JumptableSize); + private: /// Extract cases from the switch and build initial form of case clusters. void formInitalCaseClusers(const SwitchInst &SI, CaseClusterVector &Clusters); Index: lib/Analysis/InlineCost.cpp =================================================================== --- lib/Analysis/InlineCost.cpp +++ lib/Analysis/InlineCost.cpp @@ -54,6 +54,11 @@ cl::init(45), cl::desc("Threshold for inlining cold callsites")); +static cl::opt + EnableGenericSwitchCost("inline-generic-switch-cost", cl::Hidden, + cl::init(false), + cl::desc("Enable generic switch cost model")); + // We introduce this threshold to help performance of instrumentation based // PGO before we actually hook up inliner with analysis passes such as BPI and // BFI. @@ -1003,22 +1008,74 @@ if (isa(V)) return true; - // Otherwise, we need to accumulate a cost proportional to the number of - // distinct successor blocks. This fan-out in the CFG cannot be represented - // for free even if we can represent the core switch as a jumptable that - // takes a single instruction. + if (!EnableGenericSwitchCost) { + // In this simple switch cost model, we accumulate a cost proportional to + // the number of distinct successor blocks. This fan-out in the CFG cannot + // be represented for free even if we can represent the core switch as a + // jumptable that takes a single instruction. + SmallPtrSet SuccessorBlocks; + SuccessorBlocks.insert(SI.getDefaultDest()); + for (auto I = SI.case_begin(), E = SI.case_end(); I != E; ++I) + SuccessorBlocks.insert(I.getCaseSuccessor()); + // Add cost corresponding to the number of distinct destinations. The first + // we model as free because of fallthrough. + Cost += (SuccessorBlocks.size() - 1) * InlineConstants::InstrCost; + return false; + } + + // Otherwise, we assume the most general case where the swith is lowered into + // either a jump table, bit test, or a balanced binary tree consisting of + // case clusters without merging adjacent clusters with the same destination. + // We do not consider the switches that are lowered with a mix of jump table/ + // bit test/BTree. The cost of the switch is proportional to the size of + // the tree or the size of jump table range. // // NB: We convert large switches which are just used to initialize large phi // nodes to lookup tables instead in simplify-cfg, so this shouldn't prevent // inlining those. It will prevent inlining in cases where the optimization // does not (yet) fire. - SmallPtrSet SuccessorBlocks; - SuccessorBlocks.insert(SI.getDefaultDest()); - for (auto I = SI.case_begin(), E = SI.case_end(); I != E; ++I) - SuccessorBlocks.insert(I.getCaseSuccessor()); - // Add cost corresponding to the number of distinct destinations. The first - // we model as free because of fallthrough. - Cost += (SuccessorBlocks.size() - 1) * InlineConstants::InstrCost; + + // Exit early for a large switch, assuming one case needs at least one + // instruction. + // FIXME: This is not true for a bit test, but ignore such case for now to + // save compile-time. + int CostLowerBound = Cost + SI.getNumCases() * InlineConstants::InstrCost; + if (CostLowerBound > Threshold) { + Cost = CostLowerBound; + return false; + } + + unsigned JumpTableSize = 0; + int NumCaseCluster = TTI.getEstimatedNumberOfCaseClusters(SI, &JumpTableSize); + SmallVector SwitchWorkList; + SwitchWorkList.push_back(NumCaseCluster); + Cost -= InlineConstants::InstrCost; + + // If suitable for a jump table, consider the jump table size. + if (JumpTableSize) + Cost += JumpTableSize * InlineConstants::InstrCost; + + while (!SwitchWorkList.empty()) { + unsigned NumCases = SwitchWorkList.back(); + SwitchWorkList.pop_back(); + if (NumCases <= 3) + // Do not split the tree if the number of remaining cases is less than 3. + // Just compare switch condition with each case value. Suppose each + // comparison includes one compare and one conditional branch. + Cost += (2 * NumCases * InlineConstants::InstrCost); + else { + // Split the remaining nodes and add one more comparison. + unsigned NumLeft = NumCases / 2; + unsigned NumRight = NumCases - NumLeft; + SwitchWorkList.push_back(NumLeft); + SwitchWorkList.push_back(NumRight); + Cost += (2 * InlineConstants::InstrCost); + } + // Exit early if Cost is already larger than Threshold. + if (Cost > Threshold) + return false; + } + return false; } Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -83,6 +83,11 @@ return Cost; } +int TargetTransformInfo::getEstimatedNumberOfCaseClusters( + const SwitchInst &SI, unsigned *JTSize) const { + return TTIImpl->getEstimatedNumberOfCaseClusters(SI, JTSize); +} + int TargetTransformInfo::getUserCost(const User *U) const { int Cost = TTIImpl->getUserCost(U); assert(Cost >= 0 && "TTI should not produce negative costs!"); Index: lib/CodeGen/SelectionDAG/SwitchCaseCluster.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SwitchCaseCluster.cpp +++ lib/CodeGen/SelectionDAG/SwitchCaseCluster.cpp @@ -61,6 +61,64 @@ return OptForSize ? OptsizeJumpTableDensity : JumpTableDensity; } +/// Return the estimated number of clusters. Note that the number of clusters +/// identified in this function could be different from the actural numbers +/// found for lowering by findClusters(). This function ignore switches that +/// are lowered with a mix of jump table / bit test / BTree. This function was +/// initially intended to be used when estimating the cost of switch in inline +/// cost heuristic, but it's a generic cost model to be used in other place +/// (e.g., in loop unrolling). +unsigned +SwitchCaseClusterFinder::getEstimatedNumberOfClusters(const SwitchInst &SI, + unsigned &JumpTableSize) { + unsigned N = SI.getNumCases(); + const bool OptForSize = SI.getParent()->getParent()->optForSize(); + const unsigned MaxJumpTableSize = getMaxJumpTableSize(OptForSize, TLI); + bool IsJTAllowed = areJTsAllowed(TLI, &SI); + JumpTableSize = 0; + + if (N < 1 || + (DL.getPointerSizeInBits() < MaxJumpTableSize && MaxJumpTableSize < N)) + return N; + + if (!IsJTAllowed && DL.getPointerSizeInBits() < N) + return N; + + APInt MaxCaseVal = (SI.case_begin()).getCaseValue()->getValue(); + APInt MinCaseVal = MaxCaseVal; + for (auto I = (SI.case_begin() + 1), E = SI.case_end(); I != E; ++I) { + const APInt &CaseVal = I.getCaseValue()->getValue(); + if (CaseVal.sgt(MaxCaseVal)) + MaxCaseVal = CaseVal; + if (CaseVal.slt(MinCaseVal)) + MinCaseVal = CaseVal; + } + + // Check if suitable for a bit test + if (N <= DL.getPointerSizeInBits()) { + SmallPtrSet Dests; + for (auto I = SI.case_begin(), E = SI.case_end(); I != E; ++I) + Dests.insert(I.getCaseSuccessor()); + + if (isSuitableForBitTests(Dests.size(), N, MinCaseVal, MaxCaseVal)) + return 1; + } + + // Check if suitable for a jump table. + if (IsJTAllowed && + !isTooSmallForJumptable(N, TLI.getMinimumJumpTableEntries())) { + const unsigned MinDensity = getJumptableMinDensity(OptForSize); + unsigned Range = + (MaxCaseVal - MinCaseVal).getLimitedValue(UINT64_MAX - 1) + 1; + + if (Range <= MaxJumpTableSize && ::isDense(Range, N, MinDensity)) { + JumpTableSize = Range; + return 1; + } + } + return N; +} + bool SwitchCaseClusterFinder::isDense( const CaseClusterVector &Clusters, const SmallVectorImpl &TotalCases, unsigned First, unsigned Last, Index: test/Transforms/Inline/AArch64/switch.ll =================================================================== --- /dev/null +++ test/Transforms/Inline/AArch64/switch.ll @@ -0,0 +1,123 @@ +; RUN: opt < %s -inline -inline-threshold=20 -S -mtriple=aarch64-none-linux -inline-generic-switch-cost=true | FileCheck %s +; RUN: opt < %s -passes='cgscc(inline)' -inline-threshold=20 -S -mtriple=aarch64-none-linux -inline-generic-switch-cost=true | FileCheck %s + +define i32 @callee_range(i32 %a, i32* %P) { + switch i32 %a, label %sw.default [ + i32 0, label %sw.bb0 + i32 1000, label %sw.bb1 + i32 2000, label %sw.bb1 + i32 3000, label %sw.bb1 + i32 4000, label %sw.bb1 + i32 5000, label %sw.bb1 + i32 6000, label %sw.bb1 + i32 7000, label %sw.bb1 + i32 8000, label %sw.bb1 + i32 9000, label %sw.bb1 + ] + +sw.default: + store volatile i32 %a, i32* %P + br label %return +sw.bb0: + store volatile i32 %a, i32* %P + br label %return +sw.bb1: + store volatile i32 %a, i32* %P + br label %return +return: + ret i32 42 +} + +define i32 @caller_range(i32 %a, i32* %P) { +; CHECK-LABEL: @caller_range( +; CHECK: call i32 @callee_range + %r = call i32 @callee_range(i32 %a, i32* %P) + ret i32 %r +} + +define i32 @callee_bittest(i32 %a, i32* %P) { + switch i32 %a, label %sw.default [ + i32 0, label %sw.bb0 + i32 1, label %sw.bb1 + i32 2, label %sw.bb2 + i32 3, label %sw.bb0 + i32 4, label %sw.bb1 + i32 5, label %sw.bb2 + i32 6, label %sw.bb0 + i32 7, label %sw.bb1 + i32 8, label %sw.bb2 + ] + +sw.default: + store volatile i32 %a, i32* %P + br label %return + +sw.bb0: + store volatile i32 %a, i32* %P + br label %return + +sw.bb1: + store volatile i32 %a, i32* %P + br label %return + +sw.bb2: + br label %return + +return: + ret i32 42 +} + + +define i32 @caller_bittest(i32 %a, i32* %P) { +; CHECK-LABEL: @caller_bittest( +; CHECK-NOT: call i32 @callee_bittest + %r= call i32 @callee_bittest(i32 %a, i32* %P) + ret i32 %r +} + +define i32 @callee_jumptable(i32 %a, i32* %P) { + switch i32 %a, label %sw.default [ + i32 1001, label %sw.bb101 + i32 1002, label %sw.bb102 + i32 1003, label %sw.bb103 + i32 1004, label %sw.bb104 + i32 1005, label %sw.bb101 + i32 1006, label %sw.bb102 + i32 1007, label %sw.bb103 + i32 1008, label %sw.bb104 + i32 1009, label %sw.bb101 + i32 1010, label %sw.bb102 + i32 1011, label %sw.bb103 + i32 1012, label %sw.bb104 + ] + +sw.default: + br label %return + +sw.bb101: + store volatile i32 %a, i32* %P + br label %return + +sw.bb102: + store volatile i32 %a, i32* %P + br label %return + +sw.bb103: + store volatile i32 %a, i32* %P + br label %return + +sw.bb104: + store volatile i32 %a, i32* %P + br label %return + +return: + ret i32 42 +} + +define i32 @caller_jumptable(i32 %a, i32 %b, i32* %P) { +; CHECK-LABEL: @caller_jumptable( +; CHECK: call i32 @callee_jumptable + %r = call i32 @callee_jumptable(i32 %b, i32* %P) + ret i32 %r +} +