Index: include/llvm/CodeGen/SwitchCaseCluster.h =================================================================== --- /dev/null +++ include/llvm/CodeGen/SwitchCaseCluster.h @@ -0,0 +1,172 @@ +//=======-- SwitchCaseCluster.h - Form case clusters from SwitchInst --=======// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This implements routines for forming case clusters. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CODEGEN_SWITCHCASECLUSTER_H +#define LLVM_CODEGEN_SWITCHCASECLUSTER_H + +#include "llvm/Target/TargetLowering.h" +#include + +namespace llvm { + +class SelectionDAGBuilder; + +enum CaseClusterKind { + /// A cluster of adjacent case labels with the same destination, or just one + /// case. + CC_Range, + /// A cluster of cases suitable for jump table lowering. + CC_JumpTable, + /// A cluster of cases suitable for bit test lowering. + CC_BitTests +}; + +/// A cluster of case labels. +class CaseCluster { +public: + CaseClusterKind Kind; + const ConstantInt *Low, *High; + + // Hold case indexes for a switch + SmallVector Cases; + + // Return a read-only case iterator indexed by I in this cluster. + SwitchInst::ConstCaseIt getCase(const SwitchInst *SI, unsigned I) const { + assert(Cases.size() > 0 && I < Cases.size()); + return SwitchInst::ConstCaseIt(SI, Cases[I]); + } + + /// Return the number of cases in this cluster. + unsigned getNumerOfCases() const { return Cases.size(); } + + /// Return case value indexed by I in this cluster. + const ConstantInt *getCaseValueAt(const SwitchInst *SI, unsigned I) const { + return getCase(SI, I).getCaseValue(); + } + + /// Return successor for the case indexed by I in this cluster. + const BasicBlock *getCaseSuccessorAt(const SwitchInst *SI, unsigned I) const { + return getCase(SI, I).getCaseSuccessor(); + } + + /// Return a single successor only for a range cluster. + const BasicBlock *getSingleSuccessor(const SwitchInst *SI) const { + assert(Kind == CC_Range && "Expected to be used only by CC_Range clusters"); + const BasicBlock *BB = getCase(SI, 0).getCaseSuccessor(); + return BB; + } + + static CaseCluster range(const ConstantInt *Low, const ConstantInt *High, + unsigned Index) { + CaseCluster C; + C.Kind = CC_Range; + C.Low = Low; + C.High = High; + C.Cases.push_back(Index); + return C; + } + + static CaseCluster jumpTable(const ConstantInt *Low, + const ConstantInt *High) { + CaseCluster C; + C.Kind = CC_JumpTable; + C.Low = Low; + C.High = High; + return C; + } + + static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High) { + CaseCluster C; + C.Kind = CC_BitTests; + C.Low = Low; + C.High = High; + return C; + } +}; + +typedef std::vector CaseClusterVector; +typedef CaseClusterVector::iterator CaseClusterIt; + +class SwitchCaseClusterFinder { +public: + const DataLayout &DL; + const TargetLowering &TLI; + const CodeGenOpt::Level OptLevel; + + SwitchCaseClusterFinder(const DataLayout &DL, const TargetLowering &TLI, + const CodeGenOpt::Level OptLevel) + : DL(DL), TLI(TLI), OptLevel(OptLevel) {} + + /// Check whether the range [Low,High] fits in a machine word. + static bool rangeFitsInWord(const APInt &Low, const APInt &High, + const DataLayout &DL) { + // FIXME: Using the pointer type doesn't seem ideal. + uint64_t BW = DL.getPointerSizeInBits(); + uint64_t Range = (High - Low).getLimitedValue(UINT64_MAX - 1) + 1; + return Range <= BW; + } + + /// Calculate clusters for cases in SI and store them in Clusters. + const BasicBlock *findClusters(const SwitchInst &SI, + CaseClusterVector &Clusters); +private: + /// Extract cases from the switch and build initial form of case clusters. + void formInitalCaseClusers(const SwitchInst &SI, CaseClusterVector &Clusters); + + /// Find clusters of cases suitable for jump table lowering. + void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI); + + /// Find clusters of cases suitable for bit test lowering. + void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI); + + // Replace an unreachable default with the most popular destination. + const BasicBlock *replaceUnreachableDefault(const SwitchInst &SI, + CaseClusterVector &Clusters); + + /// Check whether these clusters are suitable for lowering with bit tests + /// based on the number of destinations, comparison metric, and range. + bool isSuitableForBitTests(unsigned NumDests, unsigned NumCmps, + const APInt &Low, const APInt &High); + + /// Return true if building a jump table is feasible from + /// Clusters[First..Last]. + bool canBuildJumpTable(const CaseClusterVector &Clusters, unsigned First, + unsigned Last, const SwitchInst *SI, + CaseCluster &JTCluster); + + /// Returns true if build a bit test cluster is feasible from + /// Clusters[First..Last]. + bool canBuildBitTest(CaseClusterVector &Clusters, unsigned First, + unsigned Last, const SwitchInst *SI, + CaseCluster &BTCluster); + + /// Check whether a range of clusters is dense enough for a jump table. + bool isDense(const CaseClusterVector &Clusters, + const SmallVectorImpl &TotalCases, unsigned First, + unsigned Last, unsigned MinDensity) const; + + /// Sort Clusters and merge adjacent cases. + void sortAndRangeify(const SwitchInst *SI, CaseClusterVector &Clusters); + + /// Create a jump table cluster from Clusters[First..Last]. + void createJumpTableCluster(const CaseClusterVector &Clusters, unsigned First, + unsigned Last, const SwitchInst *SI, + CaseCluster &JTCluster); + + /// Build a bit test cluster from Clusters[First..Last]. + void createBitTestCluster(CaseClusterVector &Clusters, unsigned First, + unsigned Last, const SwitchInst *SI, + CaseCluster &BTCluster); +}; +} // end namespace llvm +#endif Index: lib/CodeGen/SelectionDAG/CMakeLists.txt =================================================================== --- lib/CodeGen/SelectionDAG/CMakeLists.txt +++ lib/CodeGen/SelectionDAG/CMakeLists.txt @@ -22,6 +22,7 @@ SelectionDAGPrinter.cpp SelectionDAGTargetInfo.cpp StatepointLowering.cpp + SwitchCaseCluster.cpp TargetLowering.cpp DEPENDS Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h @@ -20,6 +20,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/SelectionDAGNodes.h" +#include "llvm/CodeGen/SwitchCaseCluster.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Statepoint.h" @@ -135,20 +136,11 @@ /// SDNodes we create. unsigned SDNodeOrder; - enum CaseClusterKind { - /// A cluster of adjacent case labels with the same destination, or just one - /// case. - CC_Range, - /// A cluster of cases suitable for jump table lowering. - CC_JumpTable, - /// A cluster of cases suitable for bit test lowering. - CC_BitTests - }; /// A cluster of case labels. - struct CaseCluster { - CaseClusterKind Kind; - const ConstantInt *Low, *High; + class MachineCaseCluster { + public: + const CaseCluster *CC; union { MachineBasicBlock *MBB; unsigned JTCasesIndex; @@ -156,43 +148,43 @@ }; BranchProbability Prob; - static CaseCluster range(const ConstantInt *Low, const ConstantInt *High, - MachineBasicBlock *MBB, BranchProbability Prob) { - CaseCluster C; - C.Kind = CC_Range; - C.Low = Low; - C.High = High; + static MachineCaseCluster range(const CaseCluster *CC, + MachineBasicBlock *MBB, + BranchProbability Prob) { + MachineCaseCluster C; + C.CC = CC; C.MBB = MBB; C.Prob = Prob; return C; } - static CaseCluster jumpTable(const ConstantInt *Low, - const ConstantInt *High, unsigned JTCasesIndex, - BranchProbability Prob) { - CaseCluster C; - C.Kind = CC_JumpTable; - C.Low = Low; - C.High = High; + static MachineCaseCluster jumpTable(const CaseCluster *CC, + unsigned JTCasesIndex, + BranchProbability Prob) { + MachineCaseCluster C; + C.CC = CC; C.JTCasesIndex = JTCasesIndex; C.Prob = Prob; return C; } - static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High, - unsigned BTCasesIndex, BranchProbability Prob) { - CaseCluster C; - C.Kind = CC_BitTests; - C.Low = Low; - C.High = High; + static MachineCaseCluster bitTests(const CaseCluster *CC, + unsigned BTCasesIndex, + BranchProbability Prob) { + MachineCaseCluster C; + C.CC = CC; C.BTCasesIndex = BTCasesIndex; C.Prob = Prob; return C; } + + CaseClusterKind getKind() const { return CC->Kind; } + const ConstantInt *getLow() const { return CC->Low; } + const ConstantInt *getHigh() const { return CC->High; } }; - typedef std::vector CaseClusterVector; - typedef CaseClusterVector::iterator CaseClusterIt; + typedef std::vector MachineCaseClusterVector; + typedef MachineCaseClusterVector::iterator MachineCaseClusterIt; struct CaseBits { uint64_t Mask; @@ -209,9 +201,6 @@ typedef std::vector CaseBitsVector; - /// Sort Clusters and merge adjacent cases. - void sortAndRangeify(CaseClusterVector &Clusters); - /// CaseBlock - This structure is used to communicate between /// SelectionDAGBuilder and SDISel for the code generation of additional basic /// blocks needed by multi-case switch statements. @@ -304,41 +293,10 @@ BranchProbability DefaultProb; }; - /// Check whether a range of clusters is dense enough for a jump table. - bool isDense(const CaseClusterVector &Clusters, - const SmallVectorImpl &TotalCases, - unsigned First, unsigned Last, unsigned MinDensity) const; - - /// Build a jump table cluster from Clusters[First..Last]. Returns false if it - /// decides it's not a good idea. - bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First, - unsigned Last, const SwitchInst *SI, - MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster); - - /// Find clusters of cases suitable for jump table lowering. - void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI, - MachineBasicBlock *DefaultMBB); - - /// Check whether the range [Low,High] fits in a machine word. - bool rangeFitsInWord(const APInt &Low, const APInt &High); - - /// Check whether these clusters are suitable for lowering with bit tests based - /// on the number of destinations, comparison metric, and range. - bool isSuitableForBitTests(unsigned NumDests, unsigned NumCmps, - const APInt &Low, const APInt &High); - - /// Build a bit test cluster from Clusters[First..Last]. Returns false if it - /// decides it's not a good idea. - bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last, - const SwitchInst *SI, CaseCluster &BTCluster); - - /// Find clusters of cases suitable for bit test lowering. - void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI); - struct SwitchWorkListItem { MachineBasicBlock *MBB; - CaseClusterIt FirstCluster; - CaseClusterIt LastCluster; + MachineCaseClusterIt FirstCluster; + MachineCaseClusterIt LastCluster; const ConstantInt *GE; const ConstantInt *LT; BranchProbability DefaultProb; @@ -347,19 +305,28 @@ /// Determine the rank by weight of CC in [First,Last]. If CC has more weight /// than each cluster in the range, its rank is 0. - static unsigned caseClusterRank(const CaseCluster &CC, CaseClusterIt First, - CaseClusterIt Last); + static unsigned caseClusterRank(const MachineCaseCluster &CC, + MachineCaseClusterIt First, + MachineCaseClusterIt Last); /// Emit comparison and split W into two subtrees. void splitWorkItem(SwitchWorkList &WorkList, const SwitchWorkListItem &W, Value *Cond, MachineBasicBlock *SwitchMBB); + /// Prepare to lower a jump table case cluster. + unsigned prepareJumpTable(const SwitchInst *SI, CaseCluster &JTCluser, + MachineBasicBlock *DefaultMBB); + + /// Prepare to lower a bit test case cluster. + unsigned prepareBitTests(const SwitchInst *SI, CaseCluster &JTCluser, + MachineBasicBlock *DefaultMBB, + BranchProbability TotalProb); + /// Lower W. void lowerWorkItem(SwitchWorkListItem W, Value *Cond, MachineBasicBlock *SwitchMBB, MachineBasicBlock *DefaultMBB); - /// A class which encapsulates all of the information needed to generate a /// stack protector check and signals to isel via its state being initialized /// that a stack protector needs to be generated. @@ -606,11 +573,18 @@ LLVMContext *Context; + /// Helper object to form case clusters for SwitchInst. + SwitchCaseClusterFinder *CaseClusterFinder; + SelectionDAGBuilder(SelectionDAG &dag, FunctionLoweringInfo &funcinfo, CodeGenOpt::Level ol) - : CurInst(nullptr), SDNodeOrder(LowestSDNodeOrder), TM(dag.getTarget()), - DAG(dag), FuncInfo(funcinfo), - HasTailCall(false) { + : CurInst(nullptr), SDNodeOrder(LowestSDNodeOrder), TM(dag.getTarget()), + DAG(dag), FuncInfo(funcinfo), HasTailCall(false), + CaseClusterFinder(nullptr) {} + + ~SelectionDAGBuilder() { + if (CaseClusterFinder) + delete CaseClusterFinder; } void init(GCFunctionInfo *gfi, AliasAnalysis &aa, Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -84,19 +84,6 @@ cl::location(LimitFloatPrecision), cl::init(0)); -/// Minimum jump table density for normal functions. -static cl::opt -JumpTableDensity("jump-table-density", cl::init(10), cl::Hidden, - cl::desc("Minimum density for building a jump table in " - "a normal function")); - -/// Minimum jump table density for -Os or -Oz functions. -static cl::opt -OptsizeJumpTableDensity("optsize-jump-table-density", cl::init(40), cl::Hidden, - cl::desc("Minimum density for building a jump table in " - "an optsize function")); - - // Limit the width of DAG chains. This is important in general to prevent // DAG-based analysis from blowing up. For example, alias analysis and // load clustering may not complete in reasonable time. It is difficult to @@ -2341,39 +2328,6 @@ setValue(&LP, Res); } -void SelectionDAGBuilder::sortAndRangeify(CaseClusterVector &Clusters) { -#ifndef NDEBUG - for (const CaseCluster &CC : Clusters) - assert(CC.Low == CC.High && "Input clusters must be single-case"); -#endif - - std::sort(Clusters.begin(), Clusters.end(), - [](const CaseCluster &a, const CaseCluster &b) { - return a.Low->getValue().slt(b.Low->getValue()); - }); - - // Merge adjacent clusters with the same destination. - const unsigned N = Clusters.size(); - unsigned DstIndex = 0; - for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) { - CaseCluster &CC = Clusters[SrcIndex]; - const ConstantInt *CaseVal = CC.Low; - MachineBasicBlock *Succ = CC.MBB; - - if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ && - (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) { - // If this case has the same successor and is a neighbour, merge it into - // the previous cluster. - Clusters[DstIndex - 1].High = CaseVal; - Clusters[DstIndex - 1].Prob += CC.Prob; - } else { - std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex], - sizeof(Clusters[SrcIndex])); - } - } - Clusters.resize(DstIndex); -} - void SelectionDAGBuilder::UpdateSplitBlock(MachineBasicBlock *First, MachineBasicBlock *Last) { // Update JTCases. @@ -8587,495 +8541,6 @@ HasTailCall = true; } -bool SelectionDAGBuilder::isDense(const CaseClusterVector &Clusters, - const SmallVectorImpl &TotalCases, - unsigned First, unsigned Last, - unsigned Density) const { - assert(Last >= First); - assert(TotalCases[Last] >= TotalCases[First]); - - const APInt &LowCase = Clusters[First].Low->getValue(); - const APInt &HighCase = Clusters[Last].High->getValue(); - assert(LowCase.getBitWidth() == HighCase.getBitWidth()); - - // FIXME: A range of consecutive cases has 100% density, but only requires one - // comparison to lower. We should discriminate against such consecutive ranges - // in jump tables. - - uint64_t Diff = (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100); - uint64_t Range = Diff + 1; - - uint64_t NumCases = - TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]); - - assert(NumCases < UINT64_MAX / 100); - assert(Range >= NumCases); - - return NumCases * 100 >= Range * Density; -} - -static inline bool areJTsAllowed(const TargetLowering &TLI, - const SwitchInst *SI) { - const Function *Fn = SI->getParent()->getParent(); - if (Fn->getFnAttribute("no-jump-tables").getValueAsString() == "true") - return false; - - return TLI.isOperationLegalOrCustom(ISD::BR_JT, MVT::Other) || - TLI.isOperationLegalOrCustom(ISD::BRIND, MVT::Other); -} - -bool SelectionDAGBuilder::buildJumpTable(const CaseClusterVector &Clusters, - unsigned First, unsigned Last, - const SwitchInst *SI, - MachineBasicBlock *DefaultMBB, - CaseCluster &JTCluster) { - assert(First <= Last); - - auto Prob = BranchProbability::getZero(); - unsigned NumCmps = 0; - std::vector Table; - DenseMap JTProbs; - - // Initialize probabilities in JTProbs. - for (unsigned I = First; I <= Last; ++I) - JTProbs[Clusters[I].MBB] = BranchProbability::getZero(); - - for (unsigned I = First; I <= Last; ++I) { - assert(Clusters[I].Kind == CC_Range); - Prob += Clusters[I].Prob; - const APInt &Low = Clusters[I].Low->getValue(); - const APInt &High = Clusters[I].High->getValue(); - NumCmps += (Low == High) ? 1 : 2; - if (I != First) { - // Fill the gap between this and the previous cluster. - const APInt &PreviousHigh = Clusters[I - 1].High->getValue(); - assert(PreviousHigh.slt(Low)); - uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1; - for (uint64_t J = 0; J < Gap; J++) - Table.push_back(DefaultMBB); - } - uint64_t ClusterSize = (High - Low).getLimitedValue() + 1; - for (uint64_t J = 0; J < ClusterSize; ++J) - Table.push_back(Clusters[I].MBB); - JTProbs[Clusters[I].MBB] += Clusters[I].Prob; - } - - unsigned NumDests = JTProbs.size(); - if (isSuitableForBitTests(NumDests, NumCmps, - Clusters[First].Low->getValue(), - Clusters[Last].High->getValue())) { - // Clusters[First..Last] should be lowered as bit tests instead. - return false; - } - - // Create the MBB that will load from and jump through the table. - // Note: We create it here, but it's not inserted into the function yet. - MachineFunction *CurMF = FuncInfo.MF; - MachineBasicBlock *JumpTableMBB = - CurMF->CreateMachineBasicBlock(SI->getParent()); - - // Add successors. Note: use table order for determinism. - SmallPtrSet Done; - for (MachineBasicBlock *Succ : Table) { - if (Done.count(Succ)) - continue; - addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]); - Done.insert(Succ); - } - JumpTableMBB->normalizeSuccProbs(); - - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI.getJumpTableEncoding()) - ->createJumpTableIndex(Table); - - // Set up the jump table info. - JumpTable JT(-1U, JTI, JumpTableMBB, nullptr); - JumpTableHeader JTH(Clusters[First].Low->getValue(), - Clusters[Last].High->getValue(), SI->getCondition(), - nullptr, false); - JTCases.emplace_back(std::move(JTH), std::move(JT)); - - JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High, - JTCases.size() - 1, Prob); - return true; -} - -void SelectionDAGBuilder::findJumpTables(CaseClusterVector &Clusters, - const SwitchInst *SI, - MachineBasicBlock *DefaultMBB) { -#ifndef NDEBUG - // Clusters must be non-empty, sorted, and only contain Range clusters. - assert(!Clusters.empty()); - for (CaseCluster &C : Clusters) - assert(C.Kind == CC_Range); - for (unsigned i = 1, e = Clusters.size(); i < e; ++i) - assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue())); -#endif - - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - if (!areJTsAllowed(TLI, SI)) - return; - - const bool OptForSize = DefaultMBB->getParent()->getFunction()->optForSize(); - - const int64_t N = Clusters.size(); - const unsigned MinJumpTableEntries = TLI.getMinimumJumpTableEntries(); - const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2; - const unsigned MaxJumpTableSize = - OptForSize || TLI.getMaximumJumpTableSize() == 0 - ? UINT_MAX : TLI.getMaximumJumpTableSize(); - - if (N < 2 || N < MinJumpTableEntries) - return; - - // TotalCases[i]: Total nbr of cases in Clusters[0..i]. - SmallVector TotalCases(N); - for (unsigned i = 0; i < N; ++i) { - const APInt &Hi = Clusters[i].High->getValue(); - const APInt &Lo = Clusters[i].Low->getValue(); - TotalCases[i] = (Hi - Lo).getLimitedValue() + 1; - if (i != 0) - TotalCases[i] += TotalCases[i - 1]; - } - - const unsigned MinDensity = - OptForSize ? OptsizeJumpTableDensity : JumpTableDensity; - - // Cheap case: the whole range may be suitable for jump table. - unsigned JumpTableSize = (Clusters[N - 1].High->getValue() - - Clusters[0].Low->getValue()) - .getLimitedValue(UINT_MAX - 1) + 1; - if (JumpTableSize <= MaxJumpTableSize && - isDense(Clusters, TotalCases, 0, N - 1, MinDensity)) { - CaseCluster JTCluster; - if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) { - Clusters[0] = JTCluster; - Clusters.resize(1); - return; - } - } - - // The algorithm below is not suitable for -O0. - if (TM.getOptLevel() == CodeGenOpt::None) - return; - - // Split Clusters into minimum number of dense partitions. The algorithm uses - // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code - // for the Case Statement'" (1994), but builds the MinPartitions array in - // reverse order to make it easier to reconstruct the partitions in ascending - // order. In the choice between two optimal partitionings, it picks the one - // which yields more jump tables. - - // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1]. - SmallVector MinPartitions(N); - // LastElement[i] is the last element of the partition starting at i. - SmallVector LastElement(N); - // PartitionsScore[i] is used to break ties when choosing between two - // partitionings resulting in the same number of partitions. - SmallVector PartitionsScore(N); - // For PartitionsScore, a small number of comparisons is considered as good as - // a jump table and a single comparison is considered better than a jump - // table. - enum PartitionScores : unsigned { - NoTable = 0, - Table = 1, - FewCases = 1, - SingleCase = 2 - }; - - // Base case: There is only one way to partition Clusters[N-1]. - MinPartitions[N - 1] = 1; - LastElement[N - 1] = N - 1; - PartitionsScore[N - 1] = PartitionScores::SingleCase; - - // Note: loop indexes are signed to avoid underflow. - for (int64_t i = N - 2; i >= 0; i--) { - // Find optimal partitioning of Clusters[i..N-1]. - // Baseline: Put Clusters[i] into a partition on its own. - MinPartitions[i] = MinPartitions[i + 1] + 1; - LastElement[i] = i; - PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase; - - // Search for a solution that results in fewer partitions. - for (int64_t j = N - 1; j > i; j--) { - // Try building a partition from Clusters[i..j]. - JumpTableSize = (Clusters[j].High->getValue() - - Clusters[i].Low->getValue()) - .getLimitedValue(UINT_MAX - 1) + 1; - if (JumpTableSize <= MaxJumpTableSize && - isDense(Clusters, TotalCases, i, j, MinDensity)) { - unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]); - unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1]; - int64_t NumEntries = j - i + 1; - - if (NumEntries == 1) - Score += PartitionScores::SingleCase; - else if (NumEntries <= SmallNumberOfEntries) - Score += PartitionScores::FewCases; - else if (NumEntries >= MinJumpTableEntries) - Score += PartitionScores::Table; - - // If this leads to fewer partitions, or to the same number of - // partitions with better score, it is a better partitioning. - if (NumPartitions < MinPartitions[i] || - (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) { - MinPartitions[i] = NumPartitions; - LastElement[i] = j; - PartitionsScore[i] = Score; - } - } - } - } - - // Iterate over the partitions, replacing some with jump tables in-place. - unsigned DstIndex = 0; - for (unsigned First = 0, Last; First < N; First = Last + 1) { - Last = LastElement[First]; - assert(Last >= First); - assert(DstIndex <= First); - unsigned NumClusters = Last - First + 1; - - CaseCluster JTCluster; - if (NumClusters >= MinJumpTableEntries && - buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) { - Clusters[DstIndex++] = JTCluster; - } else { - for (unsigned I = First; I <= Last; ++I) - std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I])); - } - } - Clusters.resize(DstIndex); -} - -bool SelectionDAGBuilder::rangeFitsInWord(const APInt &Low, const APInt &High) { - // FIXME: Using the pointer type doesn't seem ideal. - uint64_t BW = DAG.getDataLayout().getPointerSizeInBits(); - uint64_t Range = (High - Low).getLimitedValue(UINT64_MAX - 1) + 1; - return Range <= BW; -} - -bool SelectionDAGBuilder::isSuitableForBitTests(unsigned NumDests, - unsigned NumCmps, - const APInt &Low, - const APInt &High) { - // FIXME: I don't think NumCmps is the correct metric: a single case and a - // range of cases both require only one branch to lower. Just looking at the - // number of clusters and destinations should be enough to decide whether to - // build bit tests. - - // To lower a range with bit tests, the range must fit the bitwidth of a - // machine word. - if (!rangeFitsInWord(Low, High)) - return false; - - // Decide whether it's profitable to lower this range with bit tests. Each - // destination requires a bit test and branch, and there is an overall range - // check branch. For a small number of clusters, separate comparisons might be - // cheaper, and for many destinations, splitting the range might be better. - return (NumDests == 1 && NumCmps >= 3) || - (NumDests == 2 && NumCmps >= 5) || - (NumDests == 3 && NumCmps >= 6); -} - -bool SelectionDAGBuilder::buildBitTests(CaseClusterVector &Clusters, - unsigned First, unsigned Last, - const SwitchInst *SI, - CaseCluster &BTCluster) { - assert(First <= Last); - if (First == Last) - return false; - - BitVector Dests(FuncInfo.MF->getNumBlockIDs()); - unsigned NumCmps = 0; - for (int64_t I = First; I <= Last; ++I) { - assert(Clusters[I].Kind == CC_Range); - Dests.set(Clusters[I].MBB->getNumber()); - NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2; - } - unsigned NumDests = Dests.count(); - - APInt Low = Clusters[First].Low->getValue(); - APInt High = Clusters[Last].High->getValue(); - assert(Low.slt(High)); - - if (!isSuitableForBitTests(NumDests, NumCmps, Low, High)) - return false; - - APInt LowBound; - APInt CmpRange; - - const int BitWidth = DAG.getTargetLoweringInfo() - .getPointerTy(DAG.getDataLayout()) - .getSizeInBits(); - assert(rangeFitsInWord(Low, High) && "Case range must fit in bit mask!"); - - // Check if the clusters cover a contiguous range such that no value in the - // range will jump to the default statement. - bool ContiguousRange = true; - for (int64_t I = First + 1; I <= Last; ++I) { - if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) { - ContiguousRange = false; - break; - } - } - - if (Low.isStrictlyPositive() && High.slt(BitWidth)) { - // Optimize the case where all the case values fit in a word without having - // to subtract minValue. In this case, we can optimize away the subtraction. - LowBound = APInt::getNullValue(Low.getBitWidth()); - CmpRange = High; - ContiguousRange = false; - } else { - LowBound = Low; - CmpRange = High - Low; - } - - CaseBitsVector CBV; - auto TotalProb = BranchProbability::getZero(); - for (unsigned i = First; i <= Last; ++i) { - // Find the CaseBits for this destination. - unsigned j; - for (j = 0; j < CBV.size(); ++j) - if (CBV[j].BB == Clusters[i].MBB) - break; - if (j == CBV.size()) - CBV.push_back( - CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero())); - CaseBits *CB = &CBV[j]; - - // Update Mask, Bits and ExtraProb. - uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue(); - uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue(); - assert(Hi >= Lo && Hi < 64 && "Invalid bit case!"); - CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo; - CB->Bits += Hi - Lo + 1; - CB->ExtraProb += Clusters[i].Prob; - TotalProb += Clusters[i].Prob; - } - - BitTestInfo BTI; - std::sort(CBV.begin(), CBV.end(), [](const CaseBits &a, const CaseBits &b) { - // Sort by probability first, number of bits second. - if (a.ExtraProb != b.ExtraProb) - return a.ExtraProb > b.ExtraProb; - return a.Bits > b.Bits; - }); - - for (auto &CB : CBV) { - MachineBasicBlock *BitTestBB = - FuncInfo.MF->CreateMachineBasicBlock(SI->getParent()); - BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb)); - } - BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange), - SI->getCondition(), -1U, MVT::Other, false, - ContiguousRange, nullptr, nullptr, std::move(BTI), - TotalProb); - - BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High, - BitTestCases.size() - 1, TotalProb); - return true; -} - -void SelectionDAGBuilder::findBitTestClusters(CaseClusterVector &Clusters, - const SwitchInst *SI) { -// Partition Clusters into as few subsets as possible, where each subset has a -// range that fits in a machine word and has <= 3 unique destinations. - -#ifndef NDEBUG - // Clusters must be sorted and contain Range or JumpTable clusters. - assert(!Clusters.empty()); - assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable); - for (const CaseCluster &C : Clusters) - assert(C.Kind == CC_Range || C.Kind == CC_JumpTable); - for (unsigned i = 1; i < Clusters.size(); ++i) - assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue())); -#endif - - // The algorithm below is not suitable for -O0. - if (TM.getOptLevel() == CodeGenOpt::None) - return; - - // If target does not have legal shift left, do not emit bit tests at all. - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - EVT PTy = TLI.getPointerTy(DAG.getDataLayout()); - if (!TLI.isOperationLegal(ISD::SHL, PTy)) - return; - - int BitWidth = PTy.getSizeInBits(); - const int64_t N = Clusters.size(); - - // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1]. - SmallVector MinPartitions(N); - // LastElement[i] is the last element of the partition starting at i. - SmallVector LastElement(N); - - // FIXME: This might not be the best algorithm for finding bit test clusters. - - // Base case: There is only one way to partition Clusters[N-1]. - MinPartitions[N - 1] = 1; - LastElement[N - 1] = N - 1; - - // Note: loop indexes are signed to avoid underflow. - for (int64_t i = N - 2; i >= 0; --i) { - // Find optimal partitioning of Clusters[i..N-1]. - // Baseline: Put Clusters[i] into a partition on its own. - MinPartitions[i] = MinPartitions[i + 1] + 1; - LastElement[i] = i; - - // Search for a solution that results in fewer partitions. - // Note: the search is limited by BitWidth, reducing time complexity. - for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) { - // Try building a partition from Clusters[i..j]. - - // Check the range. - if (!rangeFitsInWord(Clusters[i].Low->getValue(), - Clusters[j].High->getValue())) - continue; - - // Check nbr of destinations and cluster types. - // FIXME: This works, but doesn't seem very efficient. - bool RangesOnly = true; - BitVector Dests(FuncInfo.MF->getNumBlockIDs()); - for (int64_t k = i; k <= j; k++) { - if (Clusters[k].Kind != CC_Range) { - RangesOnly = false; - break; - } - Dests.set(Clusters[k].MBB->getNumber()); - } - if (!RangesOnly || Dests.count() > 3) - break; - - // Check if it's a better partition. - unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]); - if (NumPartitions < MinPartitions[i]) { - // Found a better partition. - MinPartitions[i] = NumPartitions; - LastElement[i] = j; - } - } - } - - // Iterate over the partitions, replacing with bit-test clusters in-place. - unsigned DstIndex = 0; - for (unsigned First = 0, Last; First < N; First = Last + 1) { - Last = LastElement[First]; - assert(First <= Last); - assert(DstIndex <= First); - - CaseCluster BitTestCluster; - if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) { - Clusters[DstIndex++] = BitTestCluster; - } else { - size_t NumClusters = Last - First + 1; - std::memmove(&Clusters[DstIndex], &Clusters[First], - sizeof(Clusters[0]) * NumClusters); - DstIndex += NumClusters; - } - } - Clusters.resize(DstIndex); -} - void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond, MachineBasicBlock *SwitchMBB, MachineBasicBlock *DefaultMBB) { @@ -9097,13 +8562,13 @@ // TODO: This could be extended to merge any 2 cases in switches with 3 // cases. // TODO: Handle cases where W.CaseBB != SwitchBB. - CaseCluster &Small = *W.FirstCluster; - CaseCluster &Big = *W.LastCluster; + MachineCaseCluster &Small = *W.FirstCluster; + MachineCaseCluster &Big = *W.LastCluster; - if (Small.Low == Small.High && Big.Low == Big.High && + if (Small.getLow() == Small.getHigh() && Big.getLow() == Big.getHigh() && Small.MBB == Big.MBB) { - const APInt &SmallValue = Small.Low->getValue(); - const APInt &BigValue = Big.Low->getValue(); + const APInt &SmallValue = Small.getLow()->getValue(); + const APInt &BigValue = Big.getLow()->getValue(); // Check that there is only one bit different. APInt CommonBit = BigValue ^ SmallValue; @@ -9147,17 +8612,17 @@ if (TM.getOptLevel() != CodeGenOpt::None) { // Order cases by probability so the most likely case will be checked first. std::sort(W.FirstCluster, W.LastCluster + 1, - [](const CaseCluster &a, const CaseCluster &b) { - return a.Prob > b.Prob; - }); + [](const MachineCaseCluster &a, const MachineCaseCluster &b) { + return a.Prob > b.Prob; + }); // Rearrange the case blocks so that the last one falls through if possible // without without changing the order of probabilities. - for (CaseClusterIt I = W.LastCluster; I > W.FirstCluster; ) { + for (MachineCaseClusterIt I = W.LastCluster; I > W.FirstCluster;) { --I; if (I->Prob > W.LastCluster->Prob) break; - if (I->Kind == CC_Range && I->MBB == NextMBB) { + if (I->getKind() == CC_Range && I->MBB == NextMBB) { std::swap(*I, *W.LastCluster); break; } @@ -9167,11 +8632,12 @@ // Compute total probability. BranchProbability DefaultProb = W.DefaultProb; BranchProbability UnhandledProbs = DefaultProb; - for (CaseClusterIt I = W.FirstCluster; I <= W.LastCluster; ++I) + for (MachineCaseClusterIt I = W.FirstCluster; I <= W.LastCluster; ++I) UnhandledProbs += I->Prob; MachineBasicBlock *CurMBB = W.MBB; - for (CaseClusterIt I = W.FirstCluster, E = W.LastCluster; I <= E; ++I) { + for (MachineCaseClusterIt I = W.FirstCluster, E = W.LastCluster; I <= E; + ++I) { MachineBasicBlock *Fallthrough; if (I == W.LastCluster) { // For the last cluster, fall through to the default destination. @@ -9184,7 +8650,7 @@ } UnhandledProbs -= I->Prob; - switch (I->Kind) { + switch (I->getKind()) { case CC_JumpTable: { // FIXME: Optimize away range check based on pivot comparisons. JumpTableHeader *JTH = &JTCases[I->JTCasesIndex].first; @@ -9259,18 +8725,18 @@ case CC_Range: { const Value *RHS, *LHS, *MHS; ISD::CondCode CC; - if (I->Low == I->High) { - // Check Cond == I->Low. + if (I->getLow() == I->getHigh()) { + // Check Cond == I->getLow(). CC = ISD::SETEQ; LHS = Cond; - RHS=I->Low; + RHS=I->getLow(); MHS = nullptr; } else { - // Check I->Low <= Cond <= I->High. + // Check I->getLow() <= Cond <= I->getHigh(). CC = ISD::SETLE; - LHS = I->Low; + LHS = I->getLow(); MHS = Cond; - RHS = I->High; + RHS = I->getHigh(); } // The false probability is the sum of all unhandled cases. @@ -9289,15 +8755,15 @@ } } -unsigned SelectionDAGBuilder::caseClusterRank(const CaseCluster &CC, - CaseClusterIt First, - CaseClusterIt Last) { - return std::count_if(First, Last + 1, [&](const CaseCluster &X) { +unsigned SelectionDAGBuilder::caseClusterRank(const MachineCaseCluster &CC, + MachineCaseClusterIt First, + MachineCaseClusterIt Last) { + return std::count_if(First, Last + 1, [&](const MachineCaseCluster &X) { if (X.Prob != CC.Prob) return X.Prob > CC.Prob; // Ties are broken by comparing the case value. - return X.Low->getValue().slt(CC.Low->getValue()); + return X.getLow()->getValue().slt(CC.getLow()->getValue()); }); } @@ -9305,7 +8771,8 @@ const SwitchWorkListItem &W, Value *Cond, MachineBasicBlock *SwitchMBB) { - assert(W.FirstCluster->Low->getValue().slt(W.LastCluster->Low->getValue()) && + assert(W.FirstCluster->getLow()->getValue().slt( + W.LastCluster->getLow()->getValue()) && "Clusters not sorted?"); assert(W.LastCluster - W.FirstCluster + 1 >= 2 && "Too small to split!"); @@ -9313,8 +8780,8 @@ // Balance the tree based on branch probabilities to create a near-optimal (in // terms of search time given key frequency) binary search tree. See e.g. Kurt // Mehlhorn "Nearly Optimal Binary Search Trees" (1975). - CaseClusterIt LastLeft = W.FirstCluster; - CaseClusterIt FirstRight = W.LastCluster; + MachineCaseClusterIt LastLeft = W.FirstCluster; + MachineCaseClusterIt FirstRight = W.LastCluster; auto LeftProb = LastLeft->Prob + W.DefaultProb / 2; auto RightProb = FirstRight->Prob + W.DefaultProb / 2; @@ -9346,7 +8813,7 @@ if (NumLeft < NumRight) { // Consider moving the first cluster on the right to the left side. - CaseCluster &CC = *FirstRight; + MachineCaseCluster &CC = *FirstRight; unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster); unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft); if (LeftSideRank <= RightSideRank) { @@ -9358,7 +8825,7 @@ } else { assert(NumRight < NumLeft); // Consider moving the last element on the left to the right side. - CaseCluster &CC = *LastLeft; + MachineCaseCluster &CC = *LastLeft; unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft); unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster); if (RightSideRank <= LeftSideRank) { @@ -9378,14 +8845,14 @@ // Use the first element on the right as pivot since we will make less-than // comparisons against it. - CaseClusterIt PivotCluster = FirstRight; + MachineCaseClusterIt PivotCluster = FirstRight; assert(PivotCluster > W.FirstCluster); assert(PivotCluster <= W.LastCluster); - CaseClusterIt FirstLeft = W.FirstCluster; - CaseClusterIt LastRight = W.LastCluster; + MachineCaseClusterIt FirstLeft = W.FirstCluster; + MachineCaseClusterIt LastRight = W.LastCluster; - const ConstantInt *Pivot = PivotCluster->Low; + const ConstantInt *Pivot = PivotCluster->getLow(); // New blocks will be inserted immediately after the current one. MachineFunction::iterator BBI(W.MBB); @@ -9395,9 +8862,9 @@ // we can branch to its destination directly if it's squeezed exactly in // between the known lower bound and Pivot - 1. MachineBasicBlock *LeftMBB; - if (FirstLeft == LastLeft && FirstLeft->Kind == CC_Range && - FirstLeft->Low == W.GE && - (FirstLeft->High->getValue() + 1LL) == Pivot->getValue()) { + if (FirstLeft == LastLeft && FirstLeft->getKind() == CC_Range && + FirstLeft->getLow() == W.GE && + (FirstLeft->getHigh()->getValue() + 1LL) == Pivot->getValue()) { LeftMBB = FirstLeft->MBB; } else { LeftMBB = FuncInfo.MF->CreateMachineBasicBlock(W.MBB->getBasicBlock()); @@ -9412,8 +8879,8 @@ // single cluster, RHS.Low == Pivot, and we can branch to its destination // directly if RHS.High equals the current upper bound. MachineBasicBlock *RightMBB; - if (FirstRight == LastRight && FirstRight->Kind == CC_Range && - W.LT && (FirstRight->High->getValue() + 1ULL) == W.LT->getValue()) { + if (FirstRight == LastRight && FirstRight->getKind() == CC_Range && W.LT && + (FirstRight->getHigh()->getValue() + 1ULL) == W.LT->getValue()) { RightMBB = FirstRight->MBB; } else { RightMBB = FuncInfo.MF->CreateMachineBasicBlock(W.MBB->getBasicBlock()); @@ -9434,62 +8901,177 @@ SwitchCases.push_back(CB); } -void SelectionDAGBuilder::visitSwitch(const SwitchInst &SI) { - // Extract cases from the switch. +unsigned SelectionDAGBuilder::prepareBitTests(const SwitchInst *SI, + CaseCluster &BTCluster, + MachineBasicBlock *DefaultMBB, + BranchProbability TotalProb) { + assert(BTCluster.Kind == CC_BitTests && BTCluster.getNumerOfCases() > 0 && + "Invalid bit test cluster"); + const APInt &Low = BTCluster.Low->getValue(); + const APInt &High = BTCluster.High->getValue(); + assert(Low.slt(High)); + + APInt LowBound; + APInt CmpRange; + + const int BitWidth = DAG.getTargetLoweringInfo() + .getPointerTy(DAG.getDataLayout()) + .getSizeInBits(); + assert(SwitchCaseClusterFinder::rangeFitsInWord(Low, High, + DAG.getDataLayout()) && + "Case range must fit in bit mask!"); + + unsigned First = 0; + unsigned Last = BTCluster.getNumerOfCases(); BranchProbabilityInfo *BPI = FuncInfo.BPI; - CaseClusterVector Clusters; - Clusters.reserve(SI.getNumCases()); - for (auto I : SI.cases()) { - MachineBasicBlock *Succ = FuncInfo.MBBMap[I.getCaseSuccessor()]; - const ConstantInt *CaseVal = I.getCaseValue(); - BranchProbability Prob = - BPI ? BPI->getEdgeProbability(SI.getParent(), I.getSuccessorIndex()) - : BranchProbability(1, SI.getNumCases() + 1); - Clusters.push_back(CaseCluster::range(CaseVal, CaseVal, Succ, Prob)); + + // Check if the clusters cover a contiguous range such that no value in the + // range will jump to the default statement. + bool ContiguousRange = true; + for (int64_t I = First + 1; I < Last; ++I) { + if (BTCluster.getCaseValueAt(SI, I)->getValue() != + BTCluster.getCaseValueAt(SI, I - 1)->getValue() + 1) { + ContiguousRange = false; + break; + } + } + + if (Low.isStrictlyPositive() && High.slt(BitWidth)) { + // Optimize the case where all the case values fit in a word without having + // to subtract minValue. In this case, we can optimize away the subtraction. + LowBound = APInt::getNullValue(Low.getBitWidth()); + CmpRange = High; + ContiguousRange = false; + } else { + LowBound = Low; + CmpRange = High - Low; } - MachineBasicBlock *DefaultMBB = FuncInfo.MBBMap[SI.getDefaultDest()]; + CaseBitsVector CBV; + for (unsigned I = First; I < Last; ++I) { + auto CI = BTCluster.getCase(SI, I); + MachineBasicBlock *MBB = FuncInfo.MBBMap[CI.getCaseSuccessor()]; + // Find the CaseBits for this destination. + unsigned J; + for (J = 0; J < CBV.size(); ++J) + if (CBV[J].BB == MBB) + break; + if (J == CBV.size()) + CBV.push_back(CaseBits(0, MBB, 0, BranchProbability::getZero())); + CaseBits *CB = &CBV[J]; - // Cluster adjacent cases with the same destination. We do this at all - // optimization levels because it's cheap to do and will make codegen faster - // if there are many clusters. - sortAndRangeify(Clusters); + // Update Mask, Bits and ExtraProb. + uint64_t Val = (CI.getCaseValue()->getValue() - LowBound).getZExtValue(); + CB->Mask |= (1ULL) << Val; + CB->Bits++; + CB->ExtraProb += + BPI ? BPI->getEdgeProbability(SI->getParent(), CI.getSuccessorIndex()) + : BranchProbability(1, SI->getNumCases() + 1); + } - if (TM.getOptLevel() != CodeGenOpt::None) { - // Replace an unreachable default with the most popular destination. - // FIXME: Exploit unreachable default more aggressively. - bool UnreachableDefault = - isa(SI.getDefaultDest()->getFirstNonPHIOrDbg()); - if (UnreachableDefault && !Clusters.empty()) { - DenseMap Popularity; - unsigned MaxPop = 0; - const BasicBlock *MaxBB = nullptr; - for (auto I : SI.cases()) { - const BasicBlock *BB = I.getCaseSuccessor(); - if (++Popularity[BB] > MaxPop) { - MaxPop = Popularity[BB]; - MaxBB = BB; - } - } - // Set new default. - assert(MaxPop > 0 && MaxBB); - DefaultMBB = FuncInfo.MBBMap[MaxBB]; - - // Remove cases that were pointing to the destination that is now the - // default. - CaseClusterVector New; - New.reserve(Clusters.size()); - for (CaseCluster &CC : Clusters) { - if (CC.MBB != DefaultMBB) - New.push_back(CC); - } - Clusters = std::move(New); + BitTestInfo BTI; + std::sort(CBV.begin(), CBV.end(), [](const CaseBits &a, const CaseBits &b) { + // Sort by probability first, number of bits second. + if (a.ExtraProb != b.ExtraProb) + return a.ExtraProb > b.ExtraProb; + return a.Bits > b.Bits; + }); + + for (auto &CB : CBV) { + MachineBasicBlock *BitTestBB = + FuncInfo.MF->CreateMachineBasicBlock(SI->getParent()); + BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb)); + } + BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange), + SI->getCondition(), -1U, MVT::Other, false, + ContiguousRange, nullptr, nullptr, std::move(BTI), + TotalProb); + + return BitTestCases.size() - 1; +} + +unsigned SelectionDAGBuilder::prepareJumpTable(const SwitchInst *SI, + CaseCluster &JTCluster, + MachineBasicBlock *DefaultMBB) { + assert(JTCluster.Kind == CC_JumpTable && JTCluster.getNumerOfCases() > 0 && + "Invalid jump table cluster"); + + BranchProbabilityInfo *BPI = FuncInfo.BPI; + std::vector Table; + DenseMap JTProbs; + unsigned ClusterSize = JTCluster.getNumerOfCases(); + // Initialize probabilities in JTProbs. + for (unsigned I = 0; I < ClusterSize; ++I) + JTProbs[FuncInfo.MBBMap[JTCluster.getCaseSuccessorAt(SI, I)]] = + BranchProbability::getZero(); + + for (unsigned I = 0; I < ClusterSize; ++I) { + auto CI = JTCluster.getCase(SI, I); + MachineBasicBlock *CurMBB = FuncInfo.MBBMap[CI.getCaseSuccessor()]; + const APInt &CurVal = CI.getCaseValue()->getValue(); + if (I != 0) { + const APInt &PreviousVal = + JTCluster.getCaseValueAt(SI, I - 1)->getValue(); + // Fill the gap between this and the previous case. + assert(PreviousVal.slt(CurVal)); + uint64_t Gap = (CurVal - PreviousVal).getLimitedValue() - 1; + for (uint64_t J = 0; J < Gap; J++) + Table.push_back(DefaultMBB); } + Table.push_back(CurMBB); + BranchProbability Prob = + BPI ? BPI->getEdgeProbability(SI->getParent(), CI.getSuccessorIndex()) + : BranchProbability(1, SI->getNumCases() + 1); + JTProbs[CurMBB] += Prob; } + // Create the MBB that will load from and jump through the table. + // Note: We create it here, but it's not inserted into the function yet. + MachineFunction *CurMF = FuncInfo.MF; + MachineBasicBlock *JumpTableMBB = + CurMF->CreateMachineBasicBlock(SI->getParent()); + + // Add successors. Note: use table order for determinism. + SmallPtrSet Done; + for (MachineBasicBlock *Succ : Table) { + if (Done.count(Succ)) + continue; + addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]); + Done.insert(Succ); + } + JumpTableMBB->normalizeSuccProbs(); + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI.getJumpTableEncoding()) + ->createJumpTableIndex(Table); + + // Set up the jump table info. + SelectionDAGBuilder::JumpTable JT(-1U, JTI, JumpTableMBB, nullptr); + + JumpTableHeader JTH(JTCluster.getCaseValueAt(SI, 0)->getValue(), + JTCluster.getCaseValueAt(SI, ClusterSize - 1)->getValue(), + SI->getCondition(), nullptr, false); + + JTCases.emplace_back(std::move(JTH), std::move(JT)); + return JTCases.size() - 1; +} + +void SelectionDAGBuilder::visitSwitch(const SwitchInst &SI) { + BranchProbabilityInfo *BPI = FuncInfo.BPI; + CaseClusterVector CaseClusters; + + if (!CaseClusterFinder) + CaseClusterFinder = new SwitchCaseClusterFinder( + DAG.getDataLayout(), DAG.getTargetLoweringInfo(), TM.getOptLevel()); + + const BasicBlock *DefaultBB = + CaseClusterFinder->findClusters(SI, CaseClusters); + // If there is only the default destination, jump there directly. MachineBasicBlock *SwitchMBB = FuncInfo.MBB; - if (Clusters.empty()) { + MachineBasicBlock *DefaultMBB = FuncInfo.MBBMap[DefaultBB]; + + if (CaseClusters.empty()) { SwitchMBB->addSuccessor(DefaultMBB); if (DefaultMBB != NextBlock(SwitchMBB)) { DAG.setRoot(DAG.getNode(ISD::BR, getCurSDLoc(), MVT::Other, @@ -9498,29 +9080,57 @@ return; } - findJumpTables(Clusters, &SI, DefaultMBB); - findBitTestClusters(Clusters, &SI); + // Populate clusters for lowering. + MachineCaseClusterVector MachineClusters; + MachineClusters.reserve(CaseClusters.size()); - DEBUG({ - dbgs() << "Case clusters: "; - for (const CaseCluster &C : Clusters) { - if (C.Kind == CC_JumpTable) dbgs() << "JT:"; - if (C.Kind == CC_BitTests) dbgs() << "BT:"; + for (CaseCluster &C : CaseClusters) { + BranchProbability ClusterProb = BranchProbability::getZero(); + for (unsigned I = 0, E = C.getNumerOfCases(); I != E; ++I) { + auto CI = C.getCase(&SI, I); + ClusterProb += + BPI ? BPI->getEdgeProbability(SI.getParent(), CI.getSuccessorIndex()) + : BranchProbability(1, SI.getNumCases() + 1); + } + if (C.Kind == CC_Range) { + MachineClusters.push_back(MachineCaseCluster::range( + &C, FuncInfo.MBBMap[C.getSingleSuccessor(&SI)], ClusterProb)); + } else if (C.Kind == CC_JumpTable) { + unsigned JTCasesIndex = prepareJumpTable(&SI, C, DefaultMBB); + MachineClusters.push_back( + MachineCaseCluster::jumpTable(&C, JTCasesIndex, ClusterProb)); + } else if (C.Kind == CC_BitTests) { + unsigned BTCasesIndex = prepareBitTests(&SI, C, DefaultMBB, ClusterProb); + MachineClusters.push_back( + MachineCaseCluster::bitTests(&C, BTCasesIndex, ClusterProb)); + } else + llvm_unreachable("Unknown case cluster"); + } - C.Low->getValue().print(dbgs(), true); - if (C.Low != C.High) { + DEBUG({ + dbgs() << "Machine case clusters: "; + for (const MachineCaseCluster &C : MachineClusters) { + if (C.getKind() == CC_JumpTable) + dbgs() << "JT:["; + if (C.getKind() == CC_BitTests) + dbgs() << "BT:["; + if (C.getKind() == CC_Range) + dbgs() << "Range:["; + + C.getLow()->getValue().print(dbgs(), true); + if (C.getLow() != C.getHigh()) { dbgs() << '-'; - C.High->getValue().print(dbgs(), true); + C.getHigh()->getValue().print(dbgs(), true); } - dbgs() << ' '; + dbgs() << "] "; } dbgs() << '\n'; }); - assert(!Clusters.empty()); + assert(!MachineClusters.empty()); SwitchWorkList WorkList; - CaseClusterIt First = Clusters.begin(); - CaseClusterIt Last = Clusters.end() - 1; + MachineCaseClusterIt First = MachineClusters.begin(); + MachineCaseClusterIt Last = MachineClusters.end() - 1; auto DefaultProb = getEdgeProbability(SwitchMBB, DefaultMBB); WorkList.push_back({SwitchMBB, First, Last, nullptr, nullptr, DefaultProb}); Index: lib/CodeGen/SelectionDAG/SwitchCaseCluster.cpp =================================================================== --- /dev/null +++ lib/CodeGen/SelectionDAG/SwitchCaseCluster.cpp @@ -0,0 +1,530 @@ +//===-- SwitchLoweringCaseCluster.cpp -----------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This implements routines for forming case clusters for SwitchInst. +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/SwitchCaseCluster.h" +#include "llvm/Support/Debug.h" +#include +using namespace llvm; + +/// Minimum jump table density for normal functions. +static cl::opt + JumpTableDensity("jump-table-density", cl::init(10), cl::Hidden, + cl::desc("Minimum density for building a jump table in " + "a normal function")); + +/// Minimum jump table density for -Os or -Oz functions. +static cl::opt OptsizeJumpTableDensity( + "optsize-jump-table-density", cl::init(40), cl::Hidden, + cl::desc("Minimum density for building a jump table in " + "an optsize function")); + +static inline bool isDense(uint64_t Range, uint64_t NumCases, + unsigned Density) { + assert(NumCases < UINT64_MAX / 100); + assert(Range >= NumCases); + return NumCases * 100 >= Range * Density; +} + +static inline bool areJTsAllowed(const TargetLowering &TLI, + const SwitchInst *SI) { + const Function *Fn = SI->getParent()->getParent(); + if (Fn->getFnAttribute("no-jump-tables").getValueAsString() == "true") + return false; + + return TLI.isOperationLegalOrCustom(ISD::BR_JT, MVT::Other) || + TLI.isOperationLegalOrCustom(ISD::BRIND, MVT::Other); +} + +static inline bool isTooSmallForJumptable(const unsigned ClusterSize, + const unsigned MinJumpTableEntries) { + return (ClusterSize < 2 || ClusterSize < MinJumpTableEntries); +} + +static inline unsigned getMaxJumpTableSize(const bool OptForSize, + const TargetLowering &TLI) { + return OptForSize || TLI.getMaximumJumpTableSize() == 0 + ? UINT_MAX + : TLI.getMaximumJumpTableSize(); +} + +static inline unsigned getJumptableMinDensity(const bool OptForSize) { + return OptForSize ? OptsizeJumpTableDensity : JumpTableDensity; +} + +bool SwitchCaseClusterFinder::isDense( + const CaseClusterVector &Clusters, + const SmallVectorImpl &TotalCases, unsigned First, unsigned Last, + unsigned Density) const { + assert(Last >= First); + assert(TotalCases[Last] >= TotalCases[First]); + + const APInt &LowCase = Clusters[First].Low->getValue(); + const APInt &HighCase = Clusters[Last].High->getValue(); + assert(LowCase.getBitWidth() == HighCase.getBitWidth()); + + // FIXME: A range of consecutive cases has 100% density, but only requires one + // comparison to lower. We should discriminate against such consecutive ranges + // in jump tables. + + uint64_t Diff = (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100); + uint64_t Range = Diff + 1; + + uint64_t NumCases = + TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]); + + assert(NumCases < UINT64_MAX / 100); + assert(Range >= NumCases); + + return ::isDense(Range, NumCases, Density); +} + +bool SwitchCaseClusterFinder::canBuildBitTest(CaseClusterVector &Clusters, + unsigned First, unsigned Last, + const SwitchInst *SI, + CaseCluster &BTCluster) { + assert(First <= Last); + if (First == Last) + return false; + + SmallPtrSet Dests; + unsigned NumCmps = 0; + for (int64_t I = First; I <= Last; ++I) { + assert(Clusters[I].Kind == CC_Range); + Dests.insert(Clusters[I].getSingleSuccessor(SI)); + NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2; + } + unsigned NumDests = Dests.size(); + + APInt Low = Clusters[First].Low->getValue(); + APInt High = Clusters[Last].High->getValue(); + assert(Low.slt(High)); + + return isSuitableForBitTests(NumDests, NumCmps, Low, High); +} + +void SwitchCaseClusterFinder::sortAndRangeify(const SwitchInst *SI, + CaseClusterVector &Clusters) { +#ifndef NDEBUG + for (const CaseCluster &CC : Clusters) + assert(CC.Low == CC.High && CC.Cases.size() == 1 && + "Input clusters must be single-case"); +#endif + std::sort(Clusters.begin(), Clusters.end(), + [](const CaseCluster &a, const CaseCluster &b) { + return a.Low->getValue().slt(b.Low->getValue()); + }); + // Merge adjacent clusters with the same destination. + const unsigned N = Clusters.size(); + unsigned DstIndex = 0; + for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) { + CaseCluster &CC = Clusters[SrcIndex]; + const ConstantInt *CaseVal = CC.Low; + const BasicBlock *Succ = CC.getSingleSuccessor(SI); + if (DstIndex != 0 && + Clusters[DstIndex - 1].getSingleSuccessor(SI) == Succ && + (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) { + // If this case has the same successor and is a neighbour, merge it into + // the previous cluster. + Clusters[DstIndex - 1].High = CaseVal; + Clusters[DstIndex - 1].Cases.push_back(Clusters[SrcIndex].Cases.back()); + } else + Clusters[DstIndex++] = Clusters[SrcIndex]; + } + Clusters.resize(DstIndex); +} + +const BasicBlock *SwitchCaseClusterFinder::replaceUnreachableDefault( + const SwitchInst &SI, CaseClusterVector &Clusters) { + const BasicBlock *DefaultBB = SI.getDefaultDest(); + if (OptLevel != CodeGenOpt::None) { + // FIXME: Exploit unreachable default more aggressively. + bool UnreachableDefault = + isa(SI.getDefaultDest()->getFirstNonPHIOrDbg()); + if (UnreachableDefault && !Clusters.empty()) { + DenseMap Popularity; + unsigned MaxPop = 0; + const BasicBlock *MaxBB = nullptr; + for (auto I : SI.cases()) { + const BasicBlock *BB = I.getCaseSuccessor(); + if (++Popularity[BB] > MaxPop) { + MaxPop = Popularity[BB]; + MaxBB = BB; + } + } + // Set new default. + assert(MaxPop > 0 && MaxBB); + DefaultBB = MaxBB; + + // Remove cases that were pointing to the destination that is now the + // default. + CaseClusterVector New; + New.reserve(Clusters.size()); + for (CaseCluster &CC : Clusters) { + if (CC.getSingleSuccessor(&SI) != DefaultBB) + New.push_back(CC); + } + Clusters = std::move(New); + } + } + return DefaultBB; +} + +void SwitchCaseClusterFinder::formInitalCaseClusers( + const SwitchInst &SI, CaseClusterVector &Clusters) { + Clusters.reserve(SI.getNumCases()); + for (auto I : SI.cases()) { + const ConstantInt *CaseVal = I.getCaseValue(); + Clusters.push_back(CaseCluster::range(CaseVal, CaseVal, I.getCaseIndex())); + } + // Cluster adjacent cases with the same destination. We do this at all + // optimization levels because it's cheap to do and will make codegen faster + // if there are many clusters. + sortAndRangeify(&SI, Clusters); +} + +bool SwitchCaseClusterFinder::isSuitableForBitTests(unsigned NumDests, + unsigned NumCmps, + const APInt &Low, + const APInt &High) { + // FIXME: I don't think NumCmps is the correct metric: a single case and a + // range of cases both require only one branch to lower. Just looking at the + // number of clusters and destinations should be enough to decide whether to + // build bit tests. + + // To lower a range with bit tests, the range must fit the bitwidth of a + // machine word. + if (!rangeFitsInWord(Low, High, DL)) + return false; + + // Decide whether it's profitable to lower this range with bit tests. Each + // destination requires a bit test and branch, and there is an overall range + // check branch. For a small number of clusters, separate comparisons might be + // cheaper, and for many destinations, splitting the range might be better. + return (NumDests == 1 && NumCmps >= 3) || (NumDests == 2 && NumCmps >= 5) || + (NumDests == 3 && NumCmps >= 6); +} + +bool SwitchCaseClusterFinder::canBuildJumpTable( + const CaseClusterVector &Clusters, unsigned First, unsigned Last, + const SwitchInst *SI, CaseCluster &JTCluster) { + assert(First <= Last); + unsigned NumCmps = 0; + SmallPtrSet Dests; + for (unsigned I = First; I <= Last; ++I) { + assert(Clusters[I].Kind == CC_Range); + const APInt &Low = Clusters[I].Low->getValue(); + const APInt &High = Clusters[I].High->getValue(); + NumCmps += (Low == High) ? 1 : 2; + Dests.insert(Clusters[I].getSingleSuccessor(SI)); + } + return !(isSuitableForBitTests(Dests.size(), NumCmps, + Clusters[First].Low->getValue(), + Clusters[Last].High->getValue())); +} + +void SwitchCaseClusterFinder::findJumpTables(CaseClusterVector &Clusters, + const SwitchInst *SI) { + +#ifndef NDEBUG + // Clusters must be non-empty, sorted, and only contain Range clusters. + assert(!Clusters.empty()); + for (CaseCluster &C : Clusters) + assert(C.Kind == CC_Range); + for (unsigned i = 1, e = Clusters.size(); i < e; ++i) { + assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue())); + // Case values in a cluster must be sorted. + for (unsigned I = 1, E = Clusters[i].getNumerOfCases(); I != E; ++I) { + const APInt &PreValue = Clusters[i].getCaseValueAt(SI, I - 1)->getValue(); + const APInt &CurValue = Clusters[i].getCaseValueAt(SI, I)->getValue(); + assert(PreValue.slt(CurValue)); + } + } +#endif + + if (!areJTsAllowed(TLI, SI)) + return; + + const bool OptForSize = SI->getParent()->getParent()->optForSize(); + const int64_t N = Clusters.size(); + const unsigned MinJumpTableEntries = TLI.getMinimumJumpTableEntries(); + const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2; + const unsigned MaxJumpTableSize = getMaxJumpTableSize(OptForSize, TLI); + + if (isTooSmallForJumptable(N, MinJumpTableEntries)) + return; + + // TotalCases[i]: Total nbr of cases in Clusters[0..i]. + SmallVector TotalCases(N); + for (unsigned i = 0; i < N; ++i) { + TotalCases[i] = Clusters[i].getNumerOfCases(); + if (i != 0) + TotalCases[i] += TotalCases[i - 1]; + } + + const unsigned MinDensity = getJumptableMinDensity(OptForSize); + + // Cheap case: the whole range may be suitable for jump table. + unsigned JumpTableSize = + (Clusters[N - 1].High->getValue() - Clusters[0].Low->getValue()) + .getLimitedValue(UINT_MAX - 1) + + 1; + + if (JumpTableSize <= MaxJumpTableSize && + isDense(Clusters, TotalCases, 0, N - 1, MinDensity)) { + CaseCluster JTCluster; + if (canBuildJumpTable(Clusters, 0, N - 1, SI, JTCluster)) { + createJumpTableCluster(Clusters, 0, N - 1, SI, JTCluster); + Clusters[0] = JTCluster; + Clusters.resize(1); + return; + } + } + + // The algorithm below is not suitable for -O0. + if (OptLevel == CodeGenOpt::None) + return; + + // Split Clusters into minimum number of dense partitions. The algorithm uses + // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code + // for the Case Statement'" (1994), but builds the MinPartitions array in + // reverse order to make it easier to reconstruct the partitions in ascending + // order. In the choice between two optimal partitionings, it picks the one + // which yields more jump tables. + + // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1]. + SmallVector MinPartitions(N); + // LastElement[i] is the last element of the partition starting at i. + SmallVector LastElement(N); + // PartitionsScore[i] is used to break ties when choosing between two + // partitionings resulting in the same number of partitions. + SmallVector PartitionsScore(N); + // For PartitionsScore, a small number of comparisons is considered as good as + // a jump table and a single comparison is considered better than a jump + // table. + enum PartitionScores : unsigned { + NoTable = 0, + Table = 1, + FewCases = 1, + SingleCase = 2 + }; + + // Base case: There is only one way to partition Clusters[N-1]. + MinPartitions[N - 1] = 1; + LastElement[N - 1] = N - 1; + PartitionsScore[N - 1] = PartitionScores::SingleCase; + + // Note: loop indexes are signed to avoid underflow. + for (int64_t i = N - 2; i >= 0; i--) { + // Find optimal partitioning of Clusters[i..N-1]. + // Baseline: Put Clusters[i] into a partition on its own. + MinPartitions[i] = MinPartitions[i + 1] + 1; + LastElement[i] = i; + PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase; + + // Search for a solution that results in fewer partitions. + for (int64_t j = N - 1; j > i; j--) { + // Try building a partition from Clusters[i..j]. + JumpTableSize = + (Clusters[j].High->getValue() - Clusters[i].Low->getValue()) + .getLimitedValue(UINT_MAX - 1) + + 1; + if (JumpTableSize <= MaxJumpTableSize && + isDense(Clusters, TotalCases, i, j, MinDensity)) { + unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]); + unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1]; + int64_t NumEntries = j - i + 1; + + if (NumEntries == 1) + Score += PartitionScores::SingleCase; + else if (NumEntries <= SmallNumberOfEntries) + Score += PartitionScores::FewCases; + else if (NumEntries >= MinJumpTableEntries) + Score += PartitionScores::Table; + + // If this leads to fewer partitions, or to the same number of + // partitions with better score, it is a better partitioning. + if (NumPartitions < MinPartitions[i] || + (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) { + MinPartitions[i] = NumPartitions; + LastElement[i] = j; + PartitionsScore[i] = Score; + } + } + } + } + + // Iterate over the partitions, replacing some with jump tables in-place. + unsigned DstIndex = 0; + for (unsigned First = 0, Last; First < N; First = Last + 1) { + Last = LastElement[First]; + assert(Last >= First); + assert(DstIndex <= First); + unsigned NumClusters = Last - First + 1; + CaseCluster JTCluster; + if (NumClusters >= MinJumpTableEntries && + canBuildJumpTable(Clusters, First, Last, SI, JTCluster)) { + createJumpTableCluster(Clusters, First, Last, SI, JTCluster); + Clusters[DstIndex++] = JTCluster; + } else { + for (unsigned I = First; I <= Last; ++I) + Clusters[DstIndex++] = Clusters[I]; + } + } + Clusters.resize(DstIndex); +} + +void SwitchCaseClusterFinder::findBitTestClusters(CaseClusterVector &Clusters, + const SwitchInst *SI) { +// Partition Clusters into as few subsets as possible, where each subset has a +// range that fits in a machine word and has <= 3 unique destinations. +#ifndef NDEBUG + // Clusters must be sorted and contain Range or JumpTable clusters. + assert(!Clusters.empty()); + assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable); + for (const CaseCluster &C : Clusters) + assert(C.Kind == CC_Range || C.Kind == CC_JumpTable); + + for (unsigned i = 1; i < Clusters.size(); ++i) { + assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue())); + // Case values in a cluster must be sorted. + for (unsigned I = 1, E = Clusters[i].getNumerOfCases(); I != E; ++I) { + const APInt &PreValue = Clusters[i].getCaseValueAt(SI, I - 1)->getValue(); + const APInt &CurValue = Clusters[i].getCaseValueAt(SI, I)->getValue(); + assert(PreValue.slt(CurValue)); + } + } +#endif + + // The algorithm below is not suitable for -O0. + if (OptLevel == CodeGenOpt::None) + return; + + // If target does not have legal shift left, do not emit bit tests at all. + EVT PTy = TLI.getPointerTy(DL); + if (!TLI.isOperationLegal(ISD::SHL, PTy)) + return; + + int BitWidth = PTy.getSizeInBits(); + const int64_t N = Clusters.size(); + + // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1]. + SmallVector MinPartitions(N); + // LastElement[i] is the last element of the partition starting at i. + SmallVector LastElement(N); + + // FIXME: This might not be the best algorithm for finding bit test clusters. + + // Base case: There is only one way to partition Clusters[N-1]. + MinPartitions[N - 1] = 1; + LastElement[N - 1] = N - 1; + + // Note: loop indexes are signed to avoid underflow. + for (int64_t i = N - 2; i >= 0; --i) { + // Find optimal partitioning of Clusters[i..N-1]. + // Baseline: Put Clusters[i] into a partition on its own. + MinPartitions[i] = MinPartitions[i + 1] + 1; + LastElement[i] = i; + + // Search for a solution that results in fewer partitions. + // Note: the search is limited by BitWidth, reducing time complexity. + for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) { + // Try building a partition from Clusters[i..j]. + + // Check the range. + if (!rangeFitsInWord(Clusters[i].Low->getValue(), + Clusters[j].High->getValue(), DL)) + continue; + + // Check nbr of destinations and cluster types. + // FIXME: This works, but doesn't seem very efficient. + bool RangesOnly = true; + SmallPtrSet Dests; + for (int64_t k = i; k <= j; k++) { + if (Clusters[k].Kind != CC_Range) { + RangesOnly = false; + break; + } + Dests.insert(Clusters[k].getSingleSuccessor(SI)); + } + if (!RangesOnly || Dests.size() > 3) + break; + + // Check if it's a better partition. + unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]); + if (NumPartitions < MinPartitions[i]) { + // Found a better partition. + MinPartitions[i] = NumPartitions; + LastElement[i] = j; + } + } + } + + // Iterate over the partitions, replacing with bit-test clusters in-place. + unsigned DstIndex = 0; + for (unsigned First = 0, Last; First < N; First = Last + 1) { + Last = LastElement[First]; + assert(First <= Last); + assert(DstIndex <= First); + + CaseCluster BitTestCluster; + if (canBuildBitTest(Clusters, First, Last, SI, BitTestCluster)) { + createBitTestCluster(Clusters, First, Last, SI, BitTestCluster); + Clusters[DstIndex++] = BitTestCluster; + } else { + size_t NumClusters = Last - First + 1; + for (unsigned I = 0; I != NumClusters; ++I) + Clusters[DstIndex++] = Clusters[First + I]; + } + } + Clusters.resize(DstIndex); +} + +void SwitchCaseClusterFinder::createJumpTableCluster( + const CaseClusterVector &Clusters, unsigned First, unsigned Last, + const SwitchInst *SI, CaseCluster &JTCluster) { + assert(First <= Last); + JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High); + + for (unsigned I = First; I <= Last; ++I) { + assert(Clusters[I].Kind == CC_Range); + JTCluster.Cases.insert(JTCluster.Cases.end(), Clusters[I].Cases.begin(), + Clusters[I].Cases.end()); + } +} + +void SwitchCaseClusterFinder::createBitTestCluster(CaseClusterVector &Clusters, + unsigned First, + unsigned Last, + const SwitchInst *SI, + CaseCluster &BTCluster) { + BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High); + + for (unsigned I = First; I <= Last; ++I) { + assert(Clusters[I].Kind == CC_Range); + BTCluster.Cases.insert(BTCluster.Cases.end(), Clusters[I].Cases.begin(), + Clusters[I].Cases.end()); + } +} + +const BasicBlock * +SwitchCaseClusterFinder::findClusters(const SwitchInst &SI, + CaseClusterVector &Clusters) { + formInitalCaseClusers(SI, Clusters); + const BasicBlock *DefaultBB = replaceUnreachableDefault(SI, Clusters); + + if (!Clusters.empty()) { + findJumpTables(Clusters, &SI); + findBitTestClusters(Clusters, &SI); + } + return DefaultBB; +}