Index: lib/Transforms/Utils/SimplifyCFG.cpp =================================================================== --- lib/Transforms/Utils/SimplifyCFG.cpp +++ lib/Transforms/Utils/SimplifyCFG.cpp @@ -5492,6 +5492,77 @@ return true; } +/// A cheap transform: investigate using ctlz and/or cttz as a key to the switch +/// FIXME: this transform would benifit from proper range analysis in +/// SwitchToLookupTable +static bool ReduceSwitchRangeWithCtlzOrCttz(SwitchInst &SI, + SmallVectorImpl &Values, + bool &UseCtlz, bool &UseCttz, + bool &UseInvert) { + unsigned BitWidth = SI.getCondition()->getType()->getIntegerBitWidth(); + + // Avoid applying this transform twice. + if (Values.back() == BitWidth) + return false; + + // This case is not worth it, as it (usually) requires an additional branch + // to check if popcount is 1, and the number of cases is limited to 64, so + // it is already fast. + uint64_t MaxTableSize = BitWidth > 63 ? UINT64_MAX : 1ULL << BitWidth; + const bool DefaultIsReachable = + !isa(SI.getDefaultDest()->getFirstNonPHIOrDbg()); + const bool GeneratingCoveredLookupTable = (MaxTableSize == SI.getNumCases()); + if (!DefaultIsReachable && !GeneratingCoveredLookupTable) + return false; + + bool Invert = false; + bool Cttz = false; + while (true) { + bool GotCollision = false; + uint64_t Got = 0; + APInt Prev; + for (auto &V : Values) { + APInt Int = APInt(BitWidth, Invert ? ~V : V); + // No need to correct for the bit-width here, as we are just checking for + // collisions. + unsigned Ctz = + (Cttz ? Int.countTrailingZeros() : Int.countLeadingZeros()); + std::array Prev; + if (Got & (1ULL << Ctz)) { + if (&*SI.findCaseValue(ConstantInt::get( + SI.getContext(), Invert ? ~Int : Int)) == Prev[Ctz]) + continue; + GotCollision = true; + break; + } + Got |= (1ULL << Ctz); + Prev[Ctz] = &*SI.findCaseValue( + ConstantInt::get(SI.getContext(), Invert ? ~Int : Int)); + } + if (!GotCollision) { + for (auto &V : Values) + V = Cttz ? APInt(BitWidth, Invert ? ~V : V).countTrailingZeros() + : APInt(BitWidth, Invert ? ~V : V).countLeadingZeros(); + if (Cttz) + UseCttz = true; + else + UseCtlz = true; + UseInvert = Invert; + llvm::sort(Values); + return true; + } + // We do this loop up to 4 times, trying differn't parameters + if (!Cttz) { + Cttz = true; + } else if (!Invert) { + Cttz = false; + Invert = true; + } else + break; + }; + return false; +} + /// Try to transform a switch that has "holes" in it to a contiguous sequence /// of cases. /// @@ -5524,11 +5595,17 @@ llvm::sort(Values); bool MadeChanges = false; + + bool UseCtlz = false, UseCttz = false, UseInvert = false; + if (ReduceSwitchRangeWithCtlzOrCttz(*SI, Values, UseCtlz, UseCttz, UseInvert)) + MadeChanges = true; + // We must first look find the best start point, for example if we have a // series that crosses zero: -2, -1, 0, 1, 2. uint64_t BestDistance = APInt::getMaxValue(CondTy->getIntegerBitWidth()).getLimitedValue() - Values.back() + Values.front() + 1; + unsigned BestIndex = 0; for (unsigned I = 1; I != Values.size(); I++) { if (Values[I] - Values[I - 1] > BestDistance) { @@ -5551,9 +5628,6 @@ // This transform can be done speculatively because it is so cheap - it // results in a single rotate operation being inserted. - // FIXME: It's possible that optimizing a switch on powers of two might also - // be beneficial - flag values are often powers of two and we could use a CLZ - // as the key function. // countTrailingZeros(0) returns 64. As Values is guaranteed to have more than // one element and LLVM disallows duplicate cases, Shift is guaranteed to be @@ -5584,9 +5658,27 @@ auto *Ty = cast(SI->getCondition()->getType()); Builder.SetInsertPoint(SI); + auto *ZeroNotUndef = Builder.getInt1(false); Value *Key = SI->getCondition(); if (Base > 0) - Key = Builder.CreateSub(Key, ConstantInt::get(Ty, Base)); + Key = Builder.CreateSub(Key, ConstantInt::get(Ty, Base), + "switch.rangereduce"); + if (UseInvert) { + auto *AllOnes = ConstantInt::getAllOnesValue(Ty); + Key = Builder.CreateXor(Key, AllOnes, "switch.rangereduce"); + } + if (UseCtlz) { + Function *Ctlz = + Intrinsic::getDeclaration(SI->getModule(), Intrinsic::ctlz, Ty); + Key = Builder.CreateCall(Ctlz, {Key, ZeroNotUndef}, "switch.rangereduce"); + } else if (UseCttz) { + Function *Cttz = + Intrinsic::getDeclaration(SI->getModule(), Intrinsic::cttz, Ty); + Key = Builder.CreateCall(Cttz, {Key, ZeroNotUndef}, "switch.rangereduce"); + } + if (Base > 0) + Key = Builder.CreateSub(Key, ConstantInt::get(Ty, Base), + "switch.rangereduce"); if (Shift > 0) { // FIXME replace with fshr? auto *ShiftC = ConstantInt::get(Ty, Shift); @@ -5598,7 +5690,14 @@ for (auto Case : SI->cases()) { auto *Orig = Case.getCaseValue(); - auto Sub = Orig->getValue() - APInt(Ty->getBitWidth(), Base); + uint64_t Zeros; + if (UseCtlz) + Zeros = Orig->getValue().countLeadingZeros(); + else if (UseCttz) + Zeros = Orig->getValue().countTrailingZeros(); + else + Zeros = Orig->getValue().getLimitedValue(); + auto Sub = (APInt(BitWidth, Zeros) - Base); Case.setValue(cast(ConstantInt::get(Ty, Sub.lshr(Shift)))); } return true; Index: test/Transforms/SimplifyCFG/switch-simplify-range.ll =================================================================== --- test/Transforms/SimplifyCFG/switch-simplify-range.ll +++ test/Transforms/SimplifyCFG/switch-simplify-range.ll @@ -8,12 +8,12 @@ define i64 @switch_common_right_bits(i8 %a) #0 { ; CHECK-LABEL: @switch_common_right_bits( ; CHECK-NEXT: Entry: -; CHECK-NEXT: [[TMP0:%.*]] = sub i8 [[A:%.*]], 123 -; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.fshr.i8(i8 [[TMP0]], i8 [[TMP0]], i8 1) -; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i8 [[TMP1]], 5 -; CHECK-NEXT: br i1 [[TMP2]], label [[SWITCH_LOOKUP:%.*]], label [[SWITCHELSE:%.*]] +; CHECK-NEXT: [[SWITCH_RANGEREDUCE:%.*]] = sub i8 [[A:%.*]], 123 +; CHECK-NEXT: [[SWITCH_RANGEREDUCE1:%.*]] = call i8 @llvm.fshr.i8(i8 [[SWITCH_RANGEREDUCE]], i8 [[SWITCH_RANGEREDUCE]], i8 1) +; CHECK-NEXT: [[TMP0:%.*]] = icmp ult i8 [[SWITCH_RANGEREDUCE1]], 5 +; CHECK-NEXT: br i1 [[TMP0]], label [[SWITCH_LOOKUP:%.*]], label [[SWITCHELSE:%.*]] ; CHECK: switch.lookup: -; CHECK-NEXT: [[SWITCH_GEP:%.*]] = getelementptr inbounds [5 x i64], [5 x i64]* @switch.table.switch_common_right_bits, i32 0, i8 [[TMP1]] +; CHECK-NEXT: [[SWITCH_GEP:%.*]] = getelementptr inbounds [5 x i64], [5 x i64]* @switch.table.switch_common_right_bits, i32 0, i8 [[SWITCH_RANGEREDUCE1]] ; CHECK-NEXT: [[SWITCH_LOAD:%.*]] = load i64, i64* [[SWITCH_GEP]] ; CHECK-NEXT: ret i64 [[SWITCH_LOAD]] ; CHECK: SwitchElse: @@ -44,35 +44,16 @@ define i64 @switch_clz1(i16 %a) optsize #0 { ; CHECK-LABEL: @switch_clz1( ; CHECK-NEXT: Entry: -; CHECK-NEXT: switch i16 [[A:%.*]], label [[SWITCHELSE:%.*]] [ -; CHECK-NEXT: i16 2, label [[SWITCHPRONG:%.*]] -; CHECK-NEXT: i16 4, label [[SWITCHPRONG1:%.*]] -; CHECK-NEXT: i16 8, label [[SWITCHPRONG2:%.*]] -; CHECK-NEXT: i16 1, label [[SWITCHPRONG3:%.*]] -; CHECK-NEXT: i16 64, label [[SWITCHPRONG6:%.*]] -; CHECK-NEXT: i16 128, label [[SWITCHPRONG7:%.*]] -; CHECK-NEXT: i16 16, label [[SWITCHPRONG5:%.*]] -; CHECK-NEXT: i16 32, label [[SWITCHPRONG4:%.*]] -; CHECK-NEXT: ] +; CHECK-NEXT: [[SWITCH_RANGEREDUCE:%.*]] = call i16 @llvm.ctlz.i16(i16 [[A:%.*]], i1 false) +; CHECK-NEXT: [[SWITCH_RANGEREDUCE1:%.*]] = sub i16 [[SWITCH_RANGEREDUCE]], 8 +; CHECK-NEXT: [[TMP0:%.*]] = icmp ult i16 [[SWITCH_RANGEREDUCE1]], 8 +; CHECK-NEXT: br i1 [[TMP0]], label [[SWITCH_LOOKUP:%.*]], label [[SWITCHELSE:%.*]] +; CHECK: switch.lookup: +; CHECK-NEXT: [[SWITCH_GEP:%.*]] = getelementptr inbounds [8 x i64], [8 x i64]* @switch.table.switch_clz1, i32 0, i16 [[SWITCH_RANGEREDUCE1]] +; CHECK-NEXT: [[SWITCH_LOAD:%.*]] = load i64, i64* [[SWITCH_GEP]] +; CHECK-NEXT: ret i64 [[SWITCH_LOAD]] ; CHECK: SwitchElse: -; CHECK-NEXT: [[MERGE:%.*]] = phi i64 [ 10, [[ENTRY:%.*]] ], [ 6, [[SWITCHPRONG]] ], [ 3, [[SWITCHPRONG1]] ], [ 35, [[SWITCHPRONG2]] ], [ 31, [[SWITCHPRONG3]] ], [ 53, [[SWITCHPRONG4]] ], [ 51, [[SWITCHPRONG5]] ], [ 41, [[SWITCHPRONG6]] ], [ 34, [[SWITCHPRONG7]] ] -; CHECK-NEXT: ret i64 [[MERGE]] -; CHECK: SwitchProng: -; CHECK-NEXT: br label [[SWITCHELSE]] -; CHECK: SwitchProng1: -; CHECK-NEXT: br label [[SWITCHELSE]] -; CHECK: SwitchProng2: -; CHECK-NEXT: br label [[SWITCHELSE]] -; CHECK: SwitchProng3: -; CHECK-NEXT: br label [[SWITCHELSE]] -; CHECK: SwitchProng4: -; CHECK-NEXT: br label [[SWITCHELSE]] -; CHECK: SwitchProng5: -; CHECK-NEXT: br label [[SWITCHELSE]] -; CHECK: SwitchProng6: -; CHECK-NEXT: br label [[SWITCHELSE]] -; CHECK: SwitchProng7: -; CHECK-NEXT: br label [[SWITCHELSE]] +; CHECK-NEXT: ret i64 undef ; Entry: switch i16 %a, label %SwitchElse [ @@ -86,7 +67,7 @@ i16 32, label %SwitchProng4 ] SwitchElse: ; preds = %Entry - ret i64 10 + ret i64 undef SwitchProng: ; preds = %Entry ret i64 6 SwitchProng1: ; preds = %Entry @@ -108,26 +89,51 @@ define i64 @switch_clz2(i8 %a) optsize #0 { ; CHECK-LABEL: @switch_clz2( ; CHECK-NEXT: Entry: -; CHECK-NEXT: switch i8 [[A:%.*]], label [[SWITCHELSE:%.*]] [ -; CHECK-NEXT: i8 -128, label [[SWITCHPRONG2:%.*]] -; CHECK-NEXT: i8 64, label [[SWITCHPRONG3:%.*]] -; CHECK-NEXT: i8 32, label [[SWITCHPRONG6:%.*]] -; CHECK-NEXT: i8 1, label [[SWITCHPRONG7:%.*]] -; CHECK-NEXT: i8 0, label [[SWITCHPRONG5:%.*]] -; CHECK-NEXT: ] +; CHECK-NEXT: [[SWITCH_RANGEREDUCE:%.*]] = call i8 @llvm.ctlz.i8(i8 [[A:%.*]], i1 false) +; CHECK-NEXT: [[TMP0:%.*]] = icmp ult i8 [[SWITCH_RANGEREDUCE]], 9 +; CHECK-NEXT: br i1 [[TMP0]], label [[SWITCH_LOOKUP:%.*]], label [[SWITCHPRONG2:%.*]] +; CHECK: switch.lookup: +; CHECK-NEXT: [[SWITCH_GEP:%.*]] = getelementptr inbounds [9 x i64], [9 x i64]* @switch.table.switch_clz2, i32 0, i8 [[SWITCH_RANGEREDUCE]] +; CHECK-NEXT: [[SWITCH_LOAD:%.*]] = load i64, i64* [[SWITCH_GEP]] +; CHECK-NEXT: ret i64 [[SWITCH_LOAD]] ; CHECK: SwitchProng2: -; CHECK-NEXT: [[MERGE:%.*]] = phi i64 [ 35, [[ENTRY:%.*]] ], [ 31, [[SWITCHPRONG3]] ], [ 41, [[SWITCHPRONG6]] ], [ 40, [[SWITCHPRONG5]] ], [ 43, [[SWITCHPRONG7]] ], [ 12, [[SWITCHELSE]] ] -; CHECK-NEXT: ret i64 [[MERGE]] -; CHECK: SwitchProng3: -; CHECK-NEXT: br label [[SWITCHPRONG2]] -; CHECK: SwitchProng6: -; CHECK-NEXT: br label [[SWITCHPRONG2]] -; CHECK: SwitchProng5: -; CHECK-NEXT: br label [[SWITCHPRONG2]] -; CHECK: SwitchProng7: -; CHECK-NEXT: br label [[SWITCHPRONG2]] -; CHECK: SwitchElse: -; CHECK-NEXT: br label [[SWITCHPRONG2]] +; CHECK-NEXT: ret i64 undef +; +Entry: + switch i8 %a, label %SwitchElse [ + i8 128, label %SwitchProng2 + i8 64, label %SwitchProng3 + i8 32, label %SwitchProng6 + i8 1, label %SwitchProng7 + i8 0, label %SwitchProng5 + ] +SwitchProng2: ; preds = %Entry + ret i64 35 +SwitchProng3: ; preds = %Entry + ret i64 31 +SwitchProng6: ; preds = %Entry + ret i64 41 +SwitchProng5: ; preds = %Entry + ret i64 40 +SwitchProng7: ; preds = %Entry + ret i64 43 +SwitchElse: ; preds = %Entry + ret i64 undef +} + +;can't be optimized as has a default case +define i64 @switch_clz3(i8 %a) optsize #0 { +; CHECK-LABEL: @switch_clz3( +; CHECK-NEXT: Entry: +; CHECK-NEXT: [[SWITCH_RANGEREDUCE:%.*]] = call i8 @llvm.ctlz.i8(i8 [[A:%.*]], i1 false) +; CHECK-NEXT: [[TMP0:%.*]] = icmp ult i8 [[SWITCH_RANGEREDUCE]], 9 +; CHECK-NEXT: br i1 [[TMP0]], label [[SWITCH_LOOKUP:%.*]], label [[SWITCHPRONG2:%.*]] +; CHECK: switch.lookup: +; CHECK-NEXT: [[SWITCH_GEP:%.*]] = getelementptr inbounds [9 x i64], [9 x i64]* @switch.table.switch_clz3, i32 0, i8 [[SWITCH_RANGEREDUCE]] +; CHECK-NEXT: [[SWITCH_LOAD:%.*]] = load i64, i64* [[SWITCH_GEP]] +; CHECK-NEXT: ret i64 [[SWITCH_LOAD]] +; CHECK: SwitchProng2: +; CHECK-NEXT: ret i64 undef ; Entry: switch i8 %a, label %SwitchElse [ @@ -148,7 +154,7 @@ SwitchProng7: ; preds = %Entry ret i64 43 SwitchElse: ; preds = %Entry - ret i64 12 + ret i64 undef } ;Must check that the default was filled in at index 0 as happened here: