diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyFixBrTableDefaults.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyFixBrTableDefaults.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyFixBrTableDefaults.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyFixBrTableDefaults.cpp @@ -41,9 +41,10 @@ char WebAssemblyFixBrTableDefaults::ID = 0; -// `MI` is a br_table instruction missing its default target argument. This +// `MI` is a br_table instruction with a dummy default target argument. This // function finds and adds the default target argument and removes any redundant -// range check preceding the br_table. +// range check preceding the br_table. Returns the MBB that the br_table is +// moved into so it can be removed from further consideration. MachineBasicBlock *fixBrTable(MachineInstr &MI, MachineBasicBlock *MBB, MachineFunction &MF) { // Get the header block, which contains the redundant range check. @@ -51,7 +52,8 @@ auto *HeaderMBB = *MBB->pred_begin(); // Find the conditional jump to the default target. If it doesn't exist, the - // default target is unreachable anyway, so we can choose anything. + // default target is unreachable anyway, so we can keep the existing dummy + // target. MachineBasicBlock *TBB = nullptr, *FBB = nullptr; SmallVector Cond; const auto &TII = *MF.getSubtarget().getInstrInfo(); @@ -66,14 +68,11 @@ // D | _ | Header jumps to the default and falls through to the jump table // D | J | Header jumps to the default and also to the jump table if (TBB && TBB != MBB) { - // Install the default target. assert((FBB == nullptr || FBB == MBB) && "Expected jump or fallthrough to br_table block"); + // Remove the dummy default target and install the real one. + MI.RemoveOperand(MI.getNumExplicitOperands() - 1); MI.addOperand(MF, MachineOperand::CreateMBB(TBB)); - } else { - // Arbitrarily choose the first jump target as the default. - auto *SomeMBB = MI.getOperand(1).getMBB(); - MI.addOperand(MachineOperand::CreateMBB(SomeMBB)); } // Remove any branches from the header and splice in the jump table instead @@ -98,6 +97,7 @@ LLVM_DEBUG(dbgs() << "********** Fixing br_table Default Targets **********\n" "********** Function: " << MF.getName() << '\n'); + MachineRegisterInfo &MRI = MF.getRegInfo(); bool Changed = false; SmallPtrSet MBBSet; @@ -109,6 +109,16 @@ MBBSet.erase(MBB); for (auto &MI : *MBB) { if (WebAssembly::isBrTable(MI)) { + // If the table index is truncated from an i64, we can't get rid of the + // range check because numbers just over 2^32 would otherwise be + // indistinguishable from small numbers. + auto IndexReg = MI.getOperand(0).getReg(); + auto *IndexSrc = MRI.getVRegDef(IndexReg); + if (IndexSrc->getOpcode() == WebAssembly::I32_WRAP_I64) + continue; + + // Otherwise, remove the range check and remove the block containing the + // fixed br_table from consideration. auto *Fixed = fixBrTable(MI, MBB, MF); MBBSet.erase(Fixed); Changed = true; diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -1285,8 +1285,10 @@ for (auto MBB : MBBs) Ops.push_back(DAG.getBasicBlock(MBB)); - // Do not add the default case for now. It will be added in - // WebAssemblyFixBrTableDefaults. + // Add the first MBB as a dummy default target for now. This will be replaced + // with the proper default target (and the preceding range check eliminated) + // if possible by WebAssemblyFixBrTableDefaults. + Ops.push_back(DAG.getBasicBlock(*MBBs.begin())); return DAG.getNode(WebAssemblyISD::BR_TABLE, DL, MVT::Other, Ops); } diff --git a/llvm/test/CodeGen/WebAssembly/switch.ll b/llvm/test/CodeGen/WebAssembly/switch.ll --- a/llvm/test/CodeGen/WebAssembly/switch.ll +++ b/llvm/test/CodeGen/WebAssembly/switch.ll @@ -95,26 +95,30 @@ ; CHECK-LABEL: bar64: ; CHECK: block {{$}} +; CHECK: i64.const +; CHECK: i64.gt_u +; CHECK: br_if 0 ; CHECK: block {{$}} ; CHECK: block {{$}} ; CHECK: block {{$}} ; CHECK: block {{$}} ; CHECK: block {{$}} ; CHECK: block {{$}} -; CHECK: br_table {{[^,]+}}, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 4, 5, 6{{$}} -; CHECK: .LBB{{[0-9]+}}_1: -; CHECK: call foo0{{$}} +; CHECK: i32.wrap_i64 +; CHECK: br_table {{[^,]+}}, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 4, 5, 0{{$}} ; CHECK: .LBB{{[0-9]+}}_2: -; CHECK: call foo1{{$}} +; CHECK: call foo0{{$}} ; CHECK: .LBB{{[0-9]+}}_3: -; CHECK: call foo2{{$}} +; CHECK: call foo1{{$}} ; CHECK: .LBB{{[0-9]+}}_4: -; CHECK: call foo3{{$}} +; CHECK: call foo2{{$}} ; CHECK: .LBB{{[0-9]+}}_5: -; CHECK: call foo4{{$}} +; CHECK: call foo3{{$}} ; CHECK: .LBB{{[0-9]+}}_6: -; CHECK: call foo5{{$}} +; CHECK: call foo4{{$}} ; CHECK: .LBB{{[0-9]+}}_7: +; CHECK: call foo5{{$}} +; CHECK: .LBB{{[0-9]+}}_8: ; CHECK: return{{$}} define void @bar64(i64 %n) { entry: