Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -197,6 +197,12 @@ int getIntrinsicCost(Intrinsic::ID IID, Type *RetTy, ArrayRef Arguments) const; + /// \return The estimated number of case clusters when lowering \p 'SI'. + /// \p JTSize Set a jump table size only when \p SI is suitable for a jump + /// table. + int getEstimatedNumberOfCaseClusters(const SwitchInst &SI, + unsigned &JTSize) const; + /// \brief Estimate the cost of a given IR user when lowered. /// /// This can estimate the cost of either a ConstantExpr or Instruction when @@ -764,6 +770,8 @@ ArrayRef ParamTys) = 0; virtual int getIntrinsicCost(Intrinsic::ID IID, Type *RetTy, ArrayRef Arguments) = 0; + virtual int getEstimatedNumberOfCaseClusters(const SwitchInst &SI, + unsigned &JTSize) = 0; virtual int getUserCost(const User *U) = 0; virtual bool hasBranchDivergence() = 0; virtual bool isSourceOfDivergence(const Value *V) = 0; @@ -1067,6 +1075,10 @@ unsigned getMaxInterleaveFactor(unsigned VF) override { return Impl.getMaxInterleaveFactor(VF); } + int getEstimatedNumberOfCaseClusters(const SwitchInst &SI, + unsigned &JTSize) override { + return Impl.getEstimatedNumberOfCaseClusters(SI, JTSize); + } unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, OperandValueKind Opd2Info, Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -114,6 +114,11 @@ return TTI::TCC_Free; } + int getEstimatedNumberOfCaseClusters(const SwitchInst &SI, unsigned &JTSize) { + JTSize = 0; + return SI.getNumCases(); + } + unsigned getCallCost(FunctionType *FTy, int NumArgs) { assert(FTy && "FunctionType must be provided to this routine."); Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -171,6 +171,75 @@ return BaseT::getIntrinsicCost(IID, RetTy, ParamTys); } + int getEstimatedNumberOfCaseClusters(const SwitchInst &SI, + unsigned &JumpTableSize) { + /// Try to find the estimated number of clusters. Note that the number of + /// clusters identified in this function could be different from the actural + /// numbers found in lowering. This function ignore switches that are + /// lowered with a mix of jump table / bit test / BTree. This function was + /// initially intended to be used when estimating the cost of switch in + /// inline cost heuristic, but it's a generic cost model to be used in other + /// places (e.g., in loop unrolling). + unsigned N = SI.getNumCases(); + const TargetLoweringBase *TLI = getTLI(); + const DataLayout &DL = this->getDataLayout(); + const bool OptForSize = SI.getParent()->getParent()->optForSize(); + const unsigned MaxJumpTableSize = + OptForSize || TLI->getMaximumJumpTableSize() == 0 + ? UINT_MAX + : TLI->getMaximumJumpTableSize(); + + bool IsJTAllowed = TLI->areJTsAllowed(&SI); + JumpTableSize = 0; + + if (N < 1 || + (DL.getPointerSizeInBits() < MaxJumpTableSize && MaxJumpTableSize < N)) + return N; + + if (!IsJTAllowed && DL.getPointerSizeInBits() < N) + return N; + + auto CI = SI.case_begin(); + APInt MaxCaseVal = CI->getCaseValue()->getValue(); + APInt MinCaseVal = MaxCaseVal; + ++CI; + for (auto CE = SI.case_end(); CI != CE; ++CI) { + const APInt &CaseVal = CI->getCaseValue()->getValue(); + if (CaseVal.sgt(MaxCaseVal)) + MaxCaseVal = CaseVal; + if (CaseVal.slt(MinCaseVal)) + MinCaseVal = CaseVal; + } + + // Check if suitable for a bit test + if (N <= DL.getPointerSizeInBits()) { + SmallPtrSet Dests; + for (auto I : SI.cases()) + Dests.insert(I.getCaseSuccessor()); + + if (TLI->isSuitableForBitTests(Dests.size(), N, MinCaseVal, MaxCaseVal, + DL)) + return 1; + } + + // Check if suitable for a jump table. + if (IsJTAllowed) { + if (N < 2 || N < TLI->getMinimumJumpTableEntries()) + return N; + + const unsigned MinDensity = TLI->getMinimumJumpTableDensity(OptForSize); + unsigned Range = + (MaxCaseVal - MinCaseVal).getLimitedValue(UINT64_MAX - 1) + 1; + + // Check if dense enough + if (Range <= MaxJumpTableSize && (N * 100 >= Range * MinDensity)) { + JumpTableSize = Range; + return 1; + } + } + return N; + } + unsigned getJumpBufAlignment() { return getTLI()->getJumpBufAlignment(); } unsigned getJumpBufSize() { return getTLI()->getJumpBufSize(); } Index: include/llvm/Target/TargetLowering.h =================================================================== --- include/llvm/Target/TargetLowering.h +++ include/llvm/Target/TargetLowering.h @@ -762,6 +762,49 @@ return (!isTypeLegal(VT) && getOperationAction(Op, VT) == Custom); } + /// Return true if lowering to a jump table is allowed. + bool areJTsAllowed(const SwitchInst *SI) const { + const Function *Fn = SI->getParent()->getParent(); + if (Fn->getFnAttribute("no-jump-tables").getValueAsString() == "true") + return false; + return isOperationLegalOrCustom(ISD::BR_JT, MVT::Other) || + isOperationLegalOrCustom(ISD::BRIND, MVT::Other); + } + + /// Check whether the range [Low,High] fits in a machine word. + bool rangeFitsInWord(const APInt &Low, const APInt &High, + const DataLayout &DL) const { + // FIXME: Using the pointer type doesn't seem ideal. + uint64_t BW = DL.getPointerSizeInBits(); + uint64_t Range = (High - Low).getLimitedValue(UINT64_MAX - 1) + 1; + return Range <= BW; + } + + /// Return true if lowering to a bit test is suitable for a SwitchInst which + /// contains \p NumDests unique destinations, \p Low and \p High as its lowest + /// and highest case values, and expects \p NumCmps case value comparisons. + bool isSuitableForBitTests(unsigned NumDests, unsigned NumCmps, + const APInt &Low, const APInt &High, + const DataLayout &DL) const { + // FIXME: I don't think NumCmps is the correct metric: a single case and a + // range of cases both require only one branch to lower. Just looking at the + // number of clusters and destinations should be enough to decide whether to + // build bit tests. + + // To lower a range with bit tests, the range must fit the bitwidth of a + // machine word. + if (!rangeFitsInWord(Low, High, DL)) + return false; + + // Decide whether it's profitable to lower this range with bit tests. Each + // destination requires a bit test and branch, and there is an overall range + // check branch. For a small number of clusters, separate comparisons might + // be cheaper, and for many destinations, splitting the range might be + // better. + return (NumDests == 1 && NumCmps >= 3) || (NumDests == 2 && NumCmps >= 5) || + (NumDests == 3 && NumCmps >= 6); + } + /// Return true if the specified operation is illegal on this target or /// unlikely to be made legal with custom lowering. This is used to help guide /// high-level lowering decisions. @@ -1136,6 +1179,9 @@ /// Return lower limit for number of blocks in a jump table. unsigned getMinimumJumpTableEntries() const; + /// Return lower limit of the density in a jump table. + unsigned getMinimumJumpTableDensity(bool OptForSize) const; + /// Return upper limit for number of entries in a jump table. /// Zero if no limit. unsigned getMaximumJumpTableSize() const; Index: lib/Analysis/InlineCost.cpp =================================================================== --- lib/Analysis/InlineCost.cpp +++ lib/Analysis/InlineCost.cpp @@ -54,6 +54,11 @@ cl::init(45), cl::desc("Threshold for inlining cold callsites")); +static cl::opt + EnableGenericSwitchCost("inline-generic-switch-cost", cl::Hidden, + cl::init(false), + cl::desc("Enable generic switch cost model")); + // We introduce this threshold to help performance of instrumentation based // PGO before we actually hook up inliner with analysis passes such as BPI and // BFI. @@ -997,23 +1002,72 @@ if (Value *V = SimplifiedValues.lookup(SI.getCondition())) if (isa(V)) return true; + if (!EnableGenericSwitchCost) { + // Use a simple switch cost model where we accumulate a cost proportional to + // the number of distinct successor blocks. This fan-out in the CFG cannot + // be represented for free even if we can represent the core switch as a + // jumptable that takes a single instruction. + SmallPtrSet SuccessorBlocks; + SuccessorBlocks.insert(SI.getDefaultDest()); + for (auto Case : SI.cases()) + SuccessorBlocks.insert(Case.getCaseSuccessor()); + // Add cost corresponding to the number of distinct destinations. The first + // we model as free because of fallthrough. + Cost += (SuccessorBlocks.size() - 1) * InlineConstants::InstrCost; + return false; + } - // Otherwise, we need to accumulate a cost proportional to the number of - // distinct successor blocks. This fan-out in the CFG cannot be represented - // for free even if we can represent the core switch as a jumptable that - // takes a single instruction. + // Otherwise, we assume the most general case where the swith is lowered into + // either a jump table, bit test, or a balanced binary tree consisting of + // case clusters without merging adjacent clusters with the same destination. + // We do not consider the switches that are lowered with a mix of jump table/ + // bit test/BTree. The cost of the switch is proportional to the size of + // the tree or the size of jump table range. // // NB: We convert large switches which are just used to initialize large phi // nodes to lookup tables instead in simplify-cfg, so this shouldn't prevent // inlining those. It will prevent inlining in cases where the optimization // does not (yet) fire. - SmallPtrSet SuccessorBlocks; - SuccessorBlocks.insert(SI.getDefaultDest()); - for (auto Case : SI.cases()) - SuccessorBlocks.insert(Case.getCaseSuccessor()); - // Add cost corresponding to the number of distinct destinations. The first - // we model as free because of fallthrough. - Cost += (SuccessorBlocks.size() - 1) * InlineConstants::InstrCost; + // Exit early for a large switch, assuming one case needs at least one + // instruction. + // FIXME: This is not true for a bit test, but ignore such case for now to + // save compile-time. + int CostLowerBound = Cost + SI.getNumCases() * InlineConstants::InstrCost; + if (CostLowerBound > Threshold) { + Cost = CostLowerBound; + return false; + } + + unsigned JumpTableSize = 0; + int NumCaseCluster = TTI.getEstimatedNumberOfCaseClusters(SI, JumpTableSize); + SmallVector SwitchWorkList; + SwitchWorkList.push_back(NumCaseCluster); + Cost -= InlineConstants::InstrCost; + + // If suitable for a jump table, consider the jump table size. + if (JumpTableSize) + Cost += JumpTableSize * InlineConstants::InstrCost; + + while (!SwitchWorkList.empty()) { + unsigned NumCases = SwitchWorkList.back(); + SwitchWorkList.pop_back(); + if (NumCases <= 3) + // Do not split the tree if the number of remaining cases is less than 3. + // Just compare switch condition with each case value. Suppose each + // comparison includes one compare and one conditional branch. + Cost += (2 * NumCases * InlineConstants::InstrCost); + else { + // Split the remaining nodes and add one more comparison. + unsigned NumLeft = NumCases / 2; + unsigned NumRight = NumCases - NumLeft; + SwitchWorkList.push_back(NumLeft); + SwitchWorkList.push_back(NumRight); + Cost += (2 * InlineConstants::InstrCost); + } + // Exit early if Cost is already larger than Threshold. + if (Cost > Threshold) + return false; + } return false; } Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -83,6 +83,11 @@ return Cost; } +int TargetTransformInfo::getEstimatedNumberOfCaseClusters( + const SwitchInst &SI, unsigned &JTSize) const { + return TTIImpl->getEstimatedNumberOfCaseClusters(SI, JTSize); +} + int TargetTransformInfo::getUserCost(const User *U) const { int Cost = TTIImpl->getUserCost(U); assert(Cost >= 0 && "TTI should not produce negative costs!"); Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -83,20 +83,6 @@ "for some float libcalls"), cl::location(LimitFloatPrecision), cl::init(0)); - -/// Minimum jump table density for normal functions. -static cl::opt -JumpTableDensity("jump-table-density", cl::init(10), cl::Hidden, - cl::desc("Minimum density for building a jump table in " - "a normal function")); - -/// Minimum jump table density for -Os or -Oz functions. -static cl::opt -OptsizeJumpTableDensity("optsize-jump-table-density", cl::init(40), cl::Hidden, - cl::desc("Minimum density for building a jump table in " - "an optsize function")); - - // Limit the width of DAG chains. This is important in general to prevent // DAG-based analysis from blowing up. For example, alias analysis and // load clustering may not complete in reasonable time. It is difficult to @@ -8663,10 +8649,11 @@ JTProbs[Clusters[I].MBB] += Clusters[I].Prob; } + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); unsigned NumDests = JTProbs.size(); - if (isSuitableForBitTests(NumDests, NumCmps, - Clusters[First].Low->getValue(), - Clusters[Last].High->getValue())) { + if (TLI.isSuitableForBitTests( + NumDests, NumCmps, Clusters[First].Low->getValue(), + Clusters[Last].High->getValue(), DAG.getDataLayout())) { // Clusters[First..Last] should be lowered as bit tests instead. return false; } @@ -8687,7 +8674,6 @@ } JumpTableMBB->normalizeSuccProbs(); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI.getJumpTableEncoding()) ->createJumpTableIndex(Table); @@ -8741,8 +8727,7 @@ TotalCases[i] += TotalCases[i - 1]; } - const unsigned MinDensity = - OptForSize ? OptsizeJumpTableDensity : JumpTableDensity; + const unsigned MinDensity = TLI.getMinimumJumpTableDensity(OptForSize); // Cheap case: the whole range may be suitable for jump table. unsigned JumpTableSize = (Clusters[N - 1].High->getValue() - @@ -8850,36 +8835,6 @@ Clusters.resize(DstIndex); } -bool SelectionDAGBuilder::rangeFitsInWord(const APInt &Low, const APInt &High) { - // FIXME: Using the pointer type doesn't seem ideal. - uint64_t BW = DAG.getDataLayout().getPointerSizeInBits(); - uint64_t Range = (High - Low).getLimitedValue(UINT64_MAX - 1) + 1; - return Range <= BW; -} - -bool SelectionDAGBuilder::isSuitableForBitTests(unsigned NumDests, - unsigned NumCmps, - const APInt &Low, - const APInt &High) { - // FIXME: I don't think NumCmps is the correct metric: a single case and a - // range of cases both require only one branch to lower. Just looking at the - // number of clusters and destinations should be enough to decide whether to - // build bit tests. - - // To lower a range with bit tests, the range must fit the bitwidth of a - // machine word. - if (!rangeFitsInWord(Low, High)) - return false; - - // Decide whether it's profitable to lower this range with bit tests. Each - // destination requires a bit test and branch, and there is an overall range - // check branch. For a small number of clusters, separate comparisons might be - // cheaper, and for many destinations, splitting the range might be better. - return (NumDests == 1 && NumCmps >= 3) || - (NumDests == 2 && NumCmps >= 5) || - (NumDests == 3 && NumCmps >= 6); -} - bool SelectionDAGBuilder::buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last, const SwitchInst *SI, @@ -8901,16 +8856,17 @@ APInt High = Clusters[Last].High->getValue(); assert(Low.slt(High)); - if (!isSuitableForBitTests(NumDests, NumCmps, Low, High)) + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + const DataLayout &DL = DAG.getDataLayout(); + if (!TLI.isSuitableForBitTests(NumDests, NumCmps, Low, High, DL)) return false; APInt LowBound; APInt CmpRange; - const int BitWidth = DAG.getTargetLoweringInfo() - .getPointerTy(DAG.getDataLayout()) - .getSizeInBits(); - assert(rangeFitsInWord(Low, High) && "Case range must fit in bit mask!"); + const int BitWidth = TLI.getPointerTy(DL).getSizeInBits(); + assert(TLI.rangeFitsInWord(Low, High, DL) && + "Case range must fit in bit mask!"); // Check if the clusters cover a contiguous range such that no value in the // range will jump to the default statement. @@ -9000,7 +8956,9 @@ // If target does not have legal shift left, do not emit bit tests at all. const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - EVT PTy = TLI.getPointerTy(DAG.getDataLayout()); + const DataLayout &DL = DAG.getDataLayout(); + + EVT PTy = TLI.getPointerTy(DL); if (!TLI.isOperationLegal(ISD::SHL, PTy)) return; @@ -9031,8 +8989,8 @@ // Try building a partition from Clusters[i..j]. // Check the range. - if (!rangeFitsInWord(Clusters[i].Low->getValue(), - Clusters[j].High->getValue())) + if (!TLI.rangeFitsInWord(Clusters[i].Low->getValue(), + Clusters[j].High->getValue(), DL)) continue; // Check nbr of destinations and cluster types. Index: lib/CodeGen/TargetLoweringBase.cpp =================================================================== --- lib/CodeGen/TargetLoweringBase.cpp +++ lib/CodeGen/TargetLoweringBase.cpp @@ -53,6 +53,18 @@ ("max-jump-table-size", cl::init(0), cl::Hidden, cl::desc("Set maximum size of jump tables; zero for no limit.")); +/// Minimum jump table density for normal functions. +static cl::opt + JumpTableDensity("jump-table-density", cl::init(10), cl::Hidden, + cl::desc("Minimum density for building a jump table in " + "a normal function")); + +/// Minimum jump table density for -Os or -Oz functions. +static cl::opt OptsizeJumpTableDensity( + "optsize-jump-table-density", cl::init(40), cl::Hidden, + cl::desc("Minimum density for building a jump table in " + "an optsize function")); + // Although this default value is arbitrary, it is not random. It is assumed // that a condition that evaluates the same way by a higher percentage than this // is best represented as control flow. Therefore, the default value N should be @@ -1902,6 +1914,10 @@ MinimumJumpTableEntries = Val; } +unsigned TargetLoweringBase::getMinimumJumpTableDensity(bool OptForSize) const { + return OptForSize ? OptsizeJumpTableDensity : JumpTableDensity; +} + unsigned TargetLoweringBase::getMaximumJumpTableSize() const { return MaximumJumpTableSize; } Index: test/Transforms/Inline/AArch64/switch.ll =================================================================== --- /dev/null +++ test/Transforms/Inline/AArch64/switch.ll @@ -0,0 +1,123 @@ +; RUN: opt < %s -inline -inline-threshold=20 -S -mtriple=aarch64-none-linux -inline-generic-switch-cost=true | FileCheck %s +; RUN: opt < %s -passes='cgscc(inline)' -inline-threshold=20 -S -mtriple=aarch64-none-linux -inline-generic-switch-cost=true | FileCheck %s + +define i32 @callee_range(i32 %a, i32* %P) { + switch i32 %a, label %sw.default [ + i32 0, label %sw.bb0 + i32 1000, label %sw.bb1 + i32 2000, label %sw.bb1 + i32 3000, label %sw.bb1 + i32 4000, label %sw.bb1 + i32 5000, label %sw.bb1 + i32 6000, label %sw.bb1 + i32 7000, label %sw.bb1 + i32 8000, label %sw.bb1 + i32 9000, label %sw.bb1 + ] + +sw.default: + store volatile i32 %a, i32* %P + br label %return +sw.bb0: + store volatile i32 %a, i32* %P + br label %return +sw.bb1: + store volatile i32 %a, i32* %P + br label %return +return: + ret i32 42 +} + +define i32 @caller_range(i32 %a, i32* %P) { +; CHECK-LABEL: @caller_range( +; CHECK: call i32 @callee_range + %r = call i32 @callee_range(i32 %a, i32* %P) + ret i32 %r +} + +define i32 @callee_bittest(i32 %a, i32* %P) { + switch i32 %a, label %sw.default [ + i32 0, label %sw.bb0 + i32 1, label %sw.bb1 + i32 2, label %sw.bb2 + i32 3, label %sw.bb0 + i32 4, label %sw.bb1 + i32 5, label %sw.bb2 + i32 6, label %sw.bb0 + i32 7, label %sw.bb1 + i32 8, label %sw.bb2 + ] + +sw.default: + store volatile i32 %a, i32* %P + br label %return + +sw.bb0: + store volatile i32 %a, i32* %P + br label %return + +sw.bb1: + store volatile i32 %a, i32* %P + br label %return + +sw.bb2: + br label %return + +return: + ret i32 42 +} + + +define i32 @caller_bittest(i32 %a, i32* %P) { +; CHECK-LABEL: @caller_bittest( +; CHECK-NOT: call i32 @callee_bittest + %r= call i32 @callee_bittest(i32 %a, i32* %P) + ret i32 %r +} + +define i32 @callee_jumptable(i32 %a, i32* %P) { + switch i32 %a, label %sw.default [ + i32 1001, label %sw.bb101 + i32 1002, label %sw.bb102 + i32 1003, label %sw.bb103 + i32 1004, label %sw.bb104 + i32 1005, label %sw.bb101 + i32 1006, label %sw.bb102 + i32 1007, label %sw.bb103 + i32 1008, label %sw.bb104 + i32 1009, label %sw.bb101 + i32 1010, label %sw.bb102 + i32 1011, label %sw.bb103 + i32 1012, label %sw.bb104 + ] + +sw.default: + br label %return + +sw.bb101: + store volatile i32 %a, i32* %P + br label %return + +sw.bb102: + store volatile i32 %a, i32* %P + br label %return + +sw.bb103: + store volatile i32 %a, i32* %P + br label %return + +sw.bb104: + store volatile i32 %a, i32* %P + br label %return + +return: + ret i32 42 +} + +define i32 @caller_jumptable(i32 %a, i32 %b, i32* %P) { +; CHECK-LABEL: @caller_jumptable( +; CHECK: call i32 @callee_jumptable + %r = call i32 @callee_jumptable(i32 %b, i32* %P) + ret i32 %r +} +