diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -40,12 +40,14 @@ } class AssumptionCache; +class BlockFrequencyInfo; class BranchInst; class Function; class GlobalValue; class IntrinsicInst; class LoadInst; class Loop; +class ProfileSummaryInfo; class SCEV; class ScalarEvolution; class StoreInst; @@ -297,7 +299,9 @@ /// \p JTSize Set a jump table size only when \p SI is suitable for a jump /// table. unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI, - unsigned &JTSize) const; + unsigned &JTSize, + ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI) const; /// Estimate the cost of a given IR user when lowered. /// @@ -1177,7 +1181,9 @@ const User *U) = 0; virtual int getMemcpyCost(const Instruction *I) = 0; virtual unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI, - unsigned &JTSize) = 0; + unsigned &JTSize, + ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI) = 0; virtual int getUserCost(const User *U, ArrayRef Operands) = 0; virtual bool hasBranchDivergence() = 0; @@ -1678,8 +1684,10 @@ return Impl.getMaxInterleaveFactor(VF); } unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI, - unsigned &JTSize) override { - return Impl.getEstimatedNumberOfCaseClusters(SI, JTSize); + unsigned &JTSize, + ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI) override { + return Impl.getEstimatedNumberOfCaseClusters(SI, JTSize, PSI, BFI); } unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -114,7 +114,11 @@ } unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI, - unsigned &JTSize) { + unsigned &JTSize, + ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI) { + (void)PSI; + (void)BFI; JTSize = 0; return SI.getNumCases(); } diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -326,7 +326,9 @@ } unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI, - unsigned &JumpTableSize) { + unsigned &JumpTableSize, + ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI) { /// Try to find the estimated number of clusters. Note that the number of /// clusters identified in this function could be different from the actual /// numbers found in lowering. This function ignore switches that are @@ -374,7 +376,7 @@ (MaxCaseVal - MinCaseVal) .getLimitedValue(std::numeric_limits::max() - 1) + 1; // Check whether a range of clusters is dense enough for a jump table - if (TLI->isSuitableForJumpTable(&SI, N, Range)) { + if (TLI->isSuitableForJumpTable(&SI, N, Range, PSI, BFI)) { JumpTableSize = Range; return 1; } diff --git a/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h b/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h --- a/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h +++ b/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h @@ -19,6 +19,7 @@ class FunctionLoweringInfo; class MachineBasicBlock; +class BlockFrequencyInfo; namespace SwitchCG { @@ -264,7 +265,8 @@ std::vector BitTestCases; void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI, - MachineBasicBlock *DefaultMBB); + MachineBasicBlock *DefaultMBB, + ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI); bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First, unsigned Last, const SwitchInst *SI, @@ -295,4 +297,3 @@ } // namespace llvm #endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H - diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -28,6 +28,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/CodeGen/DAGCombine.h" #include "llvm/CodeGen/ISDOpcodes.h" #include "llvm/CodeGen/RuntimeLibcalls.h" @@ -53,6 +54,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MachineValueType.h" #include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/Utils/SizeOpts.h" #include #include #include @@ -1030,13 +1032,16 @@ /// Return true if lowering to a jump table is suitable for a set of case /// clusters which may contain \p NumCases cases, \p Range range of values. virtual bool isSuitableForJumpTable(const SwitchInst *SI, uint64_t NumCases, - uint64_t Range) const { + uint64_t Range, ProfileSummaryInfo* PSI, + BlockFrequencyInfo *BFI) const { // FIXME: This function check the maximum table size and density, but the // minimum size is not checked. It would be nice if the minimum size is // also combined within this function. Currently, the minimum size check is // performed in findJumpTable() in SelectionDAGBuiler and // getEstimatedNumberOfCaseClusters() in BasicTTIImpl. - const bool OptForSize = SI->getParent()->getParent()->hasOptSize(); + const bool OptForSize = SI->getParent()->getParent()->hasOptSize() || + llvm::shouldOptimizeForSize(SI->getParent(), PSI, + BFI); const unsigned MinDensity = getMinimumJumpTableDensity(OptForSize); const unsigned MaxJumpTableSize = getMaximumJumpTableSize(); diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp --- a/llvm/lib/Analysis/InlineCost.cpp +++ b/llvm/lib/Analysis/InlineCost.cpp @@ -1456,8 +1456,9 @@ int CostUpperBound = INT_MAX - InlineConstants::InstrCost - 1; unsigned JumpTableSize = 0; + BlockFrequencyInfo *BFI = GetBFI ? &((*GetBFI)(F)) : nullptr; unsigned NumCaseCluster = - TTI.getEstimatedNumberOfCaseClusters(SI, JumpTableSize); + TTI.getEstimatedNumberOfCaseClusters(SI, JumpTableSize, PSI, BFI); // If suitable for a jump table, consider the cost for the table size and // branch to destination. diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -194,9 +194,10 @@ } unsigned -TargetTransformInfo::getEstimatedNumberOfCaseClusters(const SwitchInst &SI, - unsigned &JTSize) const { - return TTIImpl->getEstimatedNumberOfCaseClusters(SI, JTSize); +TargetTransformInfo::getEstimatedNumberOfCaseClusters( + const SwitchInst &SI, unsigned &JTSize, ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI) const { + return TTIImpl->getEstimatedNumberOfCaseClusters(SI, JTSize, PSI, BFI); } int TargetTransformInfo::getUserCost(const User *U, diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp --- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp +++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -466,7 +466,7 @@ return true; } - SL->findJumpTables(Clusters, &SI, DefaultMBB); + SL->findJumpTables(Clusters, &SI, DefaultMBB, nullptr, nullptr); LLVM_DEBUG({ dbgs() << "Case clusters: "; diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -10543,7 +10543,7 @@ return; } - SL->findJumpTables(Clusters, &SI, DefaultMBB); + SL->findJumpTables(Clusters, &SI, DefaultMBB, nullptr, nullptr); SL->findBitTestClusters(Clusters, &SI); LLVM_DEBUG({ diff --git a/llvm/lib/CodeGen/SwitchLoweringUtils.cpp b/llvm/lib/CodeGen/SwitchLoweringUtils.cpp --- a/llvm/lib/CodeGen/SwitchLoweringUtils.cpp +++ b/llvm/lib/CodeGen/SwitchLoweringUtils.cpp @@ -42,7 +42,9 @@ void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI, - MachineBasicBlock *DefaultMBB) { + MachineBasicBlock *DefaultMBB, + ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI) { #ifndef NDEBUG // Clusters must be non-empty, sorted, and only contain Range clusters. assert(!Clusters.empty()); @@ -80,7 +82,7 @@ assert(Range >= NumCases); // Cheap case: the whole range may be suitable for jump table. - if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) { + if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) { CaseCluster JTCluster; if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) { Clusters[0] = JTCluster; @@ -138,7 +140,7 @@ assert(NumCases < UINT64_MAX / 100); assert(Range >= NumCases); - if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) { + if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) { unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]); unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1]; int64_t NumEntries = j - i + 1;