diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp @@ -644,19 +644,29 @@ SmallVector ToErase; LLVMContext &C = LongjmpF->getParent()->getContext(); IRBuilder<> IRB(C); + + // For calls to longjmp, replace it with emscripten_longjmp and cast its first + // argument (jmp_buf*) to int for (User *U : LongjmpF->users()) { auto *CI = dyn_cast(U); - if (!CI) - report_fatal_error("Does not support indirect calls to longjmp"); - IRB.SetInsertPoint(CI); - Value *Jmpbuf = - IRB.CreatePtrToInt(CI->getArgOperand(0), IRB.getInt32Ty(), "jmpbuf"); - IRB.CreateCall(EmLongjmpF, {Jmpbuf, CI->getArgOperand(1)}); - ToErase.push_back(CI); + if (CI && CI->getCalledFunction() == LongjmpF) { + IRB.SetInsertPoint(CI); + Value *Jmpbuf = + IRB.CreatePtrToInt(CI->getArgOperand(0), IRB.getInt32Ty(), "jmpbuf"); + IRB.CreateCall(EmLongjmpF, {Jmpbuf, CI->getArgOperand(1)}); + ToErase.push_back(CI); + } } - for (auto *I : ToErase) I->eraseFromParent(); + + // If we have any remaining uses of longjmp's function pointer, replace it + // with (int(*)(jmp_buf*, int))emscripten_longjmp. + if (!LongjmpF->uses().empty()) { + Value *EmLongjmp = + IRB.CreateBitCast(EmLongjmpF, LongjmpF->getType(), "em_longjmp"); + LongjmpF->replaceAllUsesWith(EmLongjmp); + } } bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) { diff --git a/llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll b/llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll --- a/llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll +++ b/llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll @@ -11,6 +11,8 @@ ; NO-TLS-DAG: __threwValue = external global i32 ; TLS-DAG: __THREW__ = external thread_local(localexec) global i32 ; TLS-DAG: __threwValue = external thread_local(localexec) global i32 +@global_longjmp_ptr = global void (%struct.__jmp_buf_tag*, i32)* @longjmp, align 4 +; CHECK-DAG: @global_longjmp_ptr = global void (%struct.__jmp_buf_tag*, i32)* bitcast (void (i32, i32)* @emscripten_longjmp to void (%struct.__jmp_buf_tag*, i32)*) ; Test a simple setjmp - longjmp sequence define void @setjmp_longjmp() { @@ -250,6 +252,36 @@ ; CHECK: %var[[VARNO]] = phi i32 [ %var, %for.inc ] } +; Tests cases where longjmp function pointer is used in other ways than direct +; calls. longjmps should be replaced with +; (int(*)(jmp_buf*, int))emscripten_longjmp. +declare void @take_longjmp(void (%struct.__jmp_buf_tag*, i32)* %arg_ptr) +define void @indirect_longjmp() { +; CHECK-LABEL: @indirect_longjmp +entry: + %local_longjmp_ptr = alloca void (%struct.__jmp_buf_tag*, i32)*, align 4 + %buf0 = alloca [1 x %struct.__jmp_buf_tag], align 16 + %buf1 = alloca [1 x %struct.__jmp_buf_tag], align 16 + + ; Store longjmp in a local variable, load it, and call it + store void (%struct.__jmp_buf_tag*, i32)* @longjmp, void (%struct.__jmp_buf_tag*, i32)** %local_longjmp_ptr, align 4 + ; CHECK: store void (%struct.__jmp_buf_tag*, i32)* bitcast (void (i32, i32)* @emscripten_longjmp to void (%struct.__jmp_buf_tag*, i32)*), void (%struct.__jmp_buf_tag*, i32)** %local_longjmp_ptr, align 4 + %longjmp_from_local_ptr = load void (%struct.__jmp_buf_tag*, i32)*, void (%struct.__jmp_buf_tag*, i32)** %local_longjmp_ptr, align 4 + %arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf0, i32 0, i32 0 + call void %longjmp_from_local_ptr(%struct.__jmp_buf_tag* %arraydecay, i32 0) + + ; Load longjmp from a global variable and call it + %longjmp_from_global_ptr = load void (%struct.__jmp_buf_tag*, i32)*, void (%struct.__jmp_buf_tag*, i32)** @global_longjmp_ptr, align 4 + %arraydecay1 = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf1, i32 0, i32 0 + call void %longjmp_from_global_ptr(%struct.__jmp_buf_tag* %arraydecay1, i32 0) + + ; Pass longjmp as a function argument. This is a call but longjmp is not a + ; callee but an argument. + call void @take_longjmp(void (%struct.__jmp_buf_tag*, i32)* @longjmp) + ; CHECK: call void @take_longjmp(void (%struct.__jmp_buf_tag*, i32)* bitcast (void (i32, i32)* @emscripten_longjmp to void (%struct.__jmp_buf_tag*, i32)*)) + ret void +} + declare void @foo() ; Function Attrs: returns_twice declare i32 @setjmp(%struct.__jmp_buf_tag*) #0