diff --git a/llvm/lib/Target/WebAssembly/WebAssembly.h b/llvm/lib/Target/WebAssembly/WebAssembly.h --- a/llvm/lib/Target/WebAssembly/WebAssembly.h +++ b/llvm/lib/Target/WebAssembly/WebAssembly.h @@ -25,7 +25,8 @@ class FunctionPass; // LLVM IR passes. -ModulePass *createWebAssemblyLowerEmscriptenEHSjLj(bool DoEH, bool DoSjLj); +ModulePass *createWebAssemblyLowerEmscriptenEHSjLj(bool EnableEH, + bool EnableSjLj); ModulePass *createWebAssemblyLowerGlobalDtors(); ModulePass *createWebAssemblyAddMissingPrototypes(); ModulePass *createWebAssemblyFixFunctionBitcasts(); diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp @@ -216,6 +216,7 @@ class WebAssemblyLowerEmscriptenEHSjLj final : public ModulePass { bool EnableEH; // Enable exception handling bool EnableSjLj; // Enable setjmp/longjmp handling + bool DoSjLj; // Whether we actually perform setjmp/longjmp handling GlobalVariable *ThrewGV = nullptr; GlobalVariable *ThrewValueGV = nullptr; @@ -234,6 +235,8 @@ StringMap InvokeWrappers; // Set of allowed function names for exception handling std::set EHAllowlistSet; + // Functions that contains calls to setjmp + SmallPtrSet SetjmpUsers; StringRef getPassName() const override { return "WebAssembly Lower Emscripten Exceptions"; @@ -252,6 +255,10 @@ bool areAllExceptionsAllowed() const { return EHAllowlistSet.empty(); } bool canLongjmp(Module &M, const Value *Callee) const; bool isEmAsmCall(Module &M, const Value *Callee) const; + bool supportsException(const Function *F) const { + return EnableEH && (areAllExceptionsAllowed() || + EHAllowlistSet.count(std::string(F->getName()))); + } void rebuildSSA(Function &F); @@ -287,7 +294,7 @@ return false; StringRef Name = F->getName(); // leave setjmp and longjmp (mostly) alone, we process them properly later - if (Name == "setjmp" || Name == "longjmp") + if (Name == "setjmp" || Name == "longjmp" || Name == "emscripten_longjmp") return false; return !F->doesNotThrow(); } @@ -693,7 +700,7 @@ Function *LongjmpF = M.getFunction("longjmp"); bool SetjmpUsed = SetjmpF && !SetjmpF->use_empty(); bool LongjmpUsed = LongjmpF && !LongjmpF->use_empty(); - bool DoSjLj = EnableSjLj && (SetjmpUsed || LongjmpUsed); + DoSjLj = EnableSjLj && (SetjmpUsed || LongjmpUsed); auto *TPC = getAnalysisIfAvailable(); assert(TPC && "Expected a TargetPassConfig"); @@ -718,7 +725,7 @@ bool Changed = false; - // Exception handling + // Function registration for exception handling if (EnableEH) { // Register __resumeException function FunctionType *ResumeFTy = @@ -729,26 +736,15 @@ FunctionType *EHTypeIDTy = FunctionType::get(IRB.getInt32Ty(), IRB.getInt8PtrTy(), false); EHTypeIDF = getEmscriptenFunction(EHTypeIDTy, "llvm_eh_typeid_for", &M); - - for (Function &F : M) { - if (F.isDeclaration()) - continue; - Changed |= runEHOnFunction(F); - } } - // Setjmp/longjmp handling + // Function registration and data pre-gathering for setjmp/longjmp handling if (DoSjLj) { - Changed = true; // We have setjmp or longjmp somewhere - // Register emscripten_longjmp function FunctionType *FTy = FunctionType::get( IRB.getVoidTy(), {getAddrIntType(&M), IRB.getInt32Ty()}, false); EmLongjmpF = getEmscriptenFunction(FTy, "emscripten_longjmp", &M); - if (LongjmpF) - replaceLongjmpWithEmscriptenLongjmp(LongjmpF, EmLongjmpF); - if (SetjmpF) { // Register saveSetjmp function FunctionType *SetjmpFTy = SetjmpF->getFunctionType(); @@ -765,16 +761,33 @@ false); TestSetjmpF = getEmscriptenFunction(FTy, "testSetjmp", &M); - // Only traverse functions that uses setjmp in order not to insert - // unnecessary prep / cleanup code in every function - SmallPtrSet SetjmpUsers; + // Precompute setjmp users for (User *U : SetjmpF->users()) { auto *UI = cast(U); SetjmpUsers.insert(UI->getFunction()); } + } + } + + // Exception handling transformation + if (EnableEH) { + for (Function &F : M) { + if (F.isDeclaration()) + continue; + Changed |= runEHOnFunction(F); + } + } + + // Setjmp/longjmp handling transformation + if (DoSjLj) { + Changed = true; // We have setjmp or longjmp somewhere + if (LongjmpF) + replaceLongjmpWithEmscriptenLongjmp(LongjmpF, EmLongjmpF); + // Only traverse functions that uses setjmp in order not to insert + // unnecessary prep / cleanup code in every function + if (SetjmpF) for (Function *F : SetjmpUsers) runSjLjOnFunction(*F); - } } if (!Changed) { @@ -802,8 +815,6 @@ bool Changed = false; SmallVector ToErase; SmallPtrSet LandingPads; - bool AllowExceptions = areAllExceptionsAllowed() || - EHAllowlistSet.count(std::string(F.getName())); for (BasicBlock &BB : F) { auto *II = dyn_cast(BB.getTerminator()); @@ -813,12 +824,51 @@ LandingPads.insert(II->getLandingPadInst()); IRB.SetInsertPoint(II); - bool NeedInvoke = AllowExceptions && canThrow(II->getCalledOperand()); + const Value *Callee = II->getCalledOperand(); + bool NeedInvoke = supportsException(&F) && canThrow(Callee); if (NeedInvoke) { // Wrap invoke with invoke wrapper and generate preamble/postamble Value *Threw = wrapInvoke(II); ToErase.push_back(II); + // If setjmp/longjmp handling is enabled, the thrown value can be not an + // exception but a longjmp. If the current function contains calls to + // setjmp, it will be appropriately handled in runSjLjOnFunction. But even + // if the function does not contain setjmp calls, we shouldn't silently + // ignore longjmps; we should rethrow them so they can be correctly + // handled in somewhere up the call chain where setjmp is. + // __THREW__'s value is 0 when nothing happened, 1 when an exception is + // thrown, other values when longjmp is thrown. + // + // if (%__THREW__.val == 0 || %__THREW__.val == 1) + // goto %tail + // else + // goto %longjmp.rethrow + // + // longjmp.rethrow: ;; This is longjmp. Rethrow it + // %__threwValue.val = __threwValue + // emscripten_longjmp(%__THREW__.val, %__threwValue.val); + // + // tail: ;; Nothing happened or an exception is thrown + // ... Continue exception handling ... + if (DoSjLj && !SetjmpUsers.count(&F) && canLongjmp(M, Callee)) { + BasicBlock *Tail = BasicBlock::Create(C, "tail", &F); + BasicBlock *RethrowBB = BasicBlock::Create(C, "longjmp.rethrow", &F); + Value *CmpEqOne = + IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 1), "cmp.eq.one"); + Value *CmpEqZero = + IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 0), "cmp.eq.zero"); + Value *Or = IRB.CreateOr(CmpEqZero, CmpEqOne, "or"); + IRB.CreateCondBr(Or, Tail, RethrowBB); + IRB.SetInsertPoint(RethrowBB); + Value *ThrewValue = IRB.CreateLoad(IRB.getInt32Ty(), ThrewValueGV, + ThrewValueGV->getName() + ".val"); + IRB.CreateCall(EmLongjmpF, {Threw, ThrewValue}); + + IRB.CreateUnreachable(); + IRB.SetInsertPoint(Tail); + } + // Insert a branch based on __THREW__ variable Value *Cmp = IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 1), "cmp"); IRB.CreateCondBr(Cmp, II->getUnwindDest(), II->getNormalDest()); @@ -1098,6 +1148,46 @@ Threw = wrapInvoke(CI); ToErase.push_back(CI); Tail = SplitBlock(BB, CI->getNextNode()); + + // If exception handling is enabled, the thrown value can be not a + // longjmp but an exception, in which case we shouldn't silently ignore + // exceptions; we should rethrow them. + // __THREW__'s value is 0 when nothing happened, 1 when an exception is + // thrown, other values when longjmp is thrown. + // + // if (%__THREW__.val == 1) + // goto %eh.rethrow + // else + // goto %normal + // + // eh.rethrow: ;; Rethrow exception + // %exn = call @__cxa_find_matching_catch_2() ;; Retrieve thrown ptr + // __resumeException(%exn) + // + // normal: + // <-- Insertion point. Will insert sjlj handling code from here + // goto %tail + // + // tail: + // ... + if (supportsException(&F) && canThrow(Callee)) { + IRB.SetInsertPoint(CI); + // We will add a new conditional branch. So remove the branch created + // when we split the BB + ToErase.push_back(BB->getTerminator()); + BasicBlock *NormalBB = BasicBlock::Create(C, "normal", &F); + BasicBlock *RethrowBB = BasicBlock::Create(C, "eh.rethrow", &F); + Value *CmpEqOne = + IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 1), "cmp.eq.one"); + IRB.CreateCondBr(CmpEqOne, RethrowBB, NormalBB); + IRB.SetInsertPoint(RethrowBB); + CallInst *Exn = IRB.CreateCall(getFindMatchingCatch(M, 0), {}, "exn"); + IRB.CreateCall(ResumeF, {Exn}); + IRB.CreateUnreachable(); + IRB.SetInsertPoint(NormalBB); + IRB.CreateBr(Tail); + BB = NormalBB; // New insertion point to insert testSetjmp() + } } // We need to replace the terminator in Tail - SplitBlock makes BB go diff --git a/llvm/test/CodeGen/WebAssembly/lower-em-ehsjlj.ll b/llvm/test/CodeGen/WebAssembly/lower-em-ehsjlj.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/lower-em-ehsjlj.ll @@ -0,0 +1,132 @@ +; RUN: opt < %s -wasm-lower-em-ehsjlj -S | FileCheck %s +; RUN: llc < %s + +; Tests for cases when exception handling and setjmp/longjmp handling are mixed. + +target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128" +target triple = "wasm32-unknown-unknown" + +%struct.__jmp_buf_tag = type { [6 x i32], i32, [32 x i32] } + +; There is a function call (@foo) that can either throw an exception or longjmp +; and there is also a setjmp call. When @foo throws, we have to check both for +; exception and longjmp and jump to exception or longjmp handling BB depending +; on the result. +define void @setjmp_longjmp_exception() personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) { +; CHECK-LABEL: @setjmp_longjmp_exception +entry: + %buf = alloca [1 x %struct.__jmp_buf_tag], align 16 + %arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf, i32 0, i32 0 + %call = call i32 @setjmp(%struct.__jmp_buf_tag* %arraydecay) #0 + invoke void @foo() + to label %try.cont unwind label %lpad + +; CHECK: entry.split: +; CHECK: %[[CMP0:.*]] = icmp ne i32 %__THREW__.val, 0 +; CHECK-NEXT: %__threwValue.val = load i32, i32* @__threwValue +; CHECK-NEXT: %[[CMP1:.*]] = icmp ne i32 %__threwValue.val, 0 +; CHECK-NEXT: %[[CMP:.*]] = and i1 %[[CMP0]], %[[CMP1]] +; CHECK-NEXT: br i1 %[[CMP]], label %if.then1, label %if.else1 + +; This is exception checking part. %if.else1 leads here +; CHECK: entry.split.split: +; CHECK-NEXT: %[[CMP:.*]] = icmp eq i32 %__THREW__.val, 1 +; CHECK-NEXT: br i1 %[[CMP]], label %lpad, label %try.cont + +; longjmp checking part +; CHECK: if.then1: +; CHECK: call i32 @testSetjmp + +lpad: ; preds = %entry + %0 = landingpad { i8*, i32 } + catch i8* null + %1 = extractvalue { i8*, i32 } %0, 0 + %2 = extractvalue { i8*, i32 } %0, 1 + %3 = call i8* @__cxa_begin_catch(i8* %1) #2 + call void @__cxa_end_catch() + br label %try.cont + +try.cont: ; preds = %entry, %lpad + ret void +} + +; @foo can either throw an exception or longjmp. Because this function doesn't +; have any setjmp calls, we only handle exceptions in this function. But because +; sjlj is enabled, we check if the thrown value is longjmp and if so rethrow it +; by calling @emscripten_longjmp. +define void @rethrow_longjmp() personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) { +; CHECK-LABEL: @rethrow_longjmp +entry: + invoke void @foo() + to label %try.cont unwind label %lpad +; CHECK: entry: +; CHECK: %cmp.eq.one = icmp eq i32 %__THREW__.val, 1 +; CHECK-NEXT: %cmp.eq.zero = icmp eq i32 %__THREW__.val, 0 +; CHECK-NEXT: %or = or i1 %cmp.eq.zero, %cmp.eq.one +; CHECK-NEXT: br i1 %or, label %tail, label %longjmp.rethrow + +; CHECK: tail: +; CHECK-NEXT: %cmp = icmp eq i32 %__THREW__.val, 1 +; CHECK-NEXT: br i1 %cmp, label %lpad, label %try.cont + +; CHECK: longjmp.rethrow: +; CHECK-NEXT: %__threwValue.val = load i32, i32* @__threwValue, align 4 +; CHECK-NEXT: call void @emscripten_longjmp(i32 %__THREW__.val, i32 %__threwValue.val) +; CHECK-NEXT: unreachable + +lpad: ; preds = %entry + %0 = landingpad { i8*, i32 } + catch i8* null + %1 = extractvalue { i8*, i32 } %0, 0 + %2 = extractvalue { i8*, i32 } %0, 1 + %3 = call i8* @__cxa_begin_catch(i8* %1) #5 + call void @__cxa_end_catch() + br label %try.cont + +try.cont: ; preds = %entry, %lpad + ret void +} + +; This function contains a setjmp call and no invoke, so we only handle longjmp +; here. But @foo can also throw an exception, so we check if an exception is +; thrown and if so rethrow it by calling @__resumeException. +define void @rethrow_exception() { +; CHECK-LABEL: @rethrow_exception +entry: + %buf = alloca [1 x %struct.__jmp_buf_tag], align 16 + %arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf, i32 0, i32 0 + %call = call i32 @setjmp(%struct.__jmp_buf_tag* %arraydecay) #0 + %cmp = icmp ne i32 %call, 0 + br i1 %cmp, label %return, label %if.end + +if.end: ; preds = %entry + call void @foo() + br label %return + +; CHECK: if.end: +; CHECK: %cmp.eq.one = icmp eq i32 %__THREW__.val, 1 +; CHECK-NEXT: br i1 %cmp.eq.one, label %eh.rethrow, label %normal + +; CHECK: normal: +; CHECK-NEXT: icmp ne i32 %__THREW__.val, 0 + +; CHECK: eh.rethrow: +; CHECK-NEXT: %exn = call i8* @__cxa_find_matching_catch_2() +; CHECK-NEXT: call void @__resumeException(i8* %exn) +; CHECK-NEXT: unreachable + +return: ; preds = %entry, %if.end + ret void +} + +declare void @foo() +; Function Attrs: returns_twice +declare i32 @setjmp(%struct.__jmp_buf_tag*) +; Function Attrs: noreturn +declare void @longjmp(%struct.__jmp_buf_tag*, i32) +declare i32 @__gxx_personality_v0(...) +declare i8* @__cxa_begin_catch(i8*) +declare void @__cxa_end_catch() + +attributes #0 = { returns_twice } +attributes #1 = { noreturn } diff --git a/llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll b/llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll --- a/llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll +++ b/llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll @@ -100,44 +100,6 @@ ; CHECK-NEXT: ret void } -; Test a case when a function call is within try-catch, after a setjmp -define void @exception_and_longjmp() personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) { -; CHECK-LABEL: @exception_and_longjmp -entry: - %buf = alloca [1 x %struct.__jmp_buf_tag], align 16 - %arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf, i32 0, i32 0 - %call = call i32 @setjmp(%struct.__jmp_buf_tag* %arraydecay) #0 - invoke void @foo() - to label %try.cont unwind label %lpad - -; CHECK: entry.split: -; CHECK: store [[PTR]] 0, [[PTR]]* @__THREW__ -; CHECK-NEXT: call cc{{.*}} void @__invoke_void(void ()* @foo) -; CHECK-NEXT: %[[__THREW__VAL:.*]] = load [[PTR]], [[PTR]]* @__THREW__ -; CHECK-NEXT: store [[PTR]] 0, [[PTR]]* @__THREW__ -; CHECK-NEXT: %[[CMP0:.*]] = icmp ne [[PTR]] %[[__THREW__VAL]], 0 -; CHECK-NEXT: %[[THREWVALUE_VAL:.*]] = load i32, i32* @__threwValue -; CHECK-NEXT: %[[CMP1:.*]] = icmp ne i32 %[[THREWVALUE_VAL]], 0 -; CHECK-NEXT: %[[CMP:.*]] = and i1 %[[CMP0]], %[[CMP1]] -; CHECK-NEXT: br i1 %[[CMP]], label %if.then1, label %if.else1 - -; CHECK: entry.split.split: -; CHECK-NEXT: %[[CMP:.*]] = icmp eq [[PTR]] %[[__THREW__VAL]], 1 -; CHECK-NEXT: br i1 %[[CMP]], label %lpad, label %try.cont - -lpad: ; preds = %entry - %0 = landingpad { i8*, i32 } - catch i8* null - %1 = extractvalue { i8*, i32 } %0, 0 - %2 = extractvalue { i8*, i32 } %0, 1 - %3 = call i8* @__cxa_begin_catch(i8* %1) #2 - call void @__cxa_end_catch() - br label %try.cont - -try.cont: ; preds = %entry, %lpad - ret void -} - ; Test SSA validity define void @ssa(i32 %n) { ; CHECK-LABEL: @ssa @@ -283,7 +245,8 @@ ret void } -declare void @foo() +; Function Attrs: nounwind +declare void @foo() #2 ; Function Attrs: returns_twice declare i32 @setjmp(%struct.__jmp_buf_tag*) #0 ; Function Attrs: noreturn @@ -311,7 +274,7 @@ ; CHECK-DAG: attributes #{{[0-9]+}} = { "wasm-import-module"="env" "wasm-import-name"="__resumeException" } ; CHECK-DAG: attributes #{{[0-9]+}} = { "wasm-import-module"="env" "wasm-import-name"="llvm_eh_typeid_for" } ; CHECK-DAG: attributes #{{[0-9]+}} = { "wasm-import-module"="env" "wasm-import-name"="__invoke_void" } -; CHECK-DAG: attributes #{{[0-9]+}} = { "wasm-import-module"="env" "wasm-import-name"="__cxa_find_matching_catch_3" } +; CHECK-DAG: attributes #{{[0-9]+}} = { "wasm-import-module"="env" "wasm-import-name"="__cxa_find_matching_catch_2" } ; CHECK-DAG: attributes #{{[0-9]+}} = { "wasm-import-module"="env" "wasm-import-name"="saveSetjmp" } ; CHECK-DAG: attributes #{{[0-9]+}} = { "wasm-import-module"="env" "wasm-import-name"="testSetjmp" } ; CHECK-DAG: attributes #{{[0-9]+}} = { "wasm-import-module"="env" "wasm-import-name"="emscripten_longjmp" }