diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -644,13 +644,31 @@ if (CLI.IsPatchPoint) fail(DL, DAG, "WebAssembly doesn't support patch point yet"); - // Fail if tail calls are required but not enabled - if (!Subtarget->hasTailCall()) { - if ((CallConv == CallingConv::Fast && CLI.IsTailCall && - MF.getTarget().Options.GuaranteedTailCallOpt) || - (CLI.CS && CLI.CS.isMustTailCall())) - fail(DL, DAG, "WebAssembly 'tail-call' feature not enabled"); - CLI.IsTailCall = false; + if (CLI.IsTailCall) { + bool MustTail = (CallConv == CallingConv::Fast && CLI.IsTailCall && + MF.getTarget().Options.GuaranteedTailCallOpt) || + (CLI.CS && CLI.CS.isMustTailCall()); + if (Subtarget->hasTailCall()) { + // Do not tail call unless caller and callee return types match + const Function &F = MF.getFunction(); + const TargetMachine &TM = getTargetMachine(); + Type *RetTy = F.getReturnType(); + SmallVector CallerRetTys; + SmallVector CalleeRetTys; + computeLegalValueVTs(F, TM, RetTy, CallerRetTys); + computeLegalValueVTs(F, TM, CLI.RetTy, CalleeRetTys); + bool TypesMatch = CallerRetTys.size() == CalleeRetTys.size() && + std::equal(CallerRetTys.begin(), CallerRetTys.end(), + CalleeRetTys.begin()); + if (!TypesMatch) { + assert(!MustTail); + CLI.IsTailCall = false; + } + } else { + CLI.IsTailCall = false; + if (MustTail) + fail(DL, DAG, "WebAssembly 'tail-call' feature not enabled"); + } } SmallVectorImpl &Ins = CLI.Ins; diff --git a/llvm/test/CodeGen/WebAssembly/tailcall.ll b/llvm/test/CodeGen/WebAssembly/tailcall.ll --- a/llvm/test/CodeGen/WebAssembly/tailcall.ll +++ b/llvm/test/CodeGen/WebAssembly/tailcall.ll @@ -124,6 +124,23 @@ ret i32 %v } +; CHECK-LABEL: mismatched_return_void: +; CHECK: i32.call $drop=, baz, $pop{{[0-9]+}}, $pop{{[0-9]+}}, $pop{{[0-9]+}}{{$}} +; CHECK: return{{$}} +define void @mismatched_return_void() { + %v = tail call i32 @baz(i32 0, i32 42, i32 6) + ret void +} + +; CHECK-LABEL: mismatched_return_f32: +; CHECK: i32.call $drop=, baz, $pop{{[0-9]+}}, $pop{{[0-9]+}}, $pop{{[0-9]+}}{{$}} +; CHECK: f32.const $push[[L:[0-9]+]]=, 0x1p0{{$}} +; CHECK: return $pop[[L]]{{$}} +define float @mismatched_return_f32() { + %v = tail call i32 @baz(i32 0, i32 42, i32 6) + ret float 1. +} + ; CHECK-LABEL: mismatched_byval: ; CHECK: i32.store ; CHECK: return_call quux, $pop{{[0-9]+}}{{$}}