Index: lib/Transforms/Utils/SimplifyCFG.cpp =================================================================== --- lib/Transforms/Utils/SimplifyCFG.cpp +++ lib/Transforms/Utils/SimplifyCFG.cpp @@ -4882,7 +4882,7 @@ /// Create a lookup table to use as a switch replacement with the contents /// of Values, using DefaultValue to fill any holes in the table. SwitchLookupTable( - Module &M, uint64_t TableSize, ConstantInt *Offset, + Module &M, uint64_t TableSize, const SmallVectorImpl> &Values, Constant *DefaultValue, const DataLayout &DL, const StringRef &FuncName); @@ -4936,7 +4936,7 @@ } // end anonymous namespace SwitchLookupTable::SwitchLookupTable( - Module &M, uint64_t TableSize, ConstantInt *Offset, + Module &M, uint64_t TableSize, const SmallVectorImpl> &Values, Constant *DefaultValue, const DataLayout &DL, const StringRef &FuncName) { assert(Values.size() && "Can't build lookup table without values!"); @@ -4954,7 +4954,7 @@ Constant *CaseRes = Values[I].second; assert(CaseRes->getType() == ValueType); - uint64_t Idx = (CaseVal->getValue() - Offset->getValue()).getLimitedValue(); + uint64_t Idx = CaseVal->getValue().getLimitedValue(); TableContents[Idx] = CaseRes; if (CaseRes != SingleValue) @@ -5279,7 +5279,6 @@ // common destination, as well as the min and max case values. assert(!empty(SI->cases())); SwitchInst::CaseIt CI = SI->case_begin(); - ConstantInt *MinCaseVal = CI->getCaseValue(); ConstantInt *MaxCaseVal = CI->getCaseValue(); BasicBlock *CommonDest = nullptr; @@ -5293,9 +5292,7 @@ for (SwitchInst::CaseIt E = SI->case_end(); CI != E; ++CI) { ConstantInt *CaseVal = CI->getCaseValue(); - if (CaseVal->getValue().slt(MinCaseVal->getValue())) - MinCaseVal = CaseVal; - if (CaseVal->getValue().sgt(MaxCaseVal->getValue())) + if (CaseVal->getValue().ugt(MaxCaseVal->getValue())) MaxCaseVal = CaseVal; // Resulting value at phi nodes for this case value. @@ -5321,8 +5318,7 @@ } uint64_t NumResults = ResultLists[PHIs[0]].size(); - APInt RangeSpread = MaxCaseVal->getValue() - MinCaseVal->getValue(); - uint64_t TableSize = RangeSpread.getLimitedValue() + 1; + uint64_t TableSize = MaxCaseVal->getValue().getLimitedValue() + 1; bool TableHasHoles = (NumResults < TableSize); // If the table has holes, we need a constant result for the default case @@ -5357,16 +5353,11 @@ // Compute the table index value. Builder.SetInsertPoint(SI); - Value *TableIndex; - if (MinCaseVal->isNullValue()) - TableIndex = SI->getCondition(); - else - TableIndex = Builder.CreateSub(SI->getCondition(), MinCaseVal, - "switch.tableidx"); + Value *TableIndex = SI->getCondition(); // Compute the maximum table size representable by the integer type we are // switching upon. - unsigned CaseSize = MinCaseVal->getType()->getPrimitiveSizeInBits(); + unsigned CaseSize = MaxCaseVal->getType()->getPrimitiveSizeInBits(); uint64_t MaxTableSize = CaseSize > 63 ? UINT64_MAX : 1ULL << CaseSize; assert(MaxTableSize >= TableSize && "It is impossible for a switch to have more entries than the max " @@ -5386,7 +5377,7 @@ // PHI value for the default case in case we're using a bit mask. } else { Value *Cmp = Builder.CreateICmpULT( - TableIndex, ConstantInt::get(MinCaseVal->getType(), TableSize)); + TableIndex, ConstantInt::get(MaxCaseVal->getType(), TableSize)); RangeCheckBranch = Builder.CreateCondBr(Cmp, LookupBB, SI->getDefaultDest()); } @@ -5395,40 +5386,8 @@ Builder.SetInsertPoint(LookupBB); if (NeedMask) { - // Before doing the lookup, we do the hole check. The LookupBB is therefore - // re-purposed to do the hole check, and we create a new LookupBB. - BasicBlock *MaskBB = LookupBB; - MaskBB->setName("switch.hole_check"); - LookupBB = BasicBlock::Create(Mod.getContext(), "switch.lookup", - CommonDest->getParent(), CommonDest); - - // Make the mask's bitwidth at least 8-bit and a power-of-2 to avoid - // unnecessary illegal types. - uint64_t TableSizePowOf2 = NextPowerOf2(std::max(7ULL, TableSize - 1ULL)); - APInt MaskInt(TableSizePowOf2, 0); - APInt One(TableSizePowOf2, 1); - // Build bitmask; fill in a 1 bit for every case. - const ResultListTy &ResultList = ResultLists[PHIs[0]]; - for (size_t I = 0, E = ResultList.size(); I != E; ++I) { - uint64_t Idx = (ResultList[I].first->getValue() - MinCaseVal->getValue()) - .getLimitedValue(); - MaskInt |= One << Idx; - } - ConstantInt *TableMask = ConstantInt::get(Mod.getContext(), MaskInt); - - // Get the TableIndex'th bit of the bitmask. - // If this bit is 0 (meaning hole) jump to the default destination, - // else continue with table lookup. - IntegerType *MapTy = TableMask->getType(); - Value *MaskIndex = - Builder.CreateZExtOrTrunc(TableIndex, MapTy, "switch.maskindex"); - Value *Shifted = Builder.CreateLShr(TableMask, MaskIndex, "switch.shifted"); - Value *LoBit = Builder.CreateTrunc( - Shifted, Type::getInt1Ty(Mod.getContext()), "switch.lobit"); - Builder.CreateCondBr(LoBit, LookupBB, SI->getDefaultDest()); - - Builder.SetInsertPoint(LookupBB); - AddPredecessorToBlock(SI->getDefaultDest(), MaskBB, SI->getParent()); + // Re-written in a later patch + return false; } if (!DefaultIsReachable || GeneratingCoveredLookupTable) { @@ -5445,7 +5404,7 @@ // If using a bitmask, use any value to fill the lookup table holes. Constant *DV = NeedMask ? ResultLists[PHI][0].second : DefaultResults[PHI]; StringRef FuncName = Fn->getName(); - SwitchLookupTable Table(Mod, TableSize, MinCaseVal, ResultList, DV, DL, + SwitchLookupTable Table(Mod, TableSize, ResultList, DV, DL, FuncName); Value *Result = Table.BuildLookup(TableIndex, Builder); @@ -5491,17 +5450,6 @@ return true; } -static bool isSwitchDense(ArrayRef Values) { - // See also SelectionDAGBuilder::isDense(), which this function was based on. - uint64_t Diff = (uint64_t)Values.back() - (uint64_t)Values.front(); - uint64_t Range = Diff + 1; - uint64_t NumCases = Values.size(); - // 40% is the default density for building a jump table in optsize/minsize mode. - uint64_t MinDensity = 40; - - return NumCases * 100 >= Range * MinDensity; -} - /// Try to transform a switch that has "holes" in it to a contiguous sequence /// of cases. /// @@ -5513,58 +5461,107 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, const DataLayout &DL, const TargetTransformInfo &TTI) { + // The number of cases that need to be removed by a subtraction operation + // to make it worth using. + const unsigned SubThreshold = (SI->getFunction()->hasOptSize() ? 2 : 8); + bool MadeChanges = false; auto *CondTy = cast(SI->getCondition()->getType()); - if (CondTy->getIntegerBitWidth() > 64 || - !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth())) - return false; - // Only bother with this optimization if there are more than 3 switch cases; - // SDAG will only bother creating jump tables for 4 or more cases. - if (SI->getNumCases() < 4) + unsigned BitWidth = CondTy->getIntegerBitWidth(); + if (BitWidth > 64 || + !DL.fitsInLegalInteger(BitWidth)) return false; - // This transform is agnostic to the signedness of the input or case values. We - // can treat the case values as signed or unsigned. We can optimize more common - // cases such as a sequence crossing zero {-4,0,4,8} if we interpret case values - // as signed. - SmallVector Values; + SmallVector Values; for (auto &C : SI->cases()) - Values.push_back(C.getCaseValue()->getValue().getSExtValue()); + Values.push_back(C.getCaseValue()->getLimitedValue()); llvm::sort(Values); - // If the switch is already dense, there's nothing useful to do here. - if (isSwitchDense(Values)) - return false; + // A cheap speculative transform: handle power-of-two flags using + // @clz or @ctz as the key function. + // TODO: this transform would benifit from proper range analysis in SwitchToLookupTable + // BitWidth > 8 could be relaxed after this is fixed, but not eliminated, because + // this transforms 0 => BitWidth. + bool UseClz = false; + bool UseCtz = false; + if (BitWidth > 8) { + unsigned MaxPopCount = 0; + unsigned MinClz = 64, MaxClz = 0, MinCtz = 64, MaxCtz = 0; + for (auto &V : Values) { + MaxPopCount = countPopulation(V); + if (MaxPopCount > 1) + break; + unsigned Clz = APInt(BitWidth, V).countLeadingZeros(); + unsigned Ctz = APInt(BitWidth, V).countTrailingZeros(); + if (Clz < MinClz) MinClz = Clz; + if (Clz > MaxClz) MaxClz = Clz; + if (Ctz < MinCtz) MinCtz = Ctz; + if (Ctz > MaxCtz) MaxCtz = Ctz; + } + // Without this check we might do clz followed by ctz, and if Values contains 0, + // the result might be woorse. + if (MaxPopCount == 1 && Values.back() > 64) { + MadeChanges = true; + // Prefer clz because it is one instruction cheaper + // on ARM, but they cost the same on x86 so if we only need + // the subtraction on clz use ctz. + UseClz = (MinClz < SubThreshold) || (MinCtz >= SubThreshold); + if (UseClz) + for (auto &V : Values) + V = APInt(BitWidth, V).countLeadingZeros(); + else { + UseCtz = true; + for (auto &V : Values) { + V = APInt(BitWidth, V).countTrailingZeros(); + } + } + // 0 will suddenly become the largest (BitWidth), so we need to sort again. + llvm::sort(Values); + } + } - // First, transform the values such that they start at zero and ascend. - int64_t Base = Values[0]; - for (auto &V : Values) - V -= (uint64_t)(Base); + // Find the element that has the most distance from it's previous, wrapping around. + uint64_t BestDistance = APInt::getMaxValue(CondTy->getIntegerBitWidth()).getLimitedValue() - + Values.back() + Values.front() + 1; + unsigned BestIndex = 0; + for (unsigned i = 1;i != Values.size();i++) { + if (Values[i] - Values[i-1] > BestDistance) { + BestIndex = i; + BestDistance = Values[i] - Values[i-1]; + } + } - // Now we have signed numbers that have been shifted so that, given enough - // precision, there are no negative values. Since the rest of the transform - // is bitwise only, we switch now to an unsigned representation. - uint64_t GCD = 0; - for (auto &V : Values) - GCD = GreatestCommonDivisor64(GCD, (uint64_t)V); + uint64_t Base = 0; + // Now transform the values such that they start at zero and ascend. + if ((BestDistance > SubThreshold) && + (BestIndex != 0 || (Values[0] >= SubThreshold))) { + Base = Values[BestIndex]; + MadeChanges = true; + for (auto &V : Values) + V = (APInt(BitWidth, V) - Base).getLimitedValue(); + } // This transform can be done speculatively because it is so cheap - it results - // in a single rotate operation being inserted. This can only happen if the - // factor extracted is a power of 2. - // FIXME: If the GCD is an odd number we can multiply by the multiplicative - // inverse of GCD and then perform this transform. - // FIXME: It's possible that optimizing a switch on powers of two might also - // be beneficial - flag values are often powers of two and we could use a CLZ - // as the key function. - if (GCD <= 1 || !isPowerOf2_64(GCD)) - // No common divisor found or too expensive to compute key function. - return false; - - unsigned Shift = Log2_64(GCD); - for (auto &V : Values) - V = (int64_t)((uint64_t)V >> Shift); + // in a single rotate operation being inserted. + unsigned Shift = 64; + for (auto &V : Values) { + // There is no edge condition when the BitWidth is less than 64, because if + // 0 is the only value then a shift does nothing, and LLVM requires + // well-formed IR to not have duplicate cases. + unsigned TZ = countTrailingZeros(V); + if (TZ < Shift) { + Shift = TZ; + if (Shift == 0) + break; + } + } + if (Shift) { + MadeChanges = true; + for (auto &V : Values) + V = V >> Shift; + } - if (!isSwitchDense(Values)) - // Transform didn't create a dense switch. + if (!MadeChanges) + // Didn't do anything return false; // The obvious transform is to shift the switch condition right and emit a @@ -5578,20 +5575,44 @@ // default case. auto *Ty = cast(SI->getCondition()->getType()); - Builder.SetInsertPoint(SI); - auto *ShiftC = ConstantInt::get(Ty, Shift); - auto *Sub = Builder.CreateSub(SI->getCondition(), ConstantInt::get(Ty, Base)); - auto *LShr = Builder.CreateLShr(Sub, ShiftC); - auto *Shl = Builder.CreateShl(Sub, Ty->getBitWidth() - Shift); - auto *Rot = Builder.CreateOr(LShr, Shl); - SI->replaceUsesOfWith(SI->getCondition(), Rot); + { + auto Zero = ConstantInt::get(IntegerType::get(Ty->getContext(), 1), 0); + Builder.SetInsertPoint(SI); + Value *ZerosTransform; + if (UseClz) { + Function *Ctlz = Intrinsic::getDeclaration(SI->getModule(), Intrinsic::ctlz, Ty); + ZerosTransform = Builder.Insert(CallInst::Create(Ctlz, {SI->getCondition(), Zero})); + } else if (UseCtz) { + Function *Cttz = Intrinsic::getDeclaration(SI->getModule(), Intrinsic::cttz, Ty); + ZerosTransform = Builder.Insert(CallInst::Create(Cttz, {SI->getCondition(), Zero})); + } else + ZerosTransform = SI->getCondition(); + + auto *Sub = Builder.CreateSub(ZerosTransform, ConstantInt::get(Ty, Base)); + Value *Key; + if (Shift) { + Function *Fshr = Intrinsic::getDeclaration(SI->getModule(), Intrinsic::fshr, Ty); + auto *ShiftC = ConstantInt::get(Ty, Shift); + Key = Builder.Insert(CallInst::Create(Fshr, {Sub, Sub, ShiftC})); + } else + Key = Sub; + SI->replaceUsesOfWith(SI->getCondition(), Key); + } for (auto Case : SI->cases()) { auto *Orig = Case.getCaseValue(); - auto Sub = Orig->getValue() - APInt(Ty->getBitWidth(), Base); + uint64_t Zeros; + if (UseClz) + Zeros = Orig->getValue().countLeadingZeros(); + else if (UseCtz) + Zeros = Orig->getValue().countTrailingZeros(); + else + Zeros = Orig->getValue().getLimitedValue(); + auto Sub = (APInt(BitWidth, Zeros) - Base).getLimitedValue(); Case.setValue( - cast(ConstantInt::get(Ty, Sub.lshr(ShiftC->getValue())))); + cast(ConstantInt::get(Ty, APInt(BitWidth, Sub >> Shift)))); } + return true; } @@ -5631,6 +5652,9 @@ if (Options.ForwardSwitchCondToPhi && ForwardSwitchConditionToPHI(SI)) return requestResimplify(); + if (ReduceSwitchRange(SI, Builder, DL, TTI)) + return requestResimplify(); + // The conversion from switch to lookup tables results in difficult-to-analyze // code and makes pruning branches much harder. This is a problem if the // switch expression itself can still be restricted as a result of inlining or @@ -5640,9 +5664,6 @@ SwitchToLookupTable(SI, Builder, DL, TTI)) return requestResimplify(); - if (ReduceSwitchRange(SI, Builder, DL, TTI)) - return requestResimplify(); - return false; } Index: test/Transforms/SimplifyCFG/switch-simplify-range.ll =================================================================== --- /dev/null +++ test/Transforms/SimplifyCFG/switch-simplify-range.ll @@ -0,0 +1,155 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -passes='simplify-cfg' < %s | FileCheck %s +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-pc-linux-gnu" + +attributes #0 = { "no-jump-tables"="false" } + +define i64 @switch_common_right_bits(i8 %a) #0 { +; CHECK-LABEL: @switch_common_right_bits( +; CHECK-NEXT: Entry: +; CHECK-NEXT: [[TMP0:%.*]] = add i8 [[A:%.*]], 3 +; CHECK-NEXT: [[TMP1:%.*]] = tail call i8 @llvm.fshl.i8(i8 [[TMP0]], i8 [[TMP0]], i8 7) +; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i8 [[TMP1]], 4 +; CHECK-NEXT: br i1 [[TMP2]], label [[SWITCH_LOOKUP:%.*]], label [[SWITCHELSE:%.*]] +; CHECK: switch.lookup: +; CHECK-NEXT: [[TMP3:%.*]] = sext i8 [[TMP1]] to i64 +; CHECK-NEXT: [[SWITCH_GEP:%.*]] = getelementptr inbounds [4 x i64], [4 x i64]* @switch.table.switch_common_right_bits, i64 0, i64 [[TMP3]] +; CHECK-NEXT: [[SWITCH_LOAD:%.*]] = load i64, i64* [[SWITCH_GEP]], align 8 +; CHECK-NEXT: ret i64 [[SWITCH_LOAD]] +; CHECK: SwitchElse: +; CHECK-NEXT: ret i64 10 +; +Entry: + switch i8 %a, label %SwitchElse [ + i8 253, label %SwitchProng + i8 255, label %SwitchProng1 + i8 1, label %SwitchProng2 + i8 3, label %SwitchProng3 + ] +SwitchElse: ; preds = %Entry + ret i64 10 +SwitchProng: ; preds = %Entry + ret i64 6 +SwitchProng1: ; preds = %Entry + ret i64 3 +SwitchProng2: ; preds = %Entry + ret i64 3 +SwitchProng3: ; preds = %Entry + ret i64 3 +} + +define i64 @switch_ctz(i16 %a) optsize #0 { +; CHECK-LABEL: @switch_ctz( +; CHECK-NEXT: Entry: +; CHECK-NEXT: [[TMP0:%.*]] = tail call i16 @llvm.cttz.i16(i16 [[A:%.*]], i1 false), !range !0 +; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i16 [[TMP0]], 8 +; CHECK-NEXT: br i1 [[TMP1]], label [[SWITCH_LOOKUP:%.*]], label [[SWITCHELSE:%.*]] +; CHECK: switch.lookup: +; CHECK-NEXT: [[TMP2:%.*]] = zext i16 [[TMP0]] to i64 +; CHECK-NEXT: [[SWITCH_GEP:%.*]] = getelementptr inbounds [8 x i64], [8 x i64]* @switch.table.switch_ctz, i64 0, i64 [[TMP2]] +; CHECK-NEXT: [[SWITCH_LOAD:%.*]] = load i64, i64* [[SWITCH_GEP]], align 8 +; CHECK-NEXT: ret i64 [[SWITCH_LOAD]] +; CHECK: SwitchElse: +; CHECK-NEXT: ret i64 10 +; +Entry: + switch i16 %a, label %SwitchElse [ + i16 2, label %SwitchProng + i16 4, label %SwitchProng1 + i16 8, label %SwitchProng2 + i16 1, label %SwitchProng3 + i16 64, label %SwitchProng6 + i16 128, label %SwitchProng7 + i16 16, label %SwitchProng5 + i16 32, label %SwitchProng4 + ] +SwitchElse: ; preds = %Entry + ret i64 10 +SwitchProng: ; preds = %Entry + ret i64 6 +SwitchProng1: ; preds = %Entry + ret i64 3 +SwitchProng2: ; preds = %Entry + ret i64 35 +SwitchProng3: ; preds = %Entry + ret i64 31 +SwitchProng4: ; preds = %Entry + ret i64 53 +SwitchProng5: ; preds = %Entry + ret i64 51 +SwitchProng6: ; preds = %Entry + ret i64 41 +SwitchProng7: ; preds = %Entry + ret i64 34 +} + +define i64 @switch_clz(i8 %a) optsize #0 { +; CHECK-LABEL: @switch_clz( +; CHECK-NEXT: Entry: +; CHECK-NEXT: [[TMP0:%.*]] = tail call i8 @llvm.fshl.i8(i8 [[A:%.*]], i8 [[A]], i8 3) +; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i8 [[TMP0]], 5 +; CHECK-NEXT: br i1 [[TMP1]], label [[SWITCH_LOOKUP:%.*]], label [[SWITCHPRONG2:%.*]] +; CHECK: switch.lookup: +; CHECK-NEXT: [[TMP2:%.*]] = sext i8 [[TMP0]] to i64 +; CHECK-NEXT: [[SWITCH_GEP:%.*]] = getelementptr inbounds [5 x i64], [5 x i64]* @switch.table.switch_clz, i64 0, i64 [[TMP2]] +; CHECK-NEXT: [[SWITCH_LOAD:%.*]] = load i64, i64* [[SWITCH_GEP]], align 8 +; CHECK-NEXT: ret i64 [[SWITCH_LOAD]] +; CHECK: SwitchProng2: +; CHECK-NEXT: ret i64 12 +; +Entry: + switch i8 %a, label %SwitchElse [ + i8 128, label %SwitchProng2 + i8 64, label %SwitchProng3 + i8 32, label %SwitchProng6 + i8 0, label %SwitchProng5 + ] +SwitchProng2: ; preds = %Entry + ret i64 35 +SwitchProng3: ; preds = %Entry + ret i64 31 +SwitchProng6: ; preds = %Entry + ret i64 41 +SwitchProng5: ; preds = %Entry + ret i64 40 +SwitchElse: ; preds = %Entry + ret i64 12 +} + +;Must check that the default was filled in at index 0 as happened here: +;@switch.table.switch_not_normalized_to_start_at_zero = private unnamed_addr constant [6 x i16] [i16 10, i16 7, i16 3, i16 1, i16 6, i16 8], align 2 +define i16 @switch_not_normalized_to_start_at_zero(i8 %a) #0 { +; CHECK-LABEL: @switch_not_normalized_to_start_at_zero( +; CHECK-NEXT: Entry: +; CHECK-NEXT: [[TMP0:%.*]] = icmp ult i8 [[A:%.*]], 6 +; CHECK-NEXT: br i1 [[TMP0]], label [[SWITCH_LOOKUP:%.*]], label [[SWITCHELSE:%.*]] +; CHECK: switch.lookup: +; CHECK-NEXT: [[TMP1:%.*]] = sext i8 [[A]] to i64 +; CHECK-NEXT: [[SWITCH_GEP:%.*]] = getelementptr inbounds [6 x i16], [6 x i16]* @switch.table.switch_not_normalized_to_start_at_zero, i64 0, i64 [[TMP1]] +; CHECK-NEXT: [[SWITCH_LOAD:%.*]] = load i16, i16* [[SWITCH_GEP]], align 2 +; CHECK-NEXT: ret i16 [[SWITCH_LOAD]] +; CHECK: SwitchElse: +; CHECK-NEXT: ret i16 10 +; +Entry: + switch i8 %a, label %SwitchElse [ + i8 4, label %SwitchProng + i8 2, label %SwitchProng1 + i8 1, label %SwitchProng2 + i8 3, label %SwitchProng3 + i8 5, label %SwitchProng4 + ] +SwitchElse: ; preds = %Entry + ret i16 10 +SwitchProng: ; preds = %Entry + ret i16 6 +SwitchProng1: ; preds = %Entry + ret i16 3 +SwitchProng2: ; preds = %Entry + ret i16 7 +SwitchProng3: ; preds = %Entry + ret i16 1 +SwitchProng4: ; preds = %Entry + ret i16 8 +}