diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -648,32 +648,51 @@ fail(DL, DAG, "WebAssembly doesn't support patch point yet"); if (CLI.IsTailCall) { - bool MustTail = CLI.CS && CLI.CS.isMustTailCall(); - if (Subtarget->hasTailCall() && !CLI.IsVarArg) { - // Do not tail call unless caller and callee return types match - const Function &F = MF.getFunction(); - const TargetMachine &TM = getTargetMachine(); - Type *RetTy = F.getReturnType(); - SmallVector CallerRetTys; - SmallVector CalleeRetTys; - computeLegalValueVTs(F, TM, RetTy, CallerRetTys); - computeLegalValueVTs(F, TM, CLI.RetTy, CalleeRetTys); - bool TypesMatch = CallerRetTys.size() == CalleeRetTys.size() && - std::equal(CallerRetTys.begin(), CallerRetTys.end(), - CalleeRetTys.begin()); - if (!TypesMatch) { - // musttail in this case would be an LLVM IR validation failure - assert(!MustTail); - CLI.IsTailCall = false; - } - } else { + auto NoTail = [&](const char *Msg) { + if (CLI.CS && CLI.CS.isMustTailCall()) + fail(DL, DAG, Msg); CLI.IsTailCall = false; - if (MustTail) { - if (CLI.IsVarArg) { - // The return would pop the argument buffer - fail(DL, DAG, "WebAssembly does not support varargs tail calls"); - } else { - fail(DL, DAG, "WebAssembly 'tail-call' feature not enabled"); + }; + + if (!Subtarget->hasTailCall()) + NoTail("WebAssembly 'tail-call' feature not enabled"); + + // Varargs calls cannot be tail calls because the buffer is on the stack + if (CLI.IsVarArg) + NoTail("WebAssembly does not support varargs tail calls"); + + // Do not tail call unless caller and callee return types match + const Function &F = MF.getFunction(); + const TargetMachine &TM = getTargetMachine(); + Type *RetTy = F.getReturnType(); + SmallVector CallerRetTys; + SmallVector CalleeRetTys; + computeLegalValueVTs(F, TM, RetTy, CallerRetTys); + computeLegalValueVTs(F, TM, CLI.RetTy, CalleeRetTys); + bool TypesMatch = CallerRetTys.size() == CalleeRetTys.size() && + std::equal(CallerRetTys.begin(), CallerRetTys.end(), + CalleeRetTys.begin()); + if (!TypesMatch) + NoTail("WebAssembly tail call requires caller and callee return types to " + "match"); + + // If pointers to local stack values are passed, we cannot tail call + if (CLI.CS) { + for (auto &Arg : CLI.CS.args()) { + Value *Val = Arg.get(); + // Trace the value back through pointer operations + while (true) { + Value *Src = Val->stripPointerCastsAndAliases(); + if (auto *GEP = dyn_cast(Src)) + Src = GEP->getPointerOperand(); + if (Val == Src) + break; + Val = Src; + } + if (isa(Val)) { + NoTail( + "WebAssembly does not support tail calling with stack arguments"); + break; } } } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td @@ -74,7 +74,7 @@ [(WebAssemblycall0 (i32 imm:$callee))], "call \t$callee", "call\t$callee", 0x10>; -let isReturn = 1 in +let isReturn = 1, isTerminator = 1, hasCtrlDep = 1, isBarrier = 1 in defm RET_CALL : I<(outs), (ins function32_op:$callee, variable_ops), (outs), (ins function32_op:$callee), diff --git a/llvm/test/CodeGen/WebAssembly/tailcall.ll b/llvm/test/CodeGen/WebAssembly/tailcall.ll --- a/llvm/test/CodeGen/WebAssembly/tailcall.ll +++ b/llvm/test/CodeGen/WebAssembly/tailcall.ll @@ -209,7 +209,37 @@ ret i1 %u } +; Stack-allocated arguments inhibit tail calls +; CHECK-LABEL: stack_arg: +; CHECK: i32.call +define i32 @stack_arg(i32* %x) { + %a = alloca i32 + %v = tail call i32 @stack_arg(i32* %a) + ret i32 %v +} + +; CHECK-LABEL: stack_arg_gep: +; CHECK: i32.call +define i32 @stack_arg_gep(i32* %x) { + %a = alloca { i32, i32 } + %p = getelementptr { i32, i32 }, { i32, i32 }* %a, i32 0, i32 1 + %v = tail call i32 @stack_arg_gep(i32* %p) + ret i32 %v +} + +; CHECK-LABEL: stack_arg_cast: +; CHECK: global.get $push{{[0-9]+}}=, __stack_pointer +; CHECK: global.set __stack_pointer, $pop{{[0-9]+}} +; FAST: i32.call ${{[0-9]+}}=, stack_arg_cast, $pop{{[0-9]+}} +; CHECK: global.set __stack_pointer, $pop{{[0-9]+}} +; SLOW: return_call stack_arg_cast, ${{[0-9]+}} +define i32 @stack_arg_cast(i32 %x) { + %a = alloca [64 x i32] + %i = ptrtoint [64 x i32]* %a to i32 + %v = tail call i32 @stack_arg_cast(i32 %i) + ret i32 %v +} ; Check that the signatures generated for external indirectly ; return-called functions include the proper return types