diff --git a/llvm/lib/Target/WebAssembly/CMakeLists.txt b/llvm/lib/Target/WebAssembly/CMakeLists.txt --- a/llvm/lib/Target/WebAssembly/CMakeLists.txt +++ b/llvm/lib/Target/WebAssembly/CMakeLists.txt @@ -24,6 +24,7 @@ WebAssemblyExceptionInfo.cpp WebAssemblyExplicitLocals.cpp WebAssemblyFastISel.cpp + WebAssemblyFixBrTableDefaults.cpp WebAssemblyFixIrreducibleControlFlow.cpp WebAssemblyFixFunctionBitcasts.cpp WebAssemblyFrameLowering.cpp diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h --- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h +++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h @@ -459,6 +459,18 @@ } } +inline bool isBrTable(const MachineInstr &MI) { + switch (MI.getOpcode()) { + case WebAssembly::BR_TABLE_I32: + case WebAssembly::BR_TABLE_I32_S: + case WebAssembly::BR_TABLE_I64: + case WebAssembly::BR_TABLE_I64_S: + return true; + default: + return false; + } +} + inline bool isMarker(unsigned Opc) { switch (Opc) { case WebAssembly::BLOCK: diff --git a/llvm/lib/Target/WebAssembly/WebAssembly.h b/llvm/lib/Target/WebAssembly/WebAssembly.h --- a/llvm/lib/Target/WebAssembly/WebAssembly.h +++ b/llvm/lib/Target/WebAssembly/WebAssembly.h @@ -44,6 +44,7 @@ FunctionPass *createWebAssemblyMemIntrinsicResults(); FunctionPass *createWebAssemblyRegStackify(); FunctionPass *createWebAssemblyRegColoring(); +FunctionPass *createWebAssemblyFixBrTableDefaults(); FunctionPass *createWebAssemblyFixIrreducibleControlFlow(); FunctionPass *createWebAssemblyLateEHPrepare(); FunctionPass *createWebAssemblyCFGSort(); @@ -68,6 +69,7 @@ void initializeWebAssemblyMemIntrinsicResultsPass(PassRegistry &); void initializeWebAssemblyRegStackifyPass(PassRegistry &); void initializeWebAssemblyRegColoringPass(PassRegistry &); +void initializeWebAssemblyFixBrTableDefaultsPass(PassRegistry &); void initializeWebAssemblyFixIrreducibleControlFlowPass(PassRegistry &); void initializeWebAssemblyLateEHPreparePass(PassRegistry &); void initializeWebAssemblyExceptionInfoPass(PassRegistry &); diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyFixBrTableDefaults.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyFixBrTableDefaults.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/WebAssemblyFixBrTableDefaults.cpp @@ -0,0 +1,129 @@ +//=- WebAssemblyFixBrTableDefaults.cpp - Fix br_table default branch targets -// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file implements a pass that eliminates redundant range checks +/// guarding br_table instructions. Since jump tables on most targets cannot +/// handle out of range indices, LLVM emits these checks before most jump +/// tables. But br_table takes a default branch target as an argument, so it +/// does not need the range checks. +/// +//===----------------------------------------------------------------------===// + +#include "MCTargetDesc/WebAssemblyMCTargetDesc.h" +#include "WebAssembly.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/Pass.h" + +using namespace llvm; + +#define DEBUG_TYPE "wasm-fix-br-table-defaults" + +namespace { + +class WebAssemblyFixBrTableDefaults final : public MachineFunctionPass { + StringRef getPassName() const override { + return "WebAssembly Fix br_table Defaults"; + } + + bool runOnMachineFunction(MachineFunction &MF) override; + +public: + static char ID; // Pass identification, replacement for typeid + WebAssemblyFixBrTableDefaults() : MachineFunctionPass(ID) {} +}; + +char WebAssemblyFixBrTableDefaults::ID = 0; + +// `MI` is a br_table instruction missing its default target argument. This +// function finds and adds the default target argument and removes any redundant +// range check preceding the br_table. +MachineBasicBlock *fixBrTable(MachineInstr &MI, MachineBasicBlock *MBB, + MachineFunction &MF) { + // Get the header block, which contains the redundant range check. + assert(MBB->pred_size() == 1 && "Expected a single guard predecessor"); + auto *HeaderMBB = *MBB->pred_begin(); + + // Find the conditional jump to the default target. If it doesn't exist, the + // default target is unreachable anyway, so we can choose anything. + auto &JumpMII = --HeaderMBB->end(); + while (JumpMII->getOpcode() != WebAssembly::BR_IF && + JumpMII != HeaderMBB->begin()) { + --JumpMII; + } + if (JumpMII->getOpcode() == WebAssembly::BR_IF) { + // Install the default target and remove the jumps in the header. + auto *DefaultMBB = JumpMII->getOperand(0).getMBB(); + assert(DefaultMBB != MBB && "Expected conditional jump to default target"); + MI.addOperand(MF, MachineOperand::CreateMBB(DefaultMBB)); + HeaderMBB->erase(JumpMII, HeaderMBB->end()); + } else { + // Arbitrarily choose the first jump target as the default. + auto *SomeMBB = MI.getOperand(1).getMBB(); + MI.addOperand(MachineOperand::CreateMBB(SomeMBB)); + } + + // Splice the jump table into the header. + HeaderMBB->splice(HeaderMBB->end(), MBB, MBB->begin(), MBB->end()); + + // Update CFG to skip the old jump table block. Remove shared successors + // before transferring to avoid duplicated successors. + HeaderMBB->removeSuccessor(MBB); + for (auto &Succ : MBB->successors()) + if (HeaderMBB->isSuccessor(Succ)) + HeaderMBB->removeSuccessor(Succ); + HeaderMBB->transferSuccessorsAndUpdatePHIs(MBB); + + // Remove the old jump table block from the function + MF.erase(MBB); + + return HeaderMBB; +} + +bool WebAssemblyFixBrTableDefaults::runOnMachineFunction(MachineFunction &MF) { + LLVM_DEBUG(dbgs() << "********** Fixing br_table Default Targets **********\n" + "********** Function: " + << MF.getName() << '\n'); + + bool Changed = false; + SmallPtrSet MBBSet; + for (auto &MBB : MF) + MBBSet.insert(&MBB); + + while (!MBBSet.empty()) { + MachineBasicBlock *MBB = *MBBSet.begin(); + MBBSet.erase(MBB); + for (auto &MI : *MBB) { + if (WebAssembly::isBrTable(MI)) { + auto *Fixed = fixBrTable(MI, MBB, MF); + MBBSet.erase(Fixed); + Changed = true; + break; + } + } + } + + if (Changed) { + // We rewrote part of the function; recompute relevant things. + MF.RenumberBlocks(); + return true; + } + + return false; +} + +} // end anonymous namespace + +INITIALIZE_PASS(WebAssemblyFixBrTableDefaults, DEBUG_TYPE, + "Removes range checks and sets br_table default targets", false, + false); + +FunctionPass *llvm::createWebAssemblyFixBrTableDefaults() { + return new WebAssemblyFixBrTableDefaults(); +} diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -1279,11 +1279,8 @@ for (auto MBB : MBBs) Ops.push_back(DAG.getBasicBlock(MBB)); - // TODO: For now, we just pick something arbitrary for a default case for now. - // We really want to sniff out the guard and put in the real default case (and - // delete the guard). - Ops.push_back(DAG.getBasicBlock(MBBs[0])); - + // Do not add the default case for now. It will be added in + // WebAssemblyFixBrTableDefaults. return DAG.getNode(WebAssemblyISD::BR_TABLE, DL, MVT::Other, Ops); } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp @@ -406,6 +406,10 @@ // it's inconvenient to collect. Collect it now, and update the immediate // operands. addPass(createWebAssemblySetP2AlignOperands()); + + // Eliminate range checks and add default targets to br_table instructions. + addPass(createWebAssemblyFixBrTableDefaults()); + return false; } diff --git a/llvm/test/CodeGen/WebAssembly/cfg-stackify.ll b/llvm/test/CodeGen/WebAssembly/cfg-stackify.ll --- a/llvm/test/CodeGen/WebAssembly/cfg-stackify.ll +++ b/llvm/test/CodeGen/WebAssembly/cfg-stackify.ll @@ -382,14 +382,14 @@ ; CHECK-LABEL: test4: ; CHECK-NEXT: .functype test4 (i32) -> (){{$}} -; CHECK: block {{$}} ; CHECK-NEXT: block {{$}} -; CHECK: br_if 0, $pop{{[0-9]+}}{{$}} -; CHECK: br 1{{$}} -; CHECK-NEXT: .LBB{{[0-9]+}}_2: +; CHECK-NEXT: block {{$}} +; CHECK-NEXT: br_table $0, 1, 1, 1, 1, 1, 0{{$}} +; CHECK-NEXT: .LBB{{[0-9]+}}_1: ; CHECK-NEXT: end_block{{$}} -; CHECK-NEXT: br_table $0, 0, 0, 0, 0, 0, 0{{$}} -; CHECK-NEXT: .LBB{{[0-9]+}}_3: +; CHECK-NEXT: i32.const $push[[C:[0-9]+]]=, 622{{$}} +; CHECK-NEXT: i32.eq $drop=, $0, $pop[[C]]{{$}} +; CHECK-NEXT: .LBB{{[0-9]+}}_2: ; CHECK-NEXT: end_block{{$}} ; CHECK-NEXT: return{{$}} define void @test4(i32 %t) { @@ -649,20 +649,16 @@ ; CHECK: br_if 0, {{[^,]+}}{{$}} ; CHECK-NEXT: end_loop{{$}} ; CHECK-NEXT: block {{$}} -; CHECK: br_if 0, {{[^,]+}}{{$}} -; CHECK: br 3{{$}} -; CHECK-NEXT: .LBB{{[0-9]+}}_7: -; CHECK-NEXT: end_block{{$}} -; CHECK: block {{$}} -; CHECK-NEXT: br_table $0, 0, 3, 1, 2, 0 -; CHECK-NEXT: .LBB{{[0-9]+}}_8: +; CHECK-NOT: br_if +; CHECK: br_table $pop{{[^,]+}}, 0, 3, 1, 2, 3 +; CHECK-NEXT: .LBB{{[0-9]+}}_6: ; CHECK-NEXT: end_block{{$}} ; CHECK-NEXT: end_loop{{$}} ; CHECK-NEXT: return{{$}} -; CHECK-NEXT: .LBB{{[0-9]+}}_9: +; CHECK-NEXT: .LBB{{[0-9]+}}_7: ; CHECK-NEXT: end_block{{$}} ; CHECK: br 0{{$}} -; CHECK-NEXT: .LBB{{[0-9]+}}_10: +; CHECK-NEXT: .LBB{{[0-9]+}}_8: ; CHECK-NEXT: end_loop{{$}} define void @test10() { bb0: @@ -767,25 +763,22 @@ ; CHECK-LABEL: test12: ; CHECK: .LBB{{[0-9]+}}_1: -; CHECK-NEXT: loop {{$}} ; CHECK-NEXT: block {{$}} +; CHECK-NEXT: loop {{$}} ; CHECK-NEXT: block {{$}} ; CHECK-NEXT: block {{$}} +; CHECK: br_table {{[^,]+}}, 1, 3, 3, 3, 1, 0{{$}} +; CHECK-NEXT: .LBB{{[0-9]+}}_2: +; CHECK-NEXT: end_block{{$}} ; CHECK: br_if 0, {{[^,]+}}{{$}} ; CHECK: br_if 2, {{[^,]+}}{{$}} -; CHECK: br_if 1, {{[^,]+}}{{$}} -; CHECK-NEXT: br 2{{$}} ; CHECK-NEXT: .LBB{{[0-9]+}}_4: ; CHECK-NEXT: end_block{{$}} -; CHECK-NEXT: br_table $2, 1, 0, 0, 0, 1, 1{{$}} +; CHECK: br 0{{$}} ; CHECK-NEXT: .LBB{{[0-9]+}}_5: +; CHECK-NEXT: end_loop{{$}} ; CHECK-NEXT: end_block{{$}} ; CHECK-NEXT: return{{$}} -; CHECK-NEXT: .LBB{{[0-9]+}}_6: -; CHECK-NEXT: end_block{{$}} -; CHECK: br 0{{$}} -; CHECK-NEXT: .LBB{{[0-9]+}}_7: -; CHECK-NEXT: end_loop{{$}} define void @test12(i8* %arg) { bb: br label %bb1 diff --git a/llvm/test/CodeGen/WebAssembly/indirectbr.ll b/llvm/test/CodeGen/WebAssembly/indirectbr.ll --- a/llvm/test/CodeGen/WebAssembly/indirectbr.ll +++ b/llvm/test/CodeGen/WebAssembly/indirectbr.ll @@ -13,20 +13,36 @@ ; Just check the barest skeleton of the structure ; CHECK-LABEL: test1: +; CHECK: block +; CHECK: block +; CHECK: block +; CHECK: block ; CHECK: i32.load -; CHECK: i32.load $[[DEST:.+]]= +; CHECK: i32.load +; CHECK: i32.const +; CHECK: i32.add $push[[DEST:.+]]= +; CHECK: br_table $pop[[DEST]] +; CHECK: end_block +; CHECK: end_block +; CHECK: end_block +; CHECK: end_block ; CHECK: loop ; CHECK: block ; CHECK: block +; CHECK: block +; CHECK: block +; CHECK: br_table ${{[^,]+}}, 0, 1, 2, 2 +; CHECK: end_block ; CHECK: end_block +; CHECK: end_block +; CHECK: block ; CHECK: block ; CHECK: block -; CHECK: br_table $[[DEST]] +; CHECK: br_table ${{[^,]+}}, 1, 2, 0 +; CHECK: end_block ; CHECK: end_block ; CHECK: end_block -; CHECK: i32.load $[[DEST]]= ; CHECK: end_loop - ; CHECK: test1.targets: ; CHECK-NEXT: .int32 ; CHECK-NEXT: .int32 diff --git a/llvm/test/CodeGen/WebAssembly/stack-insts.ll b/llvm/test/CodeGen/WebAssembly/stack-insts.ll --- a/llvm/test/CodeGen/WebAssembly/stack-insts.ll +++ b/llvm/test/CodeGen/WebAssembly/stack-insts.ll @@ -8,7 +8,7 @@ ; Tests if br_table is printed correctly with a tab. ; CHECK-LABEL: test0: -; CHECK: br_table {0, 1, 0, 1, 0} +; CHECK: br_table {0, 1, 0, 1, 2} define void @test0(i32 %n) { entry: switch i32 %n, label %sw.epilog [ diff --git a/llvm/test/CodeGen/WebAssembly/switch-unreachable-default.ll b/llvm/test/CodeGen/WebAssembly/switch-unreachable-default.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/switch-unreachable-default.ll @@ -0,0 +1,38 @@ +; RUN: llc < %s -asm-verbose=false -verify-machineinstrs | FileCheck %s + +target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128" +target triple = "wasm32-unknown-unknown" + +; Test that switches are lowered correctly in the presence of an +; unreachable default branch target. + +; CHECK-LABEL: foo: +; CHECK-NEXT: .functype foo (i32) -> (i32) +; CHECK-NEXT: block +; CHECK-NEXT: block +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: br_table {0, 1, 0} +; CHECK-NEXT: .LBB0_1: +; CHECK-NEXT: end_block +; CHECK-NEXT: i32.const 0 +; CHECK-NEXT: return +; CHECK-NEXT: .LBB0_2: +; CHECK-NEXT: end_block +; CHECK-NEXT: i32.const 1 +; CHECK-NEXT: end_function +define i32 @foo(i32 %x) { +entry: + switch i32 %x, label %unreachable [ + i32 0, label %bb0 + i32 1, label %bb1 + ] + +bb0: + ret i32 0 + +bb1: + ret i32 1 + +unreachable: + unreachable +} diff --git a/llvm/test/CodeGen/WebAssembly/switch.ll b/llvm/test/CodeGen/WebAssembly/switch.ll --- a/llvm/test/CodeGen/WebAssembly/switch.ll +++ b/llvm/test/CodeGen/WebAssembly/switch.ll @@ -21,20 +21,20 @@ ; CHECK: block {{$}} ; CHECK: block {{$}} ; CHECK: block {{$}} -; CHECK: br_table {{[^,]+}}, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 4, 5, 0{{$}} -; CHECK: .LBB{{[0-9]+}}_2: +; CHECK: br_table {{[^,]+}}, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 4, 5, 6{{$}} +; CHECK: .LBB{{[0-9]+}}_1: ; CHECK: call foo0{{$}} -; CHECK: .LBB{{[0-9]+}}_3: +; CHECK: .LBB{{[0-9]+}}_2: ; CHECK: call foo1{{$}} -; CHECK: .LBB{{[0-9]+}}_4: +; CHECK: .LBB{{[0-9]+}}_3: ; CHECK: call foo2{{$}} -; CHECK: .LBB{{[0-9]+}}_5: +; CHECK: .LBB{{[0-9]+}}_4: ; CHECK: call foo3{{$}} -; CHECK: .LBB{{[0-9]+}}_6: +; CHECK: .LBB{{[0-9]+}}_5: ; CHECK: call foo4{{$}} -; CHECK: .LBB{{[0-9]+}}_7: +; CHECK: .LBB{{[0-9]+}}_6: ; CHECK: call foo5{{$}} -; CHECK: .LBB{{[0-9]+}}_8: +; CHECK: .LBB{{[0-9]+}}_7: ; CHECK: return{{$}} define void @bar32(i32 %n) { entry: @@ -101,20 +101,20 @@ ; CHECK: block {{$}} ; CHECK: block {{$}} ; CHECK: block {{$}} -; CHECK: br_table {{[^,]+}}, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 4, 5, 0{{$}} -; CHECK: .LBB{{[0-9]+}}_2: +; CHECK: br_table {{[^,]+}}, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 4, 5, 6{{$}} +; CHECK: .LBB{{[0-9]+}}_1: ; CHECK: call foo0{{$}} -; CHECK: .LBB{{[0-9]+}}_3: +; CHECK: .LBB{{[0-9]+}}_2: ; CHECK: call foo1{{$}} -; CHECK: .LBB{{[0-9]+}}_4: +; CHECK: .LBB{{[0-9]+}}_3: ; CHECK: call foo2{{$}} -; CHECK: .LBB{{[0-9]+}}_5: +; CHECK: .LBB{{[0-9]+}}_4: ; CHECK: call foo3{{$}} -; CHECK: .LBB{{[0-9]+}}_6: +; CHECK: .LBB{{[0-9]+}}_5: ; CHECK: call foo4{{$}} -; CHECK: .LBB{{[0-9]+}}_7: +; CHECK: .LBB{{[0-9]+}}_6: ; CHECK: call foo5{{$}} -; CHECK: .LBB{{[0-9]+}}_8: +; CHECK: .LBB{{[0-9]+}}_7: ; CHECK: return{{$}} define void @bar64(i64 %n) { entry: