diff --git a/llvm/lib/CodeGen/AsmPrinter/WasmException.cpp b/llvm/lib/CodeGen/AsmPrinter/WasmException.cpp --- a/llvm/lib/CodeGen/AsmPrinter/WasmException.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/WasmException.cpp @@ -18,16 +18,18 @@ using namespace llvm; void WasmException::endModule() { - // This is the symbol used in 'throw' and 'catch' instruction to denote this - // is a C++ exception. This symbol has to be emitted somewhere once in the - // module. Check if the symbol has already been created, i.e., we have at - // least one 'throw' or 'catch' instruction in the module, and emit the symbol - // only if so. - SmallString<60> NameStr; - Mangler::getNameWithPrefix(NameStr, "__cpp_exception", Asm->getDataLayout()); - if (Asm->OutContext.lookupSymbol(NameStr)) { - MCSymbol *ExceptionSym = Asm->GetExternalSymbolSymbol("__cpp_exception"); - Asm->OutStreamer->emitLabel(ExceptionSym); + // These are symbols used to throw/catch C++ exceptions and C longjmps. These + // symbols have to be emitted somewhere once in the module. Check if each of + // the symbols has already been created, i.e., we have at least one 'throw' or + // 'catch' instruction with the symbol in the module, and emit the symbol only + // if so. + for (const char *SymName : {"__cpp_exception", "__c_longjmp"}) { + SmallString<60> NameStr; + Mangler::getNameWithPrefix(NameStr, SymName, Asm->getDataLayout()); + if (Asm->OutContext.lookupSymbol(NameStr)) { + MCSymbol *ExceptionSym = Asm->GetExternalSymbolSymbol(SymName); + Asm->OutStreamer->emitLabel(ExceptionSym); + } } } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -1226,7 +1226,10 @@ bool IsSingleCatchAllClause = CPI->getNumArgOperands() == 1 && cast(CPI->getArgOperand(0))->isNullValue(); - if (!IsSingleCatchAllClause) { + // cathchpads for longjmp use an empty type list, e.g. catchpad within %0 [] + // and they don't need LSDA info + bool IsCatchLongjmp = CPI->getNumArgOperands() == 0; + if (!IsSingleCatchAllClause && !IsCatchLongjmp) { // Create a mapping from landing pad label to landing pad index. bool IntrFound = false; for (const User *U : CPI->users()) { diff --git a/llvm/lib/MC/WasmObjectWriter.cpp b/llvm/lib/MC/WasmObjectWriter.cpp --- a/llvm/lib/MC/WasmObjectWriter.cpp +++ b/llvm/lib/MC/WasmObjectWriter.cpp @@ -1644,7 +1644,8 @@ LLVM_DEBUG(dbgs() << " -> table index: " << WasmIndices.find(&WS)->second << "\n"); } else if (WS.isTag()) { - // C++ exception symbol (__cpp_exception) + // C++ exception symbol (__cpp_exception) or longjmp symbol + // (__c_longjmp) unsigned Index; if (WS.isDefined()) { Index = NumTagImports + Tags.size(); diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp @@ -236,7 +236,7 @@ SmallVector Returns; SmallVector Params; - if (Name == "__cpp_exception") { + if (Name == "__cpp_exception" || Name == "__c_longjmp") { WasmSym->setType(wasm::WASM_SYMBOL_TYPE_TAG); // We can't confirm its signature index for now because there can be // imported exceptions. Set it to be 0 for now. @@ -248,12 +248,14 @@ WasmSym->setWeak(true); WasmSym->setExternal(true); - // All C++ exceptions are assumed to have a single i32 (for wasm32) or i64 - // (for wasm64) param type and void return type. The reaon is, all C++ - // exception values are pointers, and to share the type section with - // functions, exceptions are assumed to have void return type. - Params.push_back(Subtarget.hasAddr64() ? wasm::ValType::I64 - : wasm::ValType::I32); + // Currently both C++ exceptions and C longjmps have a single pointer type + // param. For C++ exceptions it is a pointer to an exception object, and for + // C longjmps it is pointer to a struct that contains a setjmp buffer and a + // longjmp return value. We may consider using multiple value parameters for + // longjmps later when multivalue support is ready. + wasm::ValType AddrType = + Subtarget.hasAddr64() ? wasm::ValType::I64 : wasm::ValType::I32; + Params.push_back(AddrType); } else { // Function symbols WasmSym->setType(wasm::WASM_SYMBOL_TYPE_FUNCTION); getLibcallSignature(Subtarget, Name, Returns, Params); diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp @@ -89,11 +89,13 @@ } static SDValue getTagSymNode(int Tag, SelectionDAG *DAG) { - assert(Tag == WebAssembly::CPP_EXCEPTION); + assert(Tag == WebAssembly::CPP_EXCEPTION || WebAssembly::C_LONGJMP); auto &MF = DAG->getMachineFunction(); const auto &TLI = DAG->getTargetLoweringInfo(); MVT PtrVT = TLI.getPointerTy(DAG->getDataLayout()); - const char *SymName = MF.createExternalSymbolName("__cpp_exception"); + const char *SymName = Tag == WebAssembly::CPP_EXCEPTION + ? MF.createExternalSymbolName("__cpp_exception") + : MF.createExternalSymbolName("__c_longjmp"); return DAG->getTargetExternalSymbol(SymName, PtrVT); } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp @@ -7,15 +7,12 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file lowers exception-related instructions and setjmp/longjmp -/// function calls in order to use Emscripten's JavaScript try and catch -/// mechanism. +/// This file lowers exception-related instructions and setjmp/longjmp function +/// calls to use Emscripten's library functions. The pass uses JavaScript's try +/// and catch mechanism in case of Emscripten EH/SjLj and Wasm EH intrinsics in +/// case of Emscripten SjLJ. /// -/// To handle exceptions and setjmp/longjmps, this scheme relies on JavaScript's -/// try and catch syntax and relevant exception-related libraries implemented -/// in JavaScript glue code that will be produced by Emscripten. -/// -/// * Exception handling +/// * Emscripten exception handling /// This pass lowers invokes and landingpads into library functions in JS glue /// code. Invokes are lowered into function wrappers called invoke wrappers that /// exist in JS side, which wraps the original function call with JS try-catch. @@ -23,7 +20,7 @@ /// variables (see below) so we can check whether an exception occurred from /// wasm code and handle it appropriately. /// -/// * Setjmp-longjmp handling +/// * Emscripten setjmp-longjmp handling /// This pass lowers setjmp to a reasonably-performant approach for emscripten. /// The idea is that each block with a setjmp is broken up into two parts: the /// part containing setjmp and the part right after the setjmp. The latter part @@ -52,7 +49,7 @@ /// __threwValue is 0 for exceptions, and the argument to longjmp in case of /// longjmp. /// -/// * Exception handling +/// * Emscripten exception handling /// /// 2) We assume the existence of setThrew and setTempRet0/getTempRet0 functions /// at link time. setThrew exists in Emscripten's compiler-rt: @@ -121,16 +118,16 @@ /// call @llvm_eh_typeid_for(type) /// llvm_eh_typeid_for function will be generated in JS glue code. /// -/// * Setjmp / Longjmp handling +/// * Emscripten setjmp / longjmp handling /// -/// In case calls to longjmp() exists +/// If there are calls to longjmp() /// /// 1) Lower /// longjmp(env, val) /// into /// emscripten_longjmp(env, val) /// -/// In case calls to setjmp() exists +/// If there are calls to setjmp() /// /// 2) In the function entry that calls setjmp, initialize setjmpTable and /// sejmpTableSize as follows: @@ -154,7 +151,6 @@ /// the buffer 'env'. A BB with setjmp is split into two after setjmp call in /// order to make the post-setjmp BB the possible destination of longjmp BB. /// -/// /// 4) Lower every call that might longjmp into /// __THREW__ = 0; /// call @__invoke_SIG(func, arg1, arg2) @@ -171,7 +167,7 @@ /// %label = -1; /// } /// longjmp_result = getTempRet0(); -/// switch label { +/// switch %label { /// label 1: goto post-setjmp BB 1 /// label 2: goto post-setjmp BB 2 /// ... @@ -188,15 +184,98 @@ /// occurred. Otherwise we jump to the right post-setjmp BB based on the /// label. /// +/// * Wasm setjmp / longjmp handling +/// This mode still uses some Emscripten library functions but not JavaScript's +/// try-catch mechanism. It instead uses Wasm exception handling intrinsics, +/// which will be lowered to exception handling instructions. +/// +/// If there are calls to longjmp() +/// +/// 1) Lower +/// longjmp(env, val) +/// into +/// __wasm_longjmp(env, val) +/// +/// If there are calls to setjmp() +/// +/// 2) and 3): The same as 2) and 3) in Emscripten SjLj. +/// (setjmpTable/setjmpTableSize initialization + setjmp callsite +/// transformation) +/// +/// 4) Create a catchpad with a wasm.catch() intrinsic, which returns the value +/// thrown by __wasm_longjmp function. In Emscripten library, we have this +/// struct: +/// +/// struct __WasmLongjmpArgs { +/// void *env; +/// int val; +/// }; +/// struct __WasmLongjmpArgs __wasm_longjmp_args; +/// +/// The thrown value here is a pointer to __wasm_longjmp_args struct object. We +/// use this struct to transfer two values by throwing a single value. Wasm +/// throw and catch instructions are capable of throwing and catching multiple +/// values, but it also requires multivalue support that is currently not very +/// reliable. +/// TODO Switch to throwing and catching two values without using the struct +/// +/// All longjmpable function calls will be converted to an invoke that will +/// unwind to this catchpad in case a longjmp occurs. Within the catchpad, we +/// test the thrown values using testSetjmp function as we do for Emscripten +/// SjLj. The main difference is, in Emscripten SjLj, we need to transform every +/// longjmpable callsite into a sequence of code including testSetjmp() call; in +/// Wasm SjLj we do the testing in only one place, in this catchpad. +/// +/// After testing calling testSetjmp(), if the longjmp does not correspond to +/// one of the setjmps within the current function, it rethrows the longjmp +/// by calling __wasm_longjmp(). If it corresponds to one of setjmps in the +/// function, we jump to the beginning of the function, which contains a switch +/// to each post-setjmp BB. Again, in Emscripten SjLj, this switch is added for +/// every longjmpable callsite; in Wasm SjLj we do this only once at the top of +/// the function. (after setjmpTable/setjmpTableSize initialization) +/// +/// The below is the pseudocode for what we have described +/// +/// entry: +/// Initialize setjmpTable and setjmpTableSize +/// +/// setjmp.dispatch: +/// switch %label { +/// label 1: goto post-setjmp BB 1 +/// label 2: goto post-setjmp BB 2 +/// ... +/// default: goto splitted next BB +/// } +/// ... +/// +/// bb: +/// invoke void @foo() ;; foo is a longjmpable function +/// to label %next unwind label %catch.dispatch.longjmp +/// ... +/// +/// catch.dispatch.longjmp: +/// %0 = catchswitch within none [label %catch.longjmp] unwind to caller +/// +/// catch.longjmp: +/// %longjmp.args = wasm.catch() ;; struct __WasmLongjmpArgs +/// %env = load 'env' field from __WasmLongjmpArgs +/// %val = load 'val' field from __WasmLongjmpArgs +/// %label = testSetjmp(mem[%env], setjmpTable, setjmpTableSize); +/// if (%label == 0) +/// __wasm_longjmp(%env, %val) +/// catchret to %setjmp.dispatch +/// ///===----------------------------------------------------------------------===// #include "WebAssembly.h" #include "WebAssemblyTargetMachine.h" #include "llvm/ADT/StringExtras.h" #include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/WasmEHFuncInfo.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsWebAssembly.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" @@ -236,6 +315,11 @@ Function *EmLongjmpF = nullptr; // emscripten_longjmp() (Emscripten) Function *SaveSetjmpF = nullptr; // saveSetjmp() (Emscripten) Function *TestSetjmpF = nullptr; // testSetjmp() (Emscripten) + Function *WasmLongjmpF = nullptr; // __wasm_longjmp() (Emscripten) + Function *CatchF = nullptr; // wasm.catch() (intrinsic) + + // type of 'struct __WasmLongjmpArgs' defined in emscripten + Type *LongjmpArgsTy = nullptr; // __cxa_find_matching_catch_N functions. // Indexed by the number of clauses in an original landingpad instruction. @@ -258,6 +342,10 @@ Function &F, InstVector &SetjmpTableInsts, InstVector &SetjmpTableSizeInsts, SmallVectorImpl &SetjmpRetPHIs); + void + handleLongjmpableCallsForWasmSjLj(Function &F, InstVector &SetjmpTableInsts, + InstVector &SetjmpTableSizeInsts, + SmallVectorImpl &SetjmpRetPHIs); Function *getFindMatchingCatch(Module &M, unsigned NumClauses); Value *wrapInvoke(CallBase *CI); @@ -274,6 +362,7 @@ return EnableEmEH && (areAllExceptionsAllowed() || EHAllowlistSet.count(std::string(F->getName()))); } + void replaceLongjmpWith(Function *LongjmpF, Function *NewF); void rebuildSSA(Function &F); @@ -654,11 +743,17 @@ void WebAssemblyLowerEmscriptenEHSjLj::rebuildSSA(Function &F) { DominatorTree &DT = getAnalysis(F).getDomTree(); DT.recalculate(F); // CFG has been changed + SSAUpdaterBulk SSA; for (BasicBlock &BB : F) { for (Instruction &I : BB) { unsigned VarID = SSA.AddVariable(I.getName(), I.getType()); - SSA.AddAvailableValue(VarID, &BB, &I); + // If a value is defined by an invoke instruction, it is only available in + // its normal destination and not in its unwind destination. + if (auto *II = dyn_cast(&I)) + SSA.AddAvailableValue(VarID, II->getNormalDest(), II); + else + SSA.AddAvailableValue(VarID, &BB, &I); for (auto &U : I.uses()) { auto *User = cast(U.getUser()); if (auto *UserPN = dyn_cast(User)) @@ -673,26 +768,36 @@ SSA.RewriteAllUses(&DT); } -// Replace uses of longjmp with emscripten_longjmp. emscripten_longjmp takes -// arguments of type {i32, i32} (wasm32) / {i64, i32} (wasm64) and longjmp takes -// {jmp_buf*, i32}, so we need a ptrtoint instruction here to make the type -// match. jmp_buf* will eventually be lowered to i32/i64 in the wasm backend. -static void replaceLongjmpWithEmscriptenLongjmp(Function *LongjmpF, - Function *EmLongjmpF) { +// Replace uses of longjmp with a new longjmp function in Emscripten library. +// In Emscripten SjLj, the new function is +// void emscripten_longjmp(uintptr_t, i32) +// In Wasm SjLj, the new function is +// void __wasm_longjmp(i8*, i32) +// Because the original libc longjmp function takes (jmp_buf*, i32), we need a +// ptrtoint/bitcast instruction here to make the type match. jmp_buf* will +// eventually be lowered to i32/i64 in the wasm backend. +void WebAssemblyLowerEmscriptenEHSjLj::replaceLongjmpWith(Function *LongjmpF, + Function *NewF) { + assert(NewF == EmLongjmpF || NewF == WasmLongjmpF); Module *M = LongjmpF->getParent(); SmallVector ToErase; LLVMContext &C = LongjmpF->getParent()->getContext(); IRBuilder<> IRB(C); - // For calls to longjmp, replace it with emscripten_longjmp and cast its first - // argument (jmp_buf*) to int + // For calls to longjmp, replace it with emscripten_longjmp/__wasm_longjmp and + // cast its first argument (jmp_buf*) appropriately for (User *U : LongjmpF->users()) { auto *CI = dyn_cast(U); if (CI && CI->getCalledFunction() == LongjmpF) { IRB.SetInsertPoint(CI); - Value *Env = - IRB.CreatePtrToInt(CI->getArgOperand(0), getAddrIntType(M), "env"); - IRB.CreateCall(EmLongjmpF, {Env, CI->getArgOperand(1)}); + Value *Env = nullptr; + if (NewF == EmLongjmpF) + Env = + IRB.CreatePtrToInt(CI->getArgOperand(0), getAddrIntType(M), "env"); + else // WasmLongjmpF + Env = + IRB.CreateBitCast(CI->getArgOperand(0), IRB.getInt8PtrTy(), "env"); + IRB.CreateCall(NewF, {Env, CI->getArgOperand(1)}); ToErase.push_back(CI); } } @@ -700,11 +805,11 @@ I->eraseFromParent(); // If we have any remaining uses of longjmp's function pointer, replace it - // with (int(*)(jmp_buf*, int))emscripten_longjmp. + // with (void(*)(jmp_buf*, int))emscripten_longjmp / __wasm_longjmp. if (!LongjmpF->uses().empty()) { - Value *EmLongjmp = - IRB.CreateBitCast(EmLongjmpF, LongjmpF->getType(), "em_longjmp"); - LongjmpF->replaceAllUsesWith(EmLongjmp); + Value *NewLongjmp = + IRB.CreateBitCast(NewF, LongjmpF->getType(), "longjmp.cast"); + LongjmpF->replaceAllUsesWith(NewLongjmp); } } @@ -759,38 +864,49 @@ EHTypeIDF = getEmscriptenFunction(EHTypeIDTy, "llvm_eh_typeid_for", &M); } - if (EnableEmSjLj && SetjmpF) { + if ((EnableEmSjLj || EnableWasmSjLj) && SetjmpF) { // Precompute setjmp users for (User *U : SetjmpF->users()) { - Function *UserF = cast(U)->getFunction(); - // If a function that calls setjmp does not contain any other calls that - // can longjmp, we don't need to do any transformation on that function, - // so can ignore it - if (containsLongjmpableCalls(UserF)) - SetjmpUsers.insert(UserF); + if (auto *UI = dyn_cast(U)) { + auto *UserF = UI->getFunction(); + // If a function that calls setjmp does not contain any other calls that + // can longjmp, we don't need to do any transformation on that function, + // so can ignore it + if (containsLongjmpableCalls(UserF)) + SetjmpUsers.insert(UserF); + } } } bool SetjmpUsed = SetjmpF && !SetjmpUsers.empty(); bool LongjmpUsed = LongjmpF && !LongjmpF->use_empty(); - DoSjLj = EnableEmSjLj && (SetjmpUsed || LongjmpUsed); + DoSjLj = (EnableEmSjLj | EnableWasmSjLj) && (SetjmpUsed || LongjmpUsed); // Function registration and data pre-gathering for setjmp/longjmp handling if (DoSjLj) { assert(EnableEmSjLj || EnableWasmSjLj); - // Register emscripten_longjmp function - FunctionType *FTy = FunctionType::get( - IRB.getVoidTy(), {getAddrIntType(&M), IRB.getInt32Ty()}, false); - EmLongjmpF = getEmscriptenFunction(FTy, "emscripten_longjmp", &M); - EmLongjmpF->addFnAttr(Attribute::NoReturn); + if (EnableEmSjLj) { + // Register emscripten_longjmp function + FunctionType *FTy = FunctionType::get( + IRB.getVoidTy(), {getAddrIntType(&M), IRB.getInt32Ty()}, false); + EmLongjmpF = getEmscriptenFunction(FTy, "emscripten_longjmp", &M); + EmLongjmpF->addFnAttr(Attribute::NoReturn); + } else { // EnableWasmSjLj + // Register __wasm_longjmp function, which calls __builtin_wasm_longjmp. + FunctionType *FTy = FunctionType::get( + IRB.getVoidTy(), {IRB.getInt8PtrTy(), IRB.getInt32Ty()}, false); + WasmLongjmpF = getEmscriptenFunction(FTy, "__wasm_longjmp", &M); + WasmLongjmpF->addFnAttr(Attribute::NoReturn); + } if (SetjmpF) { // Register saveSetjmp function FunctionType *SetjmpFTy = SetjmpF->getFunctionType(); - FTy = FunctionType::get(Type::getInt32PtrTy(C), - {SetjmpFTy->getParamType(0), IRB.getInt32Ty(), - Type::getInt32PtrTy(C), IRB.getInt32Ty()}, - false); + FunctionType *FTy = + FunctionType::get(Type::getInt32PtrTy(C), + {SetjmpFTy->getParamType(0), IRB.getInt32Ty(), + Type::getInt32PtrTy(C), IRB.getInt32Ty()}, + false); SaveSetjmpF = getEmscriptenFunction(FTy, "saveSetjmp", &M); // Register testSetjmp function @@ -799,6 +915,14 @@ {getAddrIntType(&M), Type::getInt32PtrTy(C), IRB.getInt32Ty()}, false); TestSetjmpF = getEmscriptenFunction(FTy, "testSetjmp", &M); + + // wasm.catch() will be lowered down to wasm 'catch' instruction in + // instruction selection. + CatchF = Intrinsic::getDeclaration(&M, Intrinsic::wasm_catch); + // Type for struct __WasmLongjmpArgs + LongjmpArgsTy = StructType::get(IRB.getInt8PtrTy(), // env + IRB.getInt32Ty() // val + ); } } @@ -815,7 +939,7 @@ if (DoSjLj) { Changed = true; // We have setjmp or longjmp somewhere if (LongjmpF) - replaceLongjmpWithEmscriptenLongjmp(LongjmpF, EmLongjmpF); + replaceLongjmpWith(LongjmpF, EnableEmSjLj ? EmLongjmpF : WasmLongjmpF); // Only traverse functions that uses setjmp in order not to insert // unnecessary prep / cleanup code in every function if (SetjmpF) @@ -890,7 +1014,8 @@ // // tail: ;; Nothing happened or an exception is thrown // ... Continue exception handling ... - if (DoSjLj && !SetjmpUsers.count(&F) && canLongjmp(Callee)) { + if (DoSjLj && EnableEmSjLj && !SetjmpUsers.count(&F) && + canLongjmp(Callee)) { // Create longjmp.rethrow BB once and share it within the function if (!RethrowLongjmpBB) { RethrowLongjmpBB = BasicBlock::Create(C, "rethrow.longjmp", &F); @@ -1130,9 +1255,13 @@ ToErase.push_back(CI); } - // Handle longjmp calls. - handleLongjmpableCallsForEmscriptenSjLj(F, SetjmpTableInsts, - SetjmpTableSizeInsts, SetjmpRetPHIs); + // Handle longjmpable calls. + if (EnableEmSjLj) + handleLongjmpableCallsForEmscriptenSjLj( + F, SetjmpTableInsts, SetjmpTableSizeInsts, SetjmpRetPHIs); + else // EnableWasmSjLj + handleLongjmpableCallsForWasmSjLj(F, SetjmpTableInsts, SetjmpTableSizeInsts, + SetjmpRetPHIs); // Erase everything we no longer need in this function for (Instruction *I : ToErase) @@ -1152,7 +1281,7 @@ for (auto &I : BB) { if (auto *CI = dyn_cast(&I)) { bool IsNoReturn = CI->hasFnAttr(Attribute::NoReturn); - if (auto *CalleeF = dyn_cast(CI->getCalledOperand())) + if (Function *CalleeF = CI->getCalledFunction()) IsNoReturn |= CalleeF->hasFnAttribute(Attribute::NoReturn); if (IsNoReturn) ExitingInsts.push_back(&I); @@ -1161,7 +1290,13 @@ } for (auto *I : ExitingInsts) { DebugLoc DL = getOrCreateDebugLoc(I, F.getSubprogram()); - auto *Free = CallInst::CreateFree(SetjmpTable, I); + // If this existing instruction is a call within a catchpad, we should add + // it as "funclet" to the operand bundle of 'free' call + SmallVector Bundles; + if (auto *CB = dyn_cast(I)) + if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet)) + Bundles.push_back(OperandBundleDef(*Bundle)); + auto *Free = CallInst::CreateFree(SetjmpTable, Bundles, I); Free->setDebugLoc(DL); // CallInst::CreateFree may create a bitcast instruction if its argument // types mismatch. We need to set the debug loc for the bitcast too. @@ -1215,8 +1350,9 @@ return true; } -// Update each call that can longjmp so it can return to a setjmp where -// relevant. +// Update each call that can longjmp so it can return to the corresponding +// setjmp. Refer to 4) of "Emscripten setjmp/longjmp handling" section in the +// comments at top of the file for details. void WebAssemblyLowerEmscriptenEHSjLj::handleLongjmpableCallsForEmscriptenSjLj( Function &F, InstVector &SetjmpTableInsts, InstVector &SetjmpTableSizeInsts, SmallVectorImpl &SetjmpRetPHIs) { @@ -1402,3 +1538,181 @@ for (Instruction *I : ToErase) I->eraseFromParent(); } + +// Create a catchpad in which we catch a longjmp's env and val arguments, test +// if the longjmp corresponds to one of setjmps in the current function, and if +// so, jump to the setjmp dispatch BB from which we go to one of post-setjmp +// BBs. Refer to 4) of "Wasm setjmp/longjmp handling" section in the comments at +// top of the file for details. +void WebAssemblyLowerEmscriptenEHSjLj::handleLongjmpableCallsForWasmSjLj( + Function &F, InstVector &SetjmpTableInsts, InstVector &SetjmpTableSizeInsts, + SmallVectorImpl &SetjmpRetPHIs) { + Module &M = *F.getParent(); + LLVMContext &C = F.getContext(); + IRBuilder<> IRB(C); + + // A function with catchswitch/catchpad instruction should have a personality + // function attached to it. Search for the wasm personality function, and if + // it exists, use it, and if it doesn't, create a dummy personality function. + // (SjLj is not going to call it anyway.) + if (!F.hasPersonalityFn()) { + StringRef PersName = getEHPersonalityName(EHPersonality::Wasm_CXX); + FunctionType *PersType = + FunctionType::get(IRB.getInt32Ty(), /* isVarArg */ true); + Value *PersF = M.getOrInsertFunction(PersName, PersType).getCallee(); + F.setPersonalityFn( + cast(IRB.CreateBitCast(PersF, IRB.getInt8PtrTy()))); + } + + // Use the entry BB's debugloc as a fallback + BasicBlock *Entry = &F.getEntryBlock(); + DebugLoc FirstDL = getOrCreateDebugLoc(&*Entry->begin(), F.getSubprogram()); + IRB.SetCurrentDebugLocation(FirstDL); + + // Arbitrarily use the ones defined in the beginning of the function. + // SSAUpdater will later update them to the correct values. + Instruction *SetjmpTable = *SetjmpTableInsts.begin(); + Instruction *SetjmpTableSize = *SetjmpTableSizeInsts.begin(); + + // Add setjmp.dispatch BB right after the entry block. Because we have + // initialized setjmpTable/setjmpTableSize in the entry block and split the + // rest into another BB, here 'OrigEntry' is the function's original entry + // block before the transformation. + // + // entry: + // setjmpTable / setjmpTableSize initialization + // setjmp.dispatch: + // switch will be inserted here later + // entry.split: (OrigEntry) + // the original function starts here + BasicBlock *OrigEntry = Entry->getNextNode(); + BasicBlock *SetjmpDispatchBB = + BasicBlock::Create(C, "setjmp.dispatch", &F, OrigEntry); + cast(Entry->getTerminator())->setSuccessor(0, SetjmpDispatchBB); + + // Create catch.dispatch.longjmp BB a catchswitch instruction + BasicBlock *CatchSwitchBB = + BasicBlock::Create(C, "catch.dispatch.longjmp", &F); + IRB.SetInsertPoint(CatchSwitchBB); + CatchSwitchInst *CatchSwitch = + IRB.CreateCatchSwitch(ConstantTokenNone::get(C), nullptr, 1); + + // Create catch.longjmp BB and a catchpad instruction + BasicBlock *CatchLongjmpBB = BasicBlock::Create(C, "catch.longjmp", &F); + CatchSwitch->addHandler(CatchLongjmpBB); + IRB.SetInsertPoint(CatchLongjmpBB); + CatchPadInst *CatchPad = IRB.CreateCatchPad(CatchSwitch, {}); + + // Wasm throw and catch instructions can throw and catch multiple values, but + // that requires multivalue support in the toolchain, which is currently not + // very reliable. We instead throw and catch a pointer to a struct value of + // type 'struct __WasmLongjmpArgs', which is defined in Emscripten. + Instruction *CatchCI = + IRB.CreateCall(CatchF, {IRB.getInt32(WebAssembly::C_LONGJMP)}, "thrown"); + Value *LongjmpArgs = + IRB.CreateBitCast(CatchCI, LongjmpArgsTy->getPointerTo(), "longjmp.args"); + Value *EnvField = + IRB.CreateConstGEP2_32(LongjmpArgsTy, LongjmpArgs, 0, 0, "env_gep"); + Value *ValField = + IRB.CreateConstGEP2_32(LongjmpArgsTy, LongjmpArgs, 0, 1, "val_gep"); + // void *env = __wasm_longjmp_args.env; + Instruction *Env = IRB.CreateLoad(IRB.getInt8PtrTy(), EnvField, "env"); + // int val = __wasm_longjmp_args.val; + Instruction *Val = IRB.CreateLoad(IRB.getInt32Ty(), ValField, "val"); + + // %label = testSetjmp(mem[%env], setjmpTable, setjmpTableSize); + // if (%label == 0) + // __wasm_longjmp(%env, %val) + // catchret to %setjmp.dispatch + BasicBlock *ThenBB = BasicBlock::Create(C, "if.then", &F); + BasicBlock *EndBB = BasicBlock::Create(C, "if.end", &F); + Value *EnvP = IRB.CreateBitCast(Env, getAddrPtrType(&M), "env.p"); + Value *SetjmpID = IRB.CreateLoad(getAddrIntType(&M), EnvP, "setjmp.id"); + Value *Label = + IRB.CreateCall(TestSetjmpF, {SetjmpID, SetjmpTable, SetjmpTableSize}, + OperandBundleDef("funclet", CatchPad), "label"); + Value *Cmp = IRB.CreateICmpEQ(Label, IRB.getInt32(0)); + IRB.CreateCondBr(Cmp, ThenBB, EndBB); + + IRB.SetInsertPoint(ThenBB); + CallInst *WasmLongjmpCI = IRB.CreateCall( + WasmLongjmpF, {Env, Val}, OperandBundleDef("funclet", CatchPad)); + IRB.CreateUnreachable(); + + IRB.SetInsertPoint(EndBB); + // Jump to setjmp.dispatch block + IRB.CreateCatchRet(CatchPad, SetjmpDispatchBB); + + // Go back to setjmp.dispatch BB + // setjmp.dispatch: + // switch %label { + // label 1: goto post-setjmp BB 1 + // label 2: goto post-setjmp BB 2 + // ... + // default: goto splitted next BB + // } + IRB.SetInsertPoint(SetjmpDispatchBB); + PHINode *LabelPHI = IRB.CreatePHI(IRB.getInt32Ty(), 2, "label.phi"); + LabelPHI->addIncoming(Label, EndBB); + LabelPHI->addIncoming(IRB.getInt32(-1), Entry); + SwitchInst *SI = IRB.CreateSwitch(LabelPHI, OrigEntry, SetjmpRetPHIs.size()); + // -1 means no longjmp happened, continue normally (will hit the default + // switch case). 0 means a longjmp that is not ours to handle, needs a + // rethrow. Otherwise the index is the same as the index in P+1 (to avoid + // 0). + for (unsigned I = 0; I < SetjmpRetPHIs.size(); I++) { + SI->addCase(IRB.getInt32(I + 1), SetjmpRetPHIs[I]->getParent()); + SetjmpRetPHIs[I]->addIncoming(Val, SetjmpDispatchBB); + } + + // Convert all longjmpable call instructions to invokes that unwind to the + // newly created catch.dispatch.longjmp BB. + SmallVector ToErase; + for (auto *BB = &*F.begin(); BB; BB = BB->getNextNode()) { + for (Instruction &I : *BB) { + auto *CI = dyn_cast(&I); + if (!CI) + continue; + const Value *Callee = CI->getCalledOperand(); + if (!canLongjmp(Callee)) + continue; + if (isEmAsmCall(Callee)) + report_fatal_error("Cannot use EM_ASM* alongside setjmp/longjmp in " + + F.getName() + + ". Please consider using EM_JS, or move the " + "EM_ASM into another function.", + false); + // This is __wasm_longjmp() call we inserted in this function, which + // rethrows the longjmp when the longjmp does not correspond to one of + // setjmps in this function. We should not convert this call to an invoke. + if (CI == WasmLongjmpCI) + continue; + ToErase.push_back(CI); + + // Even if the callee function has attribute 'nounwind', which is true for + // all C functions, it can longjmp, which means it can throw a Wasm + // exception now. + CI->removeFnAttr(Attribute::NoUnwind); + if (Function *CalleeF = CI->getCalledFunction()) { + CalleeF->removeFnAttr(Attribute::NoUnwind); + } + + IRB.SetInsertPoint(CI); + BasicBlock *Tail = SplitBlock(BB, CI->getNextNode()); + // We will add a new invoke. So remove the branch created when we split + // the BB + ToErase.push_back(BB->getTerminator()); + SmallVector Args(CI->args()); + InvokeInst *II = + IRB.CreateInvoke(CI->getFunctionType(), CI->getCalledOperand(), Tail, + CatchSwitchBB, Args); + II->takeName(CI); + II->setDebugLoc(CI->getDebugLoc()); + II->setAttributes(CI->getAttributes()); + CI->replaceAllUsesWith(II); + } + } + + for (Instruction *I : ToErase) + I->eraseFromParent(); +} diff --git a/llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll b/llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll --- a/llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll +++ b/llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll @@ -254,7 +254,7 @@ ; Tests cases where longjmp function pointer is used in other ways than direct ; calls. longjmps should be replaced with -; (int(*)(jmp_buf*, int))emscripten_longjmp. +; (void(*)(jmp_buf*, int))emscripten_longjmp. declare void @take_longjmp(void (%struct.__jmp_buf_tag*, i32)* %arg_ptr) define void @indirect_longjmp() { ; CHECK-LABEL: @indirect_longjmp diff --git a/llvm/test/CodeGen/WebAssembly/lower-wasm-sjlj.ll b/llvm/test/CodeGen/WebAssembly/lower-wasm-sjlj.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/lower-wasm-sjlj.ll @@ -0,0 +1,161 @@ +; RUN: opt < %s -wasm-lower-em-ehsjlj -wasm-enable-sjlj -S | FileCheck %s -DPTR=i32 +; RUN: opt < %s -wasm-lower-em-ehsjlj -wasm-enable-sjlj --mtriple=wasm64-unknown-unknown -data-layout="e-m:e-p:64:64-i64:64-n32:64-S128" -S | FileCheck %s -DPTR=i64 + +target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128" +target triple = "wasm32-unknown-unknown" + +%struct.__jmp_buf_tag = type { [6 x i32], i32, [32 x i32] } + +@global_longjmp_ptr = global void (%struct.__jmp_buf_tag*, i32)* @longjmp, align 4 +; CHECK-DAG: @global_longjmp_ptr = global void (%struct.__jmp_buf_tag*, i32)* bitcast (void (i8*, i32)* @__wasm_longjmp to void (%struct.__jmp_buf_tag*, i32)*) + +; Test a simple setjmp - longjmp sequence +define void @setjmp_longjmp() { +; CHECK-LABEL: @setjmp_longjmp() personality {{.*}} @__gxx_wasm_personality_v0 +entry: + %buf = alloca [1 x %struct.__jmp_buf_tag], align 16 + %arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf, i32 0, i32 0 + %call = call i32 @setjmp(%struct.__jmp_buf_tag* %arraydecay) #0 + %arraydecay1 = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf, i32 0, i32 0 + call void @longjmp(%struct.__jmp_buf_tag* %arraydecay1, i32 1) #1 + unreachable + +; CHECK: entry: +; CHECK-NEXT: %malloccall = tail call i8* @malloc(i32 40) +; CHECK-NEXT: %setjmpTable = bitcast i8* %malloccall to i32* +; CHECK-NEXT: store i32 0, i32* %setjmpTable, align 4 +; CHECK-NEXT: %setjmpTableSize = add i32 4, 0 +; CHECK-NEXT: br label %setjmp.dispatch + +; CHECK: setjmp.dispatch: +; CHECK-NEXT: %val10 = phi i32 [ %val, %if.end ], [ undef, %entry ] +; CHECK-NEXT: %buf9 = phi [1 x %struct.__jmp_buf_tag]* [ %buf8, %if.end ], [ undef, %entry ] +; CHECK-NEXT: %setjmpTableSize6 = phi i32 [ %setjmpTableSize7, %if.end ], [ %setjmpTableSize, %entry ] +; CHECK-NEXT: %setjmpTable4 = phi i32* [ %setjmpTable5, %if.end ], [ %setjmpTable, %entry ] +; CHECK-NEXT: %label.phi = phi i32 [ %label, %if.end ], [ -1, %entry ] +; CHECK-NEXT: switch i32 %label.phi, label %entry.split [ +; CHECK-NEXT: i32 1, label %entry.split.split +; CHECK-NEXT: ] + +; CHECK: entry.split: +; CHECK-NEXT: %buf = alloca [1 x %struct.__jmp_buf_tag], align 16 +; CHECK-NEXT: %arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf, i32 0, i32 0 +; CHECK-NEXT: %setjmpTable1 = call i32* @saveSetjmp(%struct.__jmp_buf_tag* %arraydecay, i32 1, i32* %setjmpTable4, i32 %setjmpTableSize6) +; CHECK-NEXT: %setjmpTableSize2 = call i32 @getTempRet0() +; CHECK-NEXT: br label %entry.split.split + +; CHECK: entry.split.split: +; CHECK-NEXT: %buf8 = phi [1 x %struct.__jmp_buf_tag]* [ %buf9, %setjmp.dispatch ], [ %buf, %entry.split ] +; CHECK-NEXT: %setjmpTableSize7 = phi i32 [ %setjmpTableSize2, %entry.split ], [ %setjmpTableSize6, %setjmp.dispatch ] +; CHECK-NEXT: %setjmpTable5 = phi i32* [ %setjmpTable1, %entry.split ], [ %setjmpTable4, %setjmp.dispatch ] +; CHECK-NEXT: %setjmp.ret = phi i32 [ 0, %entry.split ], [ %val10, %setjmp.dispatch ] +; CHECK-NEXT: %arraydecay1 = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf8, i32 0, i32 0 +; CHECK-NEXT: %env = bitcast %struct.__jmp_buf_tag* %arraydecay1 to i8* +; CHECK-NEXT: invoke void @__wasm_longjmp(i8* %env, i32 1) +; CHECK-NEXT: to label %entry.split.split.split unwind label %catch.dispatch.longjmp + +; CHECK: entry.split.split.split: +; CHECK-NEXT: unreachable + +; CHECK: catch.dispatch.longjmp: +; CHECK-NEXT: %0 = catchswitch within none [label %catch.longjmp] unwind to caller + +; CHECK: catch.longjmp: +; CHECK-NEXT: %1 = catchpad within %0 [] +; CHECK-NEXT: %thrown = call i8* @llvm.wasm.catch(i32 1) +; CHECK-NEXT: %longjmp.args = bitcast i8* %thrown to { i8*, i32 }* +; CHECK-NEXT: %env_gep = getelementptr { i8*, i32 }, { i8*, i32 }* %longjmp.args, i32 0, i32 0 +; CHECK-NEXT: %val_gep = getelementptr { i8*, i32 }, { i8*, i32 }* %longjmp.args, i32 0, i32 1 +; CHECK-NEXT: %env3 = load i8*, i8** %env_gep, align {{.*}} +; CHECK-NEXT: %val = load i32, i32* %val_gep, align 4 +; CHECK-NEXT: %env.p = bitcast i8* %env3 to [[PTR]]* +; CHECK-NEXT: %setjmp.id = load [[PTR]], [[PTR]]* %env.p, align {{.*}} +; CHECK-NEXT: %label = call i32 @testSetjmp([[PTR]] %setjmp.id, i32* %setjmpTable5, i32 %setjmpTableSize7) [ "funclet"(token %1) ] +; CHECK-NEXT: %2 = icmp eq i32 %label, 0 +; CHECK-NEXT: br i1 %2, label %if.then, label %if.end + +; CHECK: if.then: +; CHECK-NEXT: %3 = bitcast i32* %setjmpTable5 to i8* +; CHECK-NEXT: tail call void @free(i8* %3) [ "funclet"(token %1) ] +; CHECK-NEXT: call void @__wasm_longjmp(i8* %env3, i32 %val) [ "funclet"(token %1) ] +; CHECK-NEXT: unreachable + +; CHECK: if.end: +; CHECK-NEXT: catchret from %1 to label %setjmp.dispatch +} + +; When there are multiple longjmpable calls after setjmp. This will turn each of +; longjmpable call into an invoke whose unwind destination is +; 'catch.dispatch.longjmp' BB. +define void @setjmp_multiple_longjmpable_calls() { +; CHECK-LABEL: @setjmp_multiple_longjmpable_calls +entry: + %buf = alloca [1 x %struct.__jmp_buf_tag], align 16 + %arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf, i32 0, i32 0 + %call = call i32 @setjmp(%struct.__jmp_buf_tag* %arraydecay) #0 + call void @foo() + call void @foo() + ret void + +; CHECK: entry.split.split: +; CHECK: invoke void @foo() +; CHECK: to label %{{.*}} unwind label %catch.dispatch.longjmp + +; CHECK: entry.split.split.split: +; CHECK: invoke void @foo() +; CHECK: to label %{{.*}} unwind label %catch.dispatch.longjmp +} + +; Tests cases where longjmp function pointer is used in other ways than direct +; calls. longjmps should be replaced with (void(*)(jmp_buf*, int))__wasm_longjmp. +declare void @take_longjmp(void (%struct.__jmp_buf_tag*, i32)* %arg_ptr) +define void @indirect_longjmp() { +; CHECK-LABEL: @indirect_longjmp +entry: + %local_longjmp_ptr = alloca void (%struct.__jmp_buf_tag*, i32)*, align 4 + %buf0 = alloca [1 x %struct.__jmp_buf_tag], align 16 + %buf1 = alloca [1 x %struct.__jmp_buf_tag], align 16 + + ; Store longjmp in a local variable, load it, and call it + store void (%struct.__jmp_buf_tag*, i32)* @longjmp, void (%struct.__jmp_buf_tag*, i32)** %local_longjmp_ptr, align 4 + ; CHECK: store void (%struct.__jmp_buf_tag*, i32)* bitcast (void (i8*, i32)* @__wasm_longjmp to void (%struct.__jmp_buf_tag*, i32)*), void (%struct.__jmp_buf_tag*, i32)** %local_longjmp_ptr, align 4 + %longjmp_from_local_ptr = load void (%struct.__jmp_buf_tag*, i32)*, void (%struct.__jmp_buf_tag*, i32)** %local_longjmp_ptr, align 4 + %arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf0, i32 0, i32 0 + call void %longjmp_from_local_ptr(%struct.__jmp_buf_tag* %arraydecay, i32 0) + + ; Load longjmp from a global variable and call it + %longjmp_from_global_ptr = load void (%struct.__jmp_buf_tag*, i32)*, void (%struct.__jmp_buf_tag*, i32)** @global_longjmp_ptr, align 4 + %arraydecay1 = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf1, i32 0, i32 0 + call void %longjmp_from_global_ptr(%struct.__jmp_buf_tag* %arraydecay1, i32 0) + + ; Pass longjmp as a function argument. This is a call but longjmp is not a + ; callee but an argument. + call void @take_longjmp(void (%struct.__jmp_buf_tag*, i32)* @longjmp) + ; CHECK: call void @take_longjmp(void (%struct.__jmp_buf_tag*, i32)* bitcast (void (i8*, i32)* @__wasm_longjmp to void (%struct.__jmp_buf_tag*, i32)*)) + ret void +} + +; Function Attrs: nounwind +declare void @foo() #2 +; The pass removes the 'nounwind' attribute, so there should be no attributes +; CHECK-NOT: declare void @foo #{{.*}} +; Function Attrs: returns_twice +declare i32 @setjmp(%struct.__jmp_buf_tag*) #0 +; Function Attrs: noreturn +declare void @longjmp(%struct.__jmp_buf_tag*, i32) #1 +declare i32 @__gxx_personality_v0(...) +declare i8* @__cxa_begin_catch(i8*) +declare void @__cxa_end_catch() +declare i8* @malloc(i32) +declare void @free(i8*) + +; JS glue function declarations +; CHECK-DAG: declare i32 @getTempRet0() +; CHECK-DAG: declare void @setTempRet0(i32) +; CHECK-DAG: declare i32* @saveSetjmp(%struct.__jmp_buf_tag*, i32, i32*, i32) +; CHECK-DAG: declare i32 @testSetjmp([[PTR]], i32*, i32) +; CHECK-DAG: declare void @__wasm_longjmp(i8*, i32) + +attributes #0 = { returns_twice } +attributes #1 = { noreturn } +attributes #2 = { nounwind }