Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h @@ -260,13 +260,19 @@ }; struct JumpTableHeader { JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H, - bool E = false): - First(F), Last(L), SValue(SV), HeaderBB(H), Emitted(E) {} + bool E, APInt Base, unsigned Shift) + : First(F), Last(L), SValue(SV), HeaderBB(H), Emitted(E), Base(Base), + Shift(Shift) { + if (Shift == 0) + this->Base = First; + } APInt First; APInt Last; const Value *SValue; MachineBasicBlock *HeaderBB; bool Emitted; + APInt Base; + unsigned Shift; }; typedef std::pair JumpTableBlock; @@ -311,7 +317,16 @@ /// decides it's not a good idea. bool buildJumpTable(CaseClusterVector &Clusters, unsigned First, unsigned Last, const SwitchInst *SI, - MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster); + MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster, + uint64_t Base, unsigned Shift); + + /// If possible, try and factorize out a power of two from each case value + /// and construct a jump table from SI->getCondition() >> Shift. + /// Return false if this isn't possible or isn't a good idea. + bool buildShiftedJumpTable(CaseClusterVector &Clusters, unsigned *TotalCases, + unsigned Density, const SwitchInst *SI, + MachineBasicBlock *DefaultMBB, + CaseCluster &JTCluster); /// Find clusters of cases suitable for jump table lowering. void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI, Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -1955,14 +1955,26 @@ JumpTableHeader &JTH, MachineBasicBlock *SwitchBB) { SDLoc dl = getCurSDLoc(); - // Subtract the lowest switch case value from the value being switched on and // conditional branch to default mbb if the result is greater than the // difference between smallest and largest cases. SDValue SwitchOp = getValue(JTH.SValue); EVT VT = SwitchOp.getValueType(); SDValue Sub = DAG.getNode(ISD::SUB, dl, VT, SwitchOp, - DAG.getConstant(JTH.First, dl, VT)); + DAG.getConstant(JTH.Base, dl, VT)); + if (JTH.Shift != 0) { + // We need to shift the subtracted value right before calculating the jump + // index. This is equivalent to dividing by 2**Shift. We also need to ensure + // that the subtracted value divides cleanly by 2**Shift. This can be folded + // into the jump table bounds check by rotating the subtracted value right: + // + // Idx = rotr(Sub, Shift) + // + // If any of the low bits of Sub are set, these become high bits and so the + // table index comparison will fail. + Sub = + DAG.getNode(ISD::ROTR, dl, VT, Sub, DAG.getConstant(JTH.Shift, dl, VT)); + } // The SDNode we just created, which holds the value being switched on minus // the smallest case value, needs to be copied to a virtual register so it @@ -2350,7 +2362,6 @@ for (const CaseCluster &CC : Clusters) assert(CC.Low == CC.High && "Input clusters must be single-case"); #endif - std::sort(Clusters.begin(), Clusters.end(), [](const CaseCluster &a, const CaseCluster &b) { return a.Low->getValue().slt(b.Low->getValue()); @@ -8205,7 +8216,8 @@ unsigned First, unsigned Last, const SwitchInst *SI, MachineBasicBlock *DefaultMBB, - CaseCluster &JTCluster) { + CaseCluster &JTCluster, + uint64_t Base, unsigned Shift) { assert(First <= Last); auto Prob = BranchProbability::getZero(); @@ -8266,17 +8278,69 @@ ->createJumpTableIndex(Table); // Set up the jump table info. + auto BW = Clusters[First].Low->getValue().getBitWidth(); JumpTable JT(-1U, JTI, JumpTableMBB, nullptr); JumpTableHeader JTH(Clusters[First].Low->getValue(), Clusters[Last].High->getValue(), SI->getCondition(), - nullptr, false); + nullptr, false, APInt(BW, Base), Shift); JTCases.emplace_back(std::move(JTH), std::move(JT)); - JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High, JTCases.size() - 1, Prob); return true; } +bool SelectionDAGBuilder::buildShiftedJumpTable(CaseClusterVector &Clusters, + unsigned *TotalCases, + unsigned Density, + const SwitchInst *SI, + MachineBasicBlock *DefaultMBB, + CaseCluster &JTCluster) { + auto BW = Clusters[0].High->getValue().getBitWidth(); + if (BW > 64) + return false; + // This transform is agnostic to the signedness of the input or case values. We + // can treat the case values as signed or unsigned. We can optimize more common + // cases such as a sequence crossing zero {-4,0,4,8} if we interpret case values + // as signed. + SmallVector Values; + for (auto &C : Clusters) { + if (C.High != C.Low) + // Consecutive values mean the greatest common divisor is one. + return false; + Values.push_back(C.High->getValue().getSExtValue()); + } + + // First, transform the values such that they start at zero and ascend. + uint64_t Base = Values[0]; + for (auto &V : Values) + V -= (int64_t)Base; + + // Now we have signed numbers that have been shifted so that, given enough + // precision, there are no negative values. Since the rest of the transform + // is bitwise only, we switch now to an unsigned representation. + uint64_t GCD = 0; + for (auto &V : Values) + GCD = llvm::GreatestCommonDivisor64(GCD, (uint64_t)V); + if (GCD <= 1 || !llvm::isPowerOf2_64(GCD)) + return false; + unsigned Shift = llvm::Log2_64(GCD); + + auto XFormedClusters = Clusters; + for (auto &C : XFormedClusters) { + auto *CI = ConstantInt::get(C.High->getType(), + C.High->getValue().lshr(Shift)); + C.Low = C.High = cast(CI); + } + + if (!isDense(XFormedClusters, &TotalCases[0], 0, Clusters.size() - 1, + Density)) + // Transform didn't create a dense switch. + return false; + + return buildJumpTable(XFormedClusters, 0, Clusters.size() - 1, SI, DefaultMBB, + JTCluster, Base, Shift); +} + void SelectionDAGBuilder::findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI, MachineBasicBlock *DefaultMBB) { @@ -8295,6 +8359,8 @@ const int64_t N = Clusters.size(); const unsigned MinJumpTableSize = TLI.getMinimumJumpTableEntries(); + if (N < MinJumpTableSize) + return; // TotalCases[i]: Total nbr of cases in Clusters[0..i]. SmallVector TotalCases(N); @@ -8310,11 +8376,10 @@ unsigned MinDensity = JumpTableDensity; if (DefaultMBB->getParent()->getFunction()->optForSize()) MinDensity = OptsizeJumpTableDensity; - if (N >= MinJumpTableSize - && isDense(Clusters, &TotalCases[0], 0, N - 1, MinDensity)) { + if (isDense(Clusters, &TotalCases[0], 0, N - 1, MinDensity)) { // Cheap case: the whole range might be suitable for jump table. CaseCluster JTCluster; - if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) { + if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster, 0, 0)) { Clusters[0] = JTCluster; Clusters.resize(1); return; @@ -8325,6 +8390,16 @@ if (TM.getOptLevel() == CodeGenOpt::None) return; + // If the cases are in arithmetic progression, we may be able to factorize + // to make the table dense. + CaseCluster JTCluster; + if (buildShiftedJumpTable(Clusters, &TotalCases[0], MinDensity, SI, DefaultMBB, + JTCluster)) { + Clusters[0] = JTCluster; + Clusters.resize(1); + return; + } + // Split Clusters into minimum number of dense partitions. The algorithm uses // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code // for the Case Statement'" (1994), but builds the MinPartitions array in @@ -8383,7 +8458,7 @@ CaseCluster JTCluster; if (NumClusters >= MinJumpTableSize && - buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) { + buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster, 0, 0)) { Clusters[DstIndex++] = JTCluster; } else { for (unsigned I = First; I <= Last; ++I) Index: test/CodeGen/ARM/rangereduce-switch.ll =================================================================== --- /dev/null +++ test/CodeGen/ARM/rangereduce-switch.ll @@ -0,0 +1,192 @@ +; RUN: llc -mtriple=thumbv7-linux-gnu < %s -stop-after machine-cp -o /dev/null 2>&1 | FileCheck %s + +; CHECK-LABEL: name: test1 +; CHECK: %[[x:.*]] = t2SUBri killed %r0, 97 +; CHECK: %[[y:.*]] = t2RORri killed %[[x]], 2 +; CHECK: t2CMPri %[[y]], 3 +; CHECK: t2Bcc +; CHECK: t2BR_JT +define i32 @test1(i32 %a) optsize { + switch i32 %a, label %def [ + i32 97, label %one + i32 101, label %two + i32 105, label %three + i32 109, label %three + ] + +def: + ret i32 8867 + +one: + ret i32 11984 +two: + ret i32 1143 +three: + ret i32 99783 +} + +; Optimization shouldn't trigger; bitwidth > 64 +; CHECK-LABEL: name: test2 +; CHECK-NOT: t2BR_JT +define i128 @test2(i128 %a) optsize { + switch i128 %a, label %def [ + i128 97, label %one + i128 101, label %two + i128 105, label %three + i128 109, label %three + ] + +def: + ret i128 8867 + +one: + ret i128 11984 +two: + ret i128 1143 +three: + ret i128 99783 +} + + +; Optimization shouldn't trigger; no holes present +; CHECK-LABEL: name: test3 +; CHECK-NOT: t2BR_JT +define i32 @test3(i32 %a) optsize { + switch i32 %a, label %def [ + i32 97, label %one + i32 98, label %two + i32 99, label %three + i32 100, label %three + ] + +def: + ret i32 8867 + +one: + ret i32 11984 +two: + ret i32 1143 +three: + ret i32 99783 +} + +; Optimization shouldn't trigger; not an arithmetic progression +; CHECK-LABEL: name: test4 +; CHECK-NOT: t2BR_JT +define i32 @test4(i32 %a) optsize { + switch i32 %a, label %def [ + i32 97, label %one + i32 102, label %two + i32 105, label %three + i32 109, label %three + ] + +def: + ret i32 8867 + +one: + ret i32 11984 +two: + ret i32 1143 +three: + ret i32 99783 +} + +; Optimization shouldn't trigger; not a power of two +; CHECK-LABEL: name: test5 +; CHECK-NOT: t2BR_JT +define i32 @test5(i32 %a) optsize { + switch i32 %a, label %def [ + i32 97, label %one + i32 102, label %two + i32 107, label %three + i32 112, label %three + ] + +def: + ret i32 8867 + +one: + ret i32 11984 +two: + ret i32 1143 +three: + ret i32 99783 +} + +; CHECK-LABEL: name: test6 +; CHECK: %[[x:.*]] = t2ADDri killed %r0, 109 +; CHECK: %[[y:.*]] = t2RORri killed %[[x]], 2 +; CHECK: t2CMPri %[[y]], 3 +; CHECK: t2Bcc +; CHECK: t2BR_JT +define i32 @test6(i32 %a) optsize { + switch i32 %a, label %def [ + i32 -97, label %one + i32 -101, label %two + i32 -105, label %three + i32 -109, label %three + ] + +def: + ret i32 8867 + +one: + ret i32 11984 +two: + ret i32 1143 +three: + ret i32 99783 +} + +; CHECK-LABEL: name: test7 +; CHECK: %[[z:.*]] = t2UXTB killed %r0, 0 +; CHECK: %[[x:.*]] = t2SUBri killed %[[z]], 220 +; CHECK: %[[y:.*]] = t2RORri killed %[[x]], 2 +; CHECK: t2CMPri %[[y]], 3 +; CHECK: t2Bcc +; CHECK: t2BR_JT +define i8 @test7(i8 %a) optsize { + switch i8 %a, label %def [ + i8 220, label %one + i8 224, label %two + i8 228, label %three + i8 232, label %three + ] + +def: + ret i8 8867 + +one: + ret i8 11984 +two: + ret i8 1143 +three: + ret i8 99783 +} + + +; CHECK-LABEL: test8 +; CHECK: %[[x:.*]] = t2SUBri killed %r0, 97 +; CHECK: %[[y:.*]] = t2RORri killed %[[x]], 2 +; CHECK: t2CMPri %[[y]], 4 +; CHECK: t2Bcc +; CHECK: t2BR_JT +define i32 @test8(i32 %a) optsize { + switch i32 %a, label %def [ + i32 97, label %one + i32 101, label %two + i32 105, label %three + i32 113, label %three + ] + +def: + ret i32 8867 + +one: + ret i32 11984 +two: + ret i32 1143 +three: + ret i32 99783 +} Index: test/CodeGen/X86/switch-edge-weight.ll =================================================================== --- test/CodeGen/X86/switch-edge-weight.ll +++ test/CodeGen/X86/switch-edge-weight.ll @@ -237,7 +237,7 @@ i32 20, label %sw.bb2 i32 28, label %sw.bb3 i32 36, label %sw.bb4 - i32 124, label %sw.bb5 + i32 125, label %sw.bb5 ], !prof !2 sw.bb: