diff --git a/llvm/docs/CodeGenerator.rst b/llvm/docs/CodeGenerator.rst --- a/llvm/docs/CodeGenerator.rst +++ b/llvm/docs/CodeGenerator.rst @@ -2090,9 +2090,14 @@ * On ppc32/64 GOT/PIC only module-local calls (visibility = hidden or protected) are supported. -On WebAssembly, tail calls are lowered to ``return_call`` and -``return_call_indirect`` instructions whenever the 'tail-call' target attribute -is enabled. +WebAssembly constraints: + +* No variable argument lists are used + +* The 'tail-call' target attribute is enabled. + +* The caller and callee's return types must match. The caller cannot + be void unless the callee is, too. Example: diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -644,13 +644,36 @@ if (CLI.IsPatchPoint) fail(DL, DAG, "WebAssembly doesn't support patch point yet"); - // Fail if tail calls are required but not enabled - if (!Subtarget->hasTailCall()) { - if ((CallConv == CallingConv::Fast && CLI.IsTailCall && - MF.getTarget().Options.GuaranteedTailCallOpt) || - (CLI.CS && CLI.CS.isMustTailCall())) - fail(DL, DAG, "WebAssembly 'tail-call' feature not enabled"); - CLI.IsTailCall = false; + if (CLI.IsTailCall) { + bool MustTail = CLI.CS && CLI.CS.isMustTailCall(); + if (Subtarget->hasTailCall() && !CLI.IsVarArg) { + // Do not tail call unless caller and callee return types match + const Function &F = MF.getFunction(); + const TargetMachine &TM = getTargetMachine(); + Type *RetTy = F.getReturnType(); + SmallVector CallerRetTys; + SmallVector CalleeRetTys; + computeLegalValueVTs(F, TM, RetTy, CallerRetTys); + computeLegalValueVTs(F, TM, CLI.RetTy, CalleeRetTys); + bool TypesMatch = CallerRetTys.size() == CalleeRetTys.size() && + std::equal(CallerRetTys.begin(), CallerRetTys.end(), + CalleeRetTys.begin()); + if (!TypesMatch) { + // musttail in this case would be an LLVM IR validation failure + assert(!MustTail); + CLI.IsTailCall = false; + } + } else { + CLI.IsTailCall = false; + if (MustTail) { + if (CLI.IsVarArg) { + // The return would pop the argument buffer + fail(DL, DAG, "WebAssembly does not support varargs tail calls"); + } else { + fail(DL, DAG, "WebAssembly 'tail-call' feature not enabled"); + } + } + } } SmallVectorImpl &Ins = CLI.Ins; diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp @@ -226,6 +226,17 @@ if (WebAssembly::isCallIndirect(MI->getOpcode())) Params.pop_back(); + // return_call_indirect instructions have the return type of the + // caller + if (MI->getOpcode() == WebAssembly::RET_CALL_INDIRECT) { + const Function &F = MI->getMF()->getFunction(); + const TargetMachine &TM = MI->getMF()->getTarget(); + Type *RetTy = F.getReturnType(); + SmallVector CallerRetTys; + computeLegalValueVTs(F, TM, RetTy, CallerRetTys); + valTypesFromMVTs(CallerRetTys, Returns); + } + auto *WasmSym = cast(Sym); auto Signature = make_unique(std::move(Returns), std::move(Params)); diff --git a/llvm/test/CodeGen/WebAssembly/tailcall.ll b/llvm/test/CodeGen/WebAssembly/tailcall.ll --- a/llvm/test/CodeGen/WebAssembly/tailcall.ll +++ b/llvm/test/CodeGen/WebAssembly/tailcall.ll @@ -1,7 +1,8 @@ ; RUN: llc < %s -asm-verbose=false -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+tail-call | FileCheck --check-prefixes=CHECK,SLOW %s ; RUN: llc < %s -asm-verbose=false -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -fast-isel -mattr=+tail-call | FileCheck --check-prefixes=CHECK,FAST %s +; RUN: llc < %s --filetype=obj -mattr=+tail-call | obj2yaml | FileCheck --check-prefix=YAML %s -; Test that the tail-call attribute is accepted +; Test that the tail calls lower correctly target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128" target triple = "wasm32-unknown-unknown" @@ -124,6 +125,44 @@ ret i32 %v } +; CHECK-LABEL: mismatched_return_void: +; CHECK: i32.call $drop=, baz, $pop{{[0-9]+}}, $pop{{[0-9]+}}, $pop{{[0-9]+}}{{$}} +; CHECK: return{{$}} +define void @mismatched_return_void() { + %v = tail call i32 @baz(i32 0, i32 42, i32 6) + ret void +} + +; CHECK-LABEL: mismatched_return_f32: +; CHECK: i32.call $push[[L:[0-9]+]]=, baz, $pop{{[0-9]+}}, $pop{{[0-9]+}}, $pop{{[0-9]+}}{{$}} +; CHECK: f32.reinterpret_i32 $push[[L1:[0-9]+]]=, $pop[[L]]{{$}} +; CHECK: return $pop[[L1]]{{$}} +define float @mismatched_return_f32() { + %v = tail call i32 @baz(i32 0, i32 42, i32 6) + %u = bitcast i32 %v to float + ret float %u +} + +; CHECK-LABEL: mismatched_indirect_void: +; CHECK: i32.call_indirect $drop=, $0, $1, $2, $0{{$}} +; CHECK: return{{$}} +define void @mismatched_indirect_void(%fn %f, i32 %x, i32 %y) { + %p = extractvalue %fn %f, 0 + %v = tail call i32 %p(%fn %f, i32 %x, i32 %y) + ret void +} + +; CHECK-LABEL: mismatched_indirect_f32: +; CHECK: i32.call_indirect $push[[L:[0-9]+]]=, $0, $1, $2, $0{{$}} +; CHECK: f32.reinterpret_i32 $push[[L1:[0-9]+]]=, $pop[[L]]{{$}} +; CHECK: return $pop[[L1]]{{$}} +define float @mismatched_indirect_f32(%fn %f, i32 %x, i32 %y) { + %p = extractvalue %fn %f, 0 + %v = tail call i32 %p(%fn %f, i32 %x, i32 %y) + %u = bitcast i32 %v to float + ret float %u +} + ; CHECK-LABEL: mismatched_byval: ; CHECK: i32.store ; CHECK: return_call quux, $pop{{[0-9]+}}{{$}} @@ -135,13 +174,59 @@ ; CHECK-LABEL: varargs: ; CHECK: i32.store -; CHECK: return_call var, $1{{$}} +; CHECK: i32.call $0=, var, $1{{$}} +; CHECK: return $0{{$}} declare i32 @var(...) define i32 @varargs(i32 %x) { %v = tail call i32 (...) @var(i32 %x) ret i32 %v } +; Type transformations inhibit tail calls, even when they are nops + +; CHECK-LABEL: mismatched_return_zext: +; CHECK: i32.call +define i32 @mismatched_return_zext() { + %v = tail call i1 @foo(i1 1) + %u = zext i1 %v to i32 + ret i32 %u +} + +; CHECK-LABEL: mismatched_return_sext: +; CHECK: i32.call +define i32 @mismatched_return_sext() { + %v = tail call i1 @foo(i1 1) + %u = sext i1 %v to i32 + ret i32 %u +} + +; CHECK-LABEL: mismatched_return_trunc: +; CHECK: i32.call +declare i32 @int() +define i1 @mismatched_return_trunc() { + %v = tail call i32 @int() + %u = trunc i32 %v to i1 + ret i1 %u +} + + + +; Check that the signatures generated for external indirectly +; return-called functions include the proper return types + +; YAML-LABEL: - Index: 8 +; YAML-NEXT: ReturnType: I32 +; YAML-NEXT: ParamTypes: +; YAML-NEXT: - I32 +; YAML-NEXT: - F32 +; YAML-NEXT: - I64 +; YAML-NEXT: - F64 +define i32 @unique_caller(i32 (i32, float, i64, double)** %p) { + %f = load i32 (i32, float, i64, double)*, i32 (i32, float, i64, double)** %p + %v = tail call i32 %f(i32 0, float 0., i64 0, double 0.) + ret i32 %v +} + ; CHECK-LABEL: .section .custom_section.target_features ; CHECK-NEXT: .int8 1 ; CHECK-NEXT: .int8 43