diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyFixBrTableDefaults.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyFixBrTableDefaults.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyFixBrTableDefaults.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyFixBrTableDefaults.cpp @@ -41,9 +41,11 @@ char WebAssemblyFixBrTableDefaults::ID = 0; -// `MI` is a br_table instruction missing its default target argument. This +// `MI` is a br_table instruction with a dummy default target argument. This // function finds and adds the default target argument and removes any redundant -// range check preceding the br_table. +// range check preceding the br_table. Returns the MBB that the br_table is +// moved into so it can be removed from further consideration, or nullptr if the +// br_table cannot be optimized. MachineBasicBlock *fixBrTable(MachineInstr &MI, MachineBasicBlock *MBB, MachineFunction &MF) { // Get the header block, which contains the redundant range check. @@ -51,11 +53,13 @@ auto *HeaderMBB = *MBB->pred_begin(); // Find the conditional jump to the default target. If it doesn't exist, the - // default target is unreachable anyway, so we can choose anything. + // default target is unreachable anyway, so we can keep the existing dummy + // target. MachineBasicBlock *TBB = nullptr, *FBB = nullptr; SmallVector Cond; const auto &TII = *MF.getSubtarget().getInstrInfo(); - TII.analyzeBranch(*HeaderMBB, TBB, FBB, Cond); + bool Analyzed = !TII.analyzeBranch(*HeaderMBB, TBB, FBB, Cond); + assert(Analyzed && "Could not analyze jump header branches"); // Here are the possible outcomes. '_' is nullptr, `J` is the jump table block // aka MBB, 'D' is the default block. @@ -66,14 +70,27 @@ // D | _ | Header jumps to the default and falls through to the jump table // D | J | Header jumps to the default and also to the jump table if (TBB && TBB != MBB) { - // Install the default target. assert((FBB == nullptr || FBB == MBB) && "Expected jump or fallthrough to br_table block"); + assert(Cond.size() == 2 && Cond[1].isReg() && "Unexpected condition info"); + + // If the range check checks an i64 value, we cannot optimize it out because + // the i64 index is truncated to an i32, making values over 2^32 + // indistinguishable from small numbers. + MachineRegisterInfo &MRI = MF.getRegInfo(); + auto *RangeCheck = MRI.getVRegDef(Cond[1].getReg()); + assert(RangeCheck != nullptr); + unsigned RangeCheckOp = RangeCheck->getOpcode(); + assert(RangeCheckOp == WebAssembly::GT_U_I32 || + RangeCheckOp == WebAssembly::GT_U_I64); + if (RangeCheckOp == WebAssembly::GT_U_I64) { + // Bail out and leave the jump table untouched + return nullptr; + } + + // Remove the dummy default target and install the real one. + MI.RemoveOperand(MI.getNumExplicitOperands() - 1); MI.addOperand(MF, MachineOperand::CreateMBB(TBB)); - } else { - // Arbitrarily choose the first jump target as the default. - auto *SomeMBB = MI.getOperand(1).getMBB(); - MI.addOperand(MachineOperand::CreateMBB(SomeMBB)); } // Remove any branches from the header and splice in the jump table instead @@ -110,8 +127,10 @@ for (auto &MI : *MBB) { if (WebAssembly::isBrTable(MI)) { auto *Fixed = fixBrTable(MI, MBB, MF); - MBBSet.erase(Fixed); - Changed = true; + if (Fixed != nullptr) { + MBBSet.erase(Fixed); + Changed = true; + } break; } } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -1285,8 +1285,10 @@ for (auto MBB : MBBs) Ops.push_back(DAG.getBasicBlock(MBB)); - // Do not add the default case for now. It will be added in - // WebAssemblyFixBrTableDefaults. + // Add the first MBB as a dummy default target for now. This will be replaced + // with the proper default target (and the preceding range check eliminated) + // if possible by WebAssemblyFixBrTableDefaults. + Ops.push_back(DAG.getBasicBlock(*MBBs.begin())); return DAG.getNode(WebAssemblyISD::BR_TABLE, DL, MVT::Other, Ops); } diff --git a/llvm/test/CodeGen/WebAssembly/switch.ll b/llvm/test/CodeGen/WebAssembly/switch.ll --- a/llvm/test/CodeGen/WebAssembly/switch.ll +++ b/llvm/test/CodeGen/WebAssembly/switch.ll @@ -95,26 +95,30 @@ ; CHECK-LABEL: bar64: ; CHECK: block {{$}} +; CHECK: i64.const +; CHECK: i64.gt_u +; CHECK: br_if 0 ; CHECK: block {{$}} ; CHECK: block {{$}} ; CHECK: block {{$}} ; CHECK: block {{$}} ; CHECK: block {{$}} ; CHECK: block {{$}} -; CHECK: br_table {{[^,]+}}, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 4, 5, 6{{$}} -; CHECK: .LBB{{[0-9]+}}_1: -; CHECK: call foo0{{$}} +; CHECK: i32.wrap_i64 +; CHECK: br_table {{[^,]+}}, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 4, 5, 0{{$}} ; CHECK: .LBB{{[0-9]+}}_2: -; CHECK: call foo1{{$}} +; CHECK: call foo0{{$}} ; CHECK: .LBB{{[0-9]+}}_3: -; CHECK: call foo2{{$}} +; CHECK: call foo1{{$}} ; CHECK: .LBB{{[0-9]+}}_4: -; CHECK: call foo3{{$}} +; CHECK: call foo2{{$}} ; CHECK: .LBB{{[0-9]+}}_5: -; CHECK: call foo4{{$}} +; CHECK: call foo3{{$}} ; CHECK: .LBB{{[0-9]+}}_6: -; CHECK: call foo5{{$}} +; CHECK: call foo4{{$}} ; CHECK: .LBB{{[0-9]+}}_7: +; CHECK: call foo5{{$}} +; CHECK: .LBB{{[0-9]+}}_8: ; CHECK: return{{$}} define void @bar64(i64 %n) { entry: @@ -172,3 +176,43 @@ sw.epilog: ; preds = %entry, %sw.bb.5, %sw.bb.4, %sw.bb.3, %sw.bb.2, %sw.bb.1, %sw.bb ret void } + +; CHECK-LABEL: truncated: +; CHECK: block +; CHECK: block +; CHECK: block +; CHECK: i32.wrap_i64 +; CHECK: br_table {{[^,]+}}, 0, 1, 2{{$}} +; CHECK: .LBB{{[0-9]+}}_1 +; CHECK: end_block +; CHECK: call foo0{{$}} +; CHECK: return{{$}} +; CHECK: .LBB{{[0-9]+}}_2 +; CHECK: end_block +; CHECK: call foo1{{$}} +; CHECK: return{{$}} +; CHECK: .LBB{{[0-9]+}}_3 +; CHECK: end_block +; CHECK: call foo2{{$}} +; CHECK: return{{$}} +; CHECK: end_function +define void @truncated(i64 %n) { +entry: + %m = trunc i64 %n to i32 + switch i32 %m, label %default [ + i32 0, label %bb1 + i32 1, label %bb2 + ] + +bb1: + tail call void @foo0() + ret void + +bb2: + tail call void @foo1() + ret void + +default: + tail call void @foo2() + ret void +}