Index: lib/Transforms/Utils/SimplifyCFG.cpp =================================================================== --- lib/Transforms/Utils/SimplifyCFG.cpp +++ lib/Transforms/Utils/SimplifyCFG.cpp @@ -4882,7 +4882,7 @@ /// Create a lookup table to use as a switch replacement with the contents /// of Values, using DefaultValue to fill any holes in the table. SwitchLookupTable( - Module &M, uint64_t TableSize, ConstantInt *Offset, + Module &M, uint64_t TableSize, const SmallVectorImpl> &Values, Constant *DefaultValue, const DataLayout &DL, const StringRef &FuncName); @@ -4936,7 +4936,7 @@ } // end anonymous namespace SwitchLookupTable::SwitchLookupTable( - Module &M, uint64_t TableSize, ConstantInt *Offset, + Module &M, uint64_t TableSize, const SmallVectorImpl> &Values, Constant *DefaultValue, const DataLayout &DL, const StringRef &FuncName) { assert(Values.size() && "Can't build lookup table without values!"); @@ -4954,7 +4954,7 @@ Constant *CaseRes = Values[I].second; assert(CaseRes->getType() == ValueType); - uint64_t Idx = (CaseVal->getValue() - Offset->getValue()).getLimitedValue(); + uint64_t Idx = CaseVal->getValue().getLimitedValue(); TableContents[Idx] = CaseRes; if (CaseRes != SingleValue) @@ -5128,13 +5128,16 @@ ShouldBuildLookupTable(SwitchInst *SI, uint64_t TableSize, const TargetTransformInfo &TTI, const DataLayout &DL, const SmallDenseMap &ResultTypes) { - if (SI->getNumCases() > TableSize || TableSize >= UINT64_MAX / 10) + if (SI->getNumCases() > TableSize || TableSize >= UINT64_MAX / 24) return false; // TableSize overflowed, or mul below might overflow. bool AllTablesFitInRegister = true; bool HasIllegalType = false; + bool NoBiggerThanI8 = true; + unsigned BiggestTypeSize = 0; for (const auto &I : ResultTypes) { Type *Ty = I.second; + unsigned TySize = DL.getTypeAllocSize(Ty); // Saturate this flag to true. HasIllegalType = HasIllegalType || !TTI.isTypeLegal(Ty); @@ -5144,9 +5147,13 @@ AllTablesFitInRegister && SwitchLookupTable::WouldFitInRegister(DL, TableSize, Ty); - // If both flags saturate, we're done. NOTE: This *only* works with - // saturating flags, and all flags have to saturate first due to the - // non-deterministic behavior of iterating over a dense map. + // Saturate this flag to false. + NoBiggerThanI8 = NoBiggerThanI8 && (TySize == 1); + + if (TySize > BiggestTypeSize) + BiggestTypeSize = TySize; + + // If these two flags saturate, we're done. if (HasIllegalType && !AllTablesFitInRegister) break; } @@ -5159,10 +5166,24 @@ if (HasIllegalType) return false; - // The table density should be at least 40%. This is the same criterion as for - // jump tables, see SelectionDAGBuilder::handleJTSwitchCase. + // If the table only contains i8s or smaller, it has a bounded size of + // 256 times the largest legal size, and will be more performant with a lookup table. + if (NoBiggerThanI8 && !SI->getFunction()->hasOptSize()) + return true; + + // If the table is smaller, always use it + if (TableSize * BiggestTypeSize + 14 < + // Table Size, including empty space, plus header size + SI->getNumCases() * 14) // size of cmp jmp mov ret + return true; + + // Space is more important than performance when using -Os + if (SI->getFunction()->hasOptSize()) + return false; + + // The table density should be at least 33% for 64-bit integers. // FIXME: Find the best cut-off. - return SI->getNumCases() * 10 >= TableSize * 4; + return SI->getNumCases() * 3 * 8 >= (TableSize * BiggestTypeSize); } /// Try to reuse the switch table index compare. Following pattern: @@ -5248,6 +5269,12 @@ } } +// TODO Please move this function up here after commiting. This makes the patch +// more readable. +static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, + const DataLayout &DL, + const TargetTransformInfo &TTI); + /// If the switch is only used to initialize one or more phi nodes in a common /// successor block with different constant values, replace the switch with /// lookup tables. @@ -5293,9 +5320,9 @@ for (SwitchInst::CaseIt E = SI->case_end(); CI != E; ++CI) { ConstantInt *CaseVal = CI->getCaseValue(); - if (CaseVal->getValue().slt(MinCaseVal->getValue())) + if (CaseVal->getValue().ult(MinCaseVal->getValue())) MinCaseVal = CaseVal; - if (CaseVal->getValue().sgt(MaxCaseVal->getValue())) + if (CaseVal->getValue().ugt(MaxCaseVal->getValue())) MaxCaseVal = CaseVal; // Resulting value at phi nodes for this case value. @@ -5347,6 +5374,36 @@ DefaultResults[PHI] = Result; } + // Compute the maximum table size representable by the integer type we are + // switching upon. + unsigned CaseSize = MinCaseVal->getType()->getPrimitiveSizeInBits(); + uint64_t MaxTableSize = CaseSize > 63 ? UINT64_MAX : 1ULL << CaseSize; + assert(MaxTableSize >= TableSize && + "It is impossible for a switch to have more entries than the max " + "representable value of its input integer type's size."); + + // If the table is only a u8 and we do not have to check for the default case, + // extend the table so we can get rid of the branch. + if (MaxTableSize <= 256 && HasDefaultResults && !SI->getFunction()->hasOptSize()) { + TableSize = MaxTableSize; + // Un-rotate and un-xor now that we are covering the whole range, + ConstantInt *CAdd = nullptr, *CXor = nullptr; + Value *V; + if (match(SI->getCondition(), m_c_Add(m_Value(V), m_ConstantInt(CAdd))) || + match(SI->getCondition(), m_Xor(m_Value(V), m_ConstantInt(CXor)))) { + for (auto Case : SI->cases()) { + auto *Orig = Case.getCaseValue(); + auto Sub = CAdd ? Orig->getValue() - CAdd->getValue() : Orig->getValue(); + auto Xor = (CXor ? Sub ^ CXor->getValue() : Sub); + Case.setValue(cast(ConstantInt::get(MinCaseVal->getContext(), Xor))); + } + SI->setCondition(V); + return true; // We will get called again + } + // Call this from in here, because we need the context necessary for this if/else + } else if (ReduceSwitchRange(SI, Builder, DL, TTI)) + return true; // We will get called again + if (!ShouldBuildLookupTable(SI, TableSize, TTI, DL, ResultTypes)) return false; @@ -5355,22 +5412,9 @@ BasicBlock *LookupBB = BasicBlock::Create( Mod.getContext(), "switch.lookup", CommonDest->getParent(), CommonDest); - // Compute the table index value. Builder.SetInsertPoint(SI); Value *TableIndex; - if (MinCaseVal->isNullValue()) - TableIndex = SI->getCondition(); - else - TableIndex = Builder.CreateSub(SI->getCondition(), MinCaseVal, - "switch.tableidx"); - - // Compute the maximum table size representable by the integer type we are - // switching upon. - unsigned CaseSize = MinCaseVal->getType()->getPrimitiveSizeInBits(); - uint64_t MaxTableSize = CaseSize > 63 ? UINT64_MAX : 1ULL << CaseSize; - assert(MaxTableSize >= TableSize && - "It is impossible for a switch to have more entries than the max " - "representable value of its input integer type's size."); + TableIndex = SI->getCondition(); // If the default destination is unreachable, or if the lookup table covers // all values of the conditional variable, branch directly to the lookup table @@ -5445,7 +5489,7 @@ // If using a bitmask, use any value to fill the lookup table holes. Constant *DV = NeedMask ? ResultLists[PHI][0].second : DefaultResults[PHI]; StringRef FuncName = Fn->getName(); - SwitchLookupTable Table(Mod, TableSize, MinCaseVal, ResultList, DV, DL, + SwitchLookupTable Table(Mod, TableSize, ResultList, DV, DL, FuncName); Value *Result = Table.BuildLookup(TableIndex, Builder); @@ -5491,17 +5535,6 @@ return true; } -static bool isSwitchDense(ArrayRef Values) { - // See also SelectionDAGBuilder::isDense(), which this function was based on. - uint64_t Diff = (uint64_t)Values.back() - (uint64_t)Values.front(); - uint64_t Range = Diff + 1; - uint64_t NumCases = Values.size(); - // 40% is the default density for building a jump table in optsize/minsize mode. - uint64_t MinDensity = 40; - - return NumCases * 100 >= Range * MinDensity; -} - /// Try to transform a switch that has "holes" in it to a contiguous sequence /// of cases. /// @@ -5513,58 +5546,103 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, const DataLayout &DL, const TargetTransformInfo &TTI) { + // The number of cases that need to be removed by a subtraction operation + // to make it worth using. + const unsigned SubThreshold = (SI->getFunction()->hasOptSize() ? 2 : 8); + bool MadeChanges = false; auto *CondTy = cast(SI->getCondition()->getType()); - if (CondTy->getIntegerBitWidth() > 64 || - !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth())) - return false; - // Only bother with this optimization if there are more than 3 switch cases; - // SDAG will only bother creating jump tables for 4 or more cases. - if (SI->getNumCases() < 4) + unsigned BitWidth = CondTy->getIntegerBitWidth(); + if (BitWidth > 64 || + !DL.fitsInLegalInteger(BitWidth)) return false; - // This transform is agnostic to the signedness of the input or case values. We - // can treat the case values as signed or unsigned. We can optimize more common - // cases such as a sequence crossing zero {-4,0,4,8} if we interpret case values - // as signed. - SmallVector Values; + SmallVector Values; for (auto &C : SI->cases()) - Values.push_back(C.getCaseValue()->getValue().getSExtValue()); + Values.push_back(C.getCaseValue()->getLimitedValue()); llvm::sort(Values); - // If the switch is already dense, there's nothing useful to do here. - if (isSwitchDense(Values)) - return false; + // A cheap speculative transform: handle power-of-two flags using + // @clz or @ctz as the key function. + // TODO: this transform would benifit from proper range analysis in SwitchToLookupTable + bool UseClz = false; + bool UseCtz = false; + if (BitWidth > 8 || + SI->getFunction()->hasOptSize()) { + unsigned MaxPopCount = 0; + unsigned MinClz = 64, MaxClz = 0, MinCtz = 64, MaxCtz = 0; + for (auto &V : Values) { + MaxPopCount = countPopulation(V); + if (MaxPopCount > 1) + break; + unsigned Clz = APInt(BitWidth, V).countLeadingZeros(); + unsigned Ctz = APInt(BitWidth, V).countTrailingZeros(); + if (Clz < MinClz) MinClz = Clz; + if (Clz > MaxClz) MaxClz = Clz; + if (Ctz < MinCtz) MinCtz = Ctz; + if (Ctz > MaxCtz) MaxCtz = Ctz; + } + // Without this check we might do clz followed by ctz, and if Values contains 0, + // the result might be woorse. + if (MaxPopCount == 1 && Values.back() > 64) { + MadeChanges = true; + // Prefer clz because it is one instruction cheaper + // on ARM, but they cost the same on x86 so if we only need + // the subtraction on clz use ctz. + UseClz = (MinClz < SubThreshold) || (MinCtz >= SubThreshold); + if (UseClz) + for (auto &V : Values) + V = APInt(BitWidth, V).countLeadingZeros(); + else { + UseCtz = true; + for (auto &V : Values) { + V = APInt(BitWidth, V).countTrailingZeros(); + } + } + // 0 will suddenly become the largest (BitWidth), so we need to sort again. + llvm::sort(Values); + } + } - // First, transform the values such that they start at zero and ascend. - int64_t Base = Values[0]; - for (auto &V : Values) - V -= (uint64_t)(Base); + // Find the element that has the most distance from it's previous, wrapping around. + uint64_t BestDistance = APInt::getMaxValue(CondTy->getIntegerBitWidth()).getLimitedValue() - + Values.back() + Values.front() + 1; + unsigned BestIndex = 0; + for (unsigned i = 1;i != Values.size();i++) { + if (Values[i] - Values[i-1] > BestDistance) { + BestIndex = i; + BestDistance = Values[i] - Values[i-1]; + } + } - // Now we have signed numbers that have been shifted so that, given enough - // precision, there are no negative values. Since the rest of the transform - // is bitwise only, we switch now to an unsigned representation. - uint64_t GCD = 0; - for (auto &V : Values) - GCD = GreatestCommonDivisor64(GCD, (uint64_t)V); + uint64_t Base = 0; + // Now transform the values such that they start at zero and ascend. + if ((BestDistance > SubThreshold) && + (BestIndex != 0 || (Values[0] >= SubThreshold))) { + Base = Values[BestIndex]; + MadeChanges = true; + for (auto &V : Values) + V = (APInt(BitWidth, V) - Base).getLimitedValue(); + } // This transform can be done speculatively because it is so cheap - it results - // in a single rotate operation being inserted. This can only happen if the - // factor extracted is a power of 2. - // FIXME: If the GCD is an odd number we can multiply by the multiplicative - // inverse of GCD and then perform this transform. - // FIXME: It's possible that optimizing a switch on powers of two might also - // be beneficial - flag values are often powers of two and we could use a CLZ - // as the key function. - if (GCD <= 1 || !isPowerOf2_64(GCD)) - // No common divisor found or too expensive to compute key function. - return false; - - unsigned Shift = Log2_64(GCD); - for (auto &V : Values) - V = (int64_t)((uint64_t)V >> Shift); + // in a single rotate operation being inserted. + unsigned Shift = 64; + for (auto &V : Values) { + unsigned TZ = countTrailingZeros(V); + if (TZ < Shift) { + Shift = TZ; + if (Shift == 0) + break; + } + } + if (Shift) { + MadeChanges = true; + for (auto &V : Values) + V = V >> Shift; + } - if (!isSwitchDense(Values)) - // Transform didn't create a dense switch. + if (!MadeChanges) + // Didn't do anything return false; // The obvious transform is to shift the switch condition right and emit a @@ -5578,20 +5656,44 @@ // default case. auto *Ty = cast(SI->getCondition()->getType()); - Builder.SetInsertPoint(SI); - auto *ShiftC = ConstantInt::get(Ty, Shift); - auto *Sub = Builder.CreateSub(SI->getCondition(), ConstantInt::get(Ty, Base)); - auto *LShr = Builder.CreateLShr(Sub, ShiftC); - auto *Shl = Builder.CreateShl(Sub, Ty->getBitWidth() - Shift); - auto *Rot = Builder.CreateOr(LShr, Shl); - SI->replaceUsesOfWith(SI->getCondition(), Rot); + { + auto Zero = ConstantInt::get(IntegerType::get(Ty->getContext(), 1), 0); + Builder.SetInsertPoint(SI); + Value *ZerosTransform; + if (UseClz) { + Function *Ctlz = Intrinsic::getDeclaration(SI->getModule(), Intrinsic::ctlz, Ty); + ZerosTransform = Builder.Insert(CallInst::Create(Ctlz, {SI->getCondition(), Zero})); + } else if (UseCtz) { + Function *Cttz = Intrinsic::getDeclaration(SI->getModule(), Intrinsic::cttz, Ty); + ZerosTransform = Builder.Insert(CallInst::Create(Cttz, {SI->getCondition(), Zero})); + } else + ZerosTransform = SI->getCondition(); + + auto *Sub = Builder.CreateSub(ZerosTransform, ConstantInt::get(Ty, Base)); + Value *Key; + if (Shift) { + Function *Fshr = Intrinsic::getDeclaration(SI->getModule(), Intrinsic::fshr, Ty); + auto *ShiftC = ConstantInt::get(Ty, Shift); + Key = Builder.Insert(CallInst::Create(Fshr, {Sub, Sub, ShiftC})); + } else + Key = Sub; + SI->replaceUsesOfWith(SI->getCondition(), Key); + } for (auto Case : SI->cases()) { auto *Orig = Case.getCaseValue(); - auto Sub = Orig->getValue() - APInt(Ty->getBitWidth(), Base); + uint64_t Zeros; + if (UseClz) + Zeros = Orig->getValue().countLeadingZeros(); + else if (UseCtz) + Zeros = Orig->getValue().countTrailingZeros(); + else + Zeros = Orig->getValue().getLimitedValue(); + auto Sub = (APInt(BitWidth, Zeros) - Base).getLimitedValue(); Case.setValue( - cast(ConstantInt::get(Ty, Sub.lshr(ShiftC->getValue())))); + cast(ConstantInt::get(Ty, APInt(BitWidth, Sub >> Shift)))); } + return true; } @@ -5640,6 +5742,7 @@ SwitchToLookupTable(SI, Builder, DL, TTI)) return requestResimplify(); + // This is also called within SwitchToLookupTable if (ReduceSwitchRange(SI, Builder, DL, TTI)) return requestResimplify(); Index: test/Transforms/SimplifyCFG/switch-genfori8.ll =================================================================== --- /dev/null +++ test/Transforms/SimplifyCFG/switch-genfori8.ll @@ -0,0 +1,99 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -simplifycfg -O2 < %s | FileCheck %s +; Using a Zig driver https://gist.github.com/shawnl/8137f62f7dbcfd539f6cf1925387cd38 +;after-patch: 509.8MiB/sec, checksum: 2394975081 +;before-patch: 205.4MiB/sec, checksum: 2394975081 + +; ModuleID = 'chartodigit.c' +source_filename = "chartodigit.c" +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-pc-linux-gnu" + +; Function Attrs: norecurse nounwind readnone uwtable +define dso_local zeroext i8 @char_to_digit(i8 zeroext) local_unnamed_addr #0 { +; CHECK-LABEL: @char_to_digit( +; CHECK-NEXT: switch.lookup: +; CHECK-NEXT: [[TMP1:%.*]] = zext i8 [[TMP0:%.*]] to i64 +; CHECK-NEXT: [[SWITCH_GEP:%.*]] = getelementptr inbounds [256 x i8], [256 x i8]* @switch.table.char_to_digit, i64 0, i64 [[TMP1]] +; CHECK-NEXT: [[SWITCH_LOAD:%.*]] = load i8, i8* [[SWITCH_GEP]], align 1 +; CHECK-NEXT: ret i8 [[SWITCH_LOAD]] +; + switch i8 %0, label %17 [ + i8 48, label %18 + i8 49, label %2 + i8 50, label %3 + i8 51, label %4 + i8 52, label %5 + i8 53, label %6 + i8 54, label %7 + i8 55, label %8 + i8 56, label %9 + i8 57, label %10 + i8 97, label %11 + i8 98, label %12 + i8 99, label %13 + i8 100, label %14 + i8 101, label %15 + i8 102, label %16 + ] + +;