diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp @@ -263,7 +263,10 @@ Value *wrapInvoke(CallBase *CI); void wrapTestSetjmp(BasicBlock *BB, DebugLoc DL, Value *Threw, Value *SetjmpTable, Value *SetjmpTableSize, Value *&Label, - Value *&LongjmpResult, BasicBlock *&EndBB); + Value *&LongjmpResult, BasicBlock *&CallEmLongjmpBB, + PHINode *&CallEmLongjmpBBThrewPHI, + PHINode *&CallEmLongjmpBBThrewValuePHI, + BasicBlock *&EndBB); Function *getInvokeWrapper(CallBase *CI); bool areAllExceptionsAllowed() const { return EHAllowlistSet.empty(); } @@ -585,7 +588,8 @@ void WebAssemblyLowerEmscriptenEHSjLj::wrapTestSetjmp( BasicBlock *BB, DebugLoc DL, Value *Threw, Value *SetjmpTable, Value *SetjmpTableSize, Value *&Label, Value *&LongjmpResult, - BasicBlock *&EndBB) { + BasicBlock *&CallEmLongjmpBB, PHINode *&CallEmLongjmpBBThrewPHI, + PHINode *&CallEmLongjmpBBThrewValuePHI, BasicBlock *&EndBB) { Function *F = BB->getParent(); Module *M = F->getParent(); LLVMContext &C = M->getContext(); @@ -604,10 +608,27 @@ Value *Cmp1 = IRB.CreateAnd(ThrewCmp, ThrewValueCmp, "cmp1"); IRB.CreateCondBr(Cmp1, ThenBB1, ElseBB1); + // Generate call.em.longjmp BB once and share it within the function + if (!CallEmLongjmpBB) { + // emscripten_longjmp(%__THREW__.val, %__threwValue.val); + CallEmLongjmpBB = BasicBlock::Create(C, "call.em.longjmp", F); + IRB.SetInsertPoint(CallEmLongjmpBB); + CallEmLongjmpBBThrewPHI = IRB.CreatePHI(getAddrIntType(M), 4, "threw.phi"); + CallEmLongjmpBBThrewValuePHI = + IRB.CreatePHI(IRB.getInt32Ty(), 4, "threwvalue.phi"); + CallEmLongjmpBBThrewPHI->addIncoming(Threw, ThenBB1); + CallEmLongjmpBBThrewValuePHI->addIncoming(ThrewValue, ThenBB1); + IRB.CreateCall(EmLongjmpF, + {CallEmLongjmpBBThrewPHI, CallEmLongjmpBBThrewValuePHI}); + IRB.CreateUnreachable(); + } else { + CallEmLongjmpBBThrewPHI->addIncoming(Threw, ThenBB1); + CallEmLongjmpBBThrewValuePHI->addIncoming(ThrewValue, ThenBB1); + } + // %label = testSetjmp(mem[%__THREW__.val], setjmpTable, setjmpTableSize); // if (%label == 0) IRB.SetInsertPoint(ThenBB1); - BasicBlock *ThenBB2 = BasicBlock::Create(C, "if.then2", F); BasicBlock *EndBB2 = BasicBlock::Create(C, "if.end2", F); Value *ThrewPtr = IRB.CreateIntToPtr(Threw, getAddrPtrType(M), Threw->getName() + ".p"); @@ -616,12 +637,7 @@ Value *ThenLabel = IRB.CreateCall( TestSetjmpF, {LoadedThrew, SetjmpTable, SetjmpTableSize}, "label"); Value *Cmp2 = IRB.CreateICmpEQ(ThenLabel, IRB.getInt32(0)); - IRB.CreateCondBr(Cmp2, ThenBB2, EndBB2); - - // emscripten_longjmp(%__THREW__.val, %__threwValue.val); - IRB.SetInsertPoint(ThenBB2); - IRB.CreateCall(EmLongjmpF, {Threw, ThrewValue}); - IRB.CreateUnreachable(); + IRB.CreateCondBr(Cmp2, CallEmLongjmpBB, EndBB2); // setTempRet0(%__threwValue.val); IRB.SetInsertPoint(EndBB2); @@ -840,6 +856,12 @@ SmallVector ToErase; SmallPtrSet LandingPads; + // rethrow.longjmp BB that will be shared within the function. + BasicBlock *RethrowLongjmpBB = nullptr; + // PHI node for the loaded value of __THREW__ global variable in + // rethrow.longjmp BB + PHINode *RethrowLongjmpBBThrewPHI = nullptr; + for (BasicBlock &BB : F) { auto *II = dyn_cast(BB.getTerminator()); if (!II) @@ -869,27 +891,36 @@ // else // goto %longjmp.rethrow // - // longjmp.rethrow: ;; This is longjmp. Rethrow it + // rethrow.longjmp: ;; This is longjmp. Rethrow it // %__threwValue.val = __threwValue // emscripten_longjmp(%__THREW__.val, %__threwValue.val); // // tail: ;; Nothing happened or an exception is thrown // ... Continue exception handling ... if (DoSjLj && !SetjmpUsers.count(&F) && canLongjmp(Callee)) { + // Create longjmp.rethrow BB once and share it within the function + if (!RethrowLongjmpBB) { + RethrowLongjmpBB = BasicBlock::Create(C, "rethrow.longjmp", &F); + IRB.SetInsertPoint(RethrowLongjmpBB); + RethrowLongjmpBBThrewPHI = + IRB.CreatePHI(getAddrIntType(&M), 4, "threw.phi"); + RethrowLongjmpBBThrewPHI->addIncoming(Threw, &BB); + Value *ThrewValue = IRB.CreateLoad(IRB.getInt32Ty(), ThrewValueGV, + ThrewValueGV->getName() + ".val"); + IRB.CreateCall(EmLongjmpF, {RethrowLongjmpBBThrewPHI, ThrewValue}); + IRB.CreateUnreachable(); + } else { + RethrowLongjmpBBThrewPHI->addIncoming(Threw, &BB); + } + + IRB.SetInsertPoint(II); // Restore the insert point back BasicBlock *Tail = BasicBlock::Create(C, "tail", &F); - BasicBlock *RethrowBB = BasicBlock::Create(C, "longjmp.rethrow", &F); Value *CmpEqOne = IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 1), "cmp.eq.one"); Value *CmpEqZero = IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 0), "cmp.eq.zero"); Value *Or = IRB.CreateOr(CmpEqZero, CmpEqOne, "or"); - IRB.CreateCondBr(Or, Tail, RethrowBB); - IRB.SetInsertPoint(RethrowBB); - Value *ThrewValue = IRB.CreateLoad(IRB.getInt32Ty(), ThrewValueGV, - ThrewValueGV->getName() + ".val"); - IRB.CreateCall(EmLongjmpF, {Threw, ThrewValue}); - - IRB.CreateUnreachable(); + IRB.CreateCondBr(Or, Tail, RethrowLongjmpBB); IRB.SetInsertPoint(Tail); BB.replaceSuccessorsPhiUsesWith(&BB, Tail); } @@ -1204,6 +1235,17 @@ Instruction *SetjmpTable = *SetjmpTableInsts.begin(); Instruction *SetjmpTableSize = *SetjmpTableSizeInsts.begin(); + // call.em.longjmp BB that will be shared within the function. + BasicBlock *CallEmLongjmpBB = nullptr; + // PHI node for the loaded value of __THREW__ global variable in + // call.em.longjmp BB + PHINode *CallEmLongjmpBBThrewPHI = nullptr; + // PHI node for the loaded value of __threwValue global variable in + // call.em.longjmp BB + PHINode *CallEmLongjmpBBThrewValuePHI = nullptr; + // rethrow.exn BB that will be shared within the function. + BasicBlock *RethrowExnBB = nullptr; + // Because we are creating new BBs while processing and don't want to make // all these newly created BBs candidates again for longjmp processing, we // first make the vector of candidate BBs. @@ -1297,19 +1339,26 @@ // tail: // ... if (supportsException(&F) && canThrow(Callee)) { - IRB.SetInsertPoint(CI); // We will add a new conditional branch. So remove the branch created // when we split the BB ToErase.push_back(BB->getTerminator()); + + // Generate rethrow.exn BB once and share it within the function + if (!RethrowExnBB) { + RethrowExnBB = BasicBlock::Create(C, "rethrow.exn", &F); + IRB.SetInsertPoint(RethrowExnBB); + CallInst *Exn = + IRB.CreateCall(getFindMatchingCatch(M, 0), {}, "exn"); + IRB.CreateCall(ResumeF, {Exn}); + IRB.CreateUnreachable(); + } + + IRB.SetInsertPoint(CI); BasicBlock *NormalBB = BasicBlock::Create(C, "normal", &F); - BasicBlock *RethrowBB = BasicBlock::Create(C, "eh.rethrow", &F); Value *CmpEqOne = IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 1), "cmp.eq.one"); - IRB.CreateCondBr(CmpEqOne, RethrowBB, NormalBB); - IRB.SetInsertPoint(RethrowBB); - CallInst *Exn = IRB.CreateCall(getFindMatchingCatch(M, 0), {}, "exn"); - IRB.CreateCall(ResumeF, {Exn}); - IRB.CreateUnreachable(); + IRB.CreateCondBr(CmpEqOne, RethrowExnBB, NormalBB); + IRB.SetInsertPoint(NormalBB); IRB.CreateBr(Tail); BB = NormalBB; // New insertion point to insert testSetjmp() @@ -1328,7 +1377,9 @@ Value *LongjmpResult = nullptr; BasicBlock *EndBB = nullptr; wrapTestSetjmp(BB, CI->getDebugLoc(), Threw, SetjmpTable, SetjmpTableSize, - Label, LongjmpResult, EndBB); + Label, LongjmpResult, CallEmLongjmpBB, + CallEmLongjmpBBThrewPHI, CallEmLongjmpBBThrewValuePHI, + EndBB); assert(Label && LongjmpResult && EndBB); // Create switch instruction diff --git a/llvm/test/CodeGen/WebAssembly/lower-em-ehsjlj.ll b/llvm/test/CodeGen/WebAssembly/lower-em-ehsjlj.ll --- a/llvm/test/CodeGen/WebAssembly/lower-em-ehsjlj.ll +++ b/llvm/test/CodeGen/WebAssembly/lower-em-ehsjlj.ll @@ -46,7 +46,7 @@ call void @__cxa_end_catch() br label %try.cont -try.cont: ; preds = %entry, %lpad +try.cont: ; preds = %lpad, %entry ret void } @@ -63,21 +63,22 @@ ; CHECK: %cmp.eq.one = icmp eq i32 %__THREW__.val, 1 ; CHECK-NEXT: %cmp.eq.zero = icmp eq i32 %__THREW__.val, 0 ; CHECK-NEXT: %or = or i1 %cmp.eq.zero, %cmp.eq.one -; CHECK-NEXT: br i1 %or, label %tail, label %longjmp.rethrow +; CHECK-NEXT: br i1 %or, label %tail, label %rethrow.longjmp ; CHECK: try.cont: ; CHECK-NEXT: %phi = phi i32 [ undef, %tail ], [ undef, %lpad ] ; CHECK-NEXT: ret void +; CHECK: rethrow.longjmp: +; CHECK-NEXT: %threw.phi = phi i32 [ %__THREW__.val, %entry ] +; CHECK-NEXT: %__threwValue.val = load i32, i32* @__threwValue, align 4 +; CHECK-NEXT: call void @emscripten_longjmp(i32 %threw.phi, i32 %__threwValue.val +; CHECK-NEXT: unreachable + ; CHECK: tail: ; CHECK-NEXT: %cmp = icmp eq i32 %__THREW__.val, 1 ; CHECK-NEXT: br i1 %cmp, label %lpad, label %try.cont -; CHECK: longjmp.rethrow: -; CHECK-NEXT: %__threwValue.val = load i32, i32* @__threwValue, align 4 -; CHECK-NEXT: call void @emscripten_longjmp(i32 %__THREW__.val, i32 %__threwValue.val) -; CHECK-NEXT: unreachable - lpad: ; preds = %entry %0 = landingpad { i8*, i32 } catch i8* null @@ -87,7 +88,7 @@ call void @__cxa_end_catch() br label %try.cont -try.cont: ; preds = %entry, %lpad +try.cont: ; preds = %lpad, %entry %phi = phi i32 [ undef, %entry ], [ undef, %lpad ] ret void } @@ -111,19 +112,19 @@ ; CHECK: if.end: ; CHECK: %cmp.eq.one = icmp eq i32 %__THREW__.val, 1 -; CHECK-NEXT: br i1 %cmp.eq.one, label %eh.rethrow, label %normal - -; CHECK: normal: -; CHECK-NEXT: icmp ne i32 %__THREW__.val, 0 +; CHECK-NEXT: br i1 %cmp.eq.one, label %rethrow.exn, label %normal -; CHECK: eh.rethrow: +; CHECK: rethrow.exn: ; CHECK-NEXT: %exn = call i8* @__cxa_find_matching_catch_2() -; CHECK-NEXT: %[[BUF:.*]] = bitcast i32* %setjmpTable1 to i8* +; CHECK-NEXT: %[[BUF:.*]] = bitcast i32* %setjmpTable{{.*}} to i8* ; CHECK-NEXT: call void @free(i8* %[[BUF]]) ; CHECK-NEXT: call void @__resumeException(i8* %exn) ; CHECK-NEXT: unreachable -return: ; preds = %entry, %if.end +; CHECK: normal: +; CHECK-NEXT: icmp ne i32 %__THREW__.val, 0 + +return: ; preds = %if.end, %entry ret void } @@ -142,17 +143,81 @@ call void @foo() br label %throw -throw: ; preds = %entry, %if.end +throw: ; preds = %if.end, %entry call void @__cxa_throw(i8* null, i8* null, i8* null) #1 unreachable -; CHECK: throw: -; CHECK: %[[BUF:.*]] = bitcast i32* %setjmpTable5 to i8* +; CHECK: throw: +; CHECK: %[[BUF:.*]] = bitcast i32* %setjmpTable{{.*}} to i8* ; CHECK-NEXT: call void @free(i8* %[[BUF]]) ; CHECK-NEXT: call void @__cxa_throw(i8* null, i8* null, i8* null) ; CHECK-NEXT: unreachable } +; The same case with @rethrow_longjmp, but there are multiple function calls +; that can possibly longjmp (instead of throwing exception) so we have to +; rethrow them. Here we test if we correclty generate only one 'rethrow.longjmp' +; BB and share it for multiple calls. +define void @rethrow_longjmp_multi() personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) { +; CHECK-LABEL: @rethrow_longjmp_multi +entry: + invoke void @foo() + to label %bb unwind label %lpad + +bb: ; preds = %entry + invoke void @foo() + to label %try.cont unwind label %lpad + +lpad: ; preds = %bb, %entry + %0 = landingpad { i8*, i32 } + catch i8* null + %1 = extractvalue { i8*, i32 } %0, 0 + %2 = extractvalue { i8*, i32 } %0, 1 + %3 = call i8* @__cxa_begin_catch(i8* %1) #5 + call void @__cxa_end_catch() + br label %try.cont + +try.cont: ; preds = %lpad, %bb + %phi = phi i32 [ undef, %bb ], [ undef, %lpad ] + ret void + +; CHECK: rethrow.longjmp: +; CHECK-NEXT: %threw.phi = phi i32 [ %__THREW__.val, %entry ], [ %__THREW__.val1, %bb ] +; CHECK-NEXT: %__threwValue.val = load i32, i32* @__threwValue, align 4 +; CHECK-NEXT: call void @emscripten_longjmp(i32 %threw.phi, i32 %__threwValue.val) +; CHECK-NEXT: unreachable +} + +; The same case with @rethrow_exception, but there are multiple function calls +; that can possibly throw (instead of longjmping) so we have to rethrow them. +; Here we test if correctly we generate only one 'rethrow.exn' BB and share it +; for multiple calls. +define void @rethrow_exception_multi() { +; CHECK-LABEL: @rethrow_exception_multi +entry: + %buf = alloca [1 x %struct.__jmp_buf_tag], align 16 + %arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf, i32 0, i32 0 + %call = call i32 @setjmp(%struct.__jmp_buf_tag* %arraydecay) #0 + %cmp = icmp ne i32 %call, 0 + br i1 %cmp, label %return, label %if.end + +if.end: ; preds = %entry + call void @foo() + call void @foo() + br label %return + +return: ; preds = %entry, %if.end + ret void + +; CHECK: rethrow.exn: +; CHECK-NEXT: %setjmpTable{{.*}} = phi i32* [ %setjmpTable{{.*}}, %if.end.split ], [ %setjmpTable{{.*}}, %if.end ] +; CHECK-NEXT: %exn = call i8* @__cxa_find_matching_catch_2() +; CHECK-NEXT: %{{.*}} = bitcast i32* %setjmpTable{{.*}} to i8* +; CHECK-NEXT: tail call void @free(i8* %{{.*}}) +; CHECK-NEXT: call void @__resumeException(i8* %exn) +; CHECK-NEXT: unreachable +} + declare void @foo() ; Function Attrs: returns_twice declare i32 @setjmp(%struct.__jmp_buf_tag*) diff --git a/llvm/test/CodeGen/WebAssembly/lower-em-sjlj-debuginfo.ll b/llvm/test/CodeGen/WebAssembly/lower-em-sjlj-debuginfo.ll --- a/llvm/test/CodeGen/WebAssembly/lower-em-sjlj-debuginfo.ll +++ b/llvm/test/CodeGen/WebAssembly/lower-em-sjlj-debuginfo.ll @@ -38,7 +38,7 @@ ; CHECK: if.end: ; CHECK: call i32 @getTempRet0{{.*}}, !dbg ![[DL2]] -; CHECK: if.then2: +; CHECK: call.em.longjmp: ; CHECK: call void @emscripten_longjmp{{.*}}, !dbg ![[DL2]] ; CHECK: if.end2: diff --git a/llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll b/llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll --- a/llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll +++ b/llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll @@ -61,7 +61,7 @@ ; CHECK-NEXT: %[[__THREW__VAL_P_LOADED:.*]] = load [[PTR]], [[PTR]]* %[[__THREW__VAL_P]] ; CHECK-NEXT: %[[LABEL:.*]] = call i32 @testSetjmp([[PTR]] %[[__THREW__VAL_P_LOADED]], i32* %[[SETJMP_TABLE1]], i32 %[[SETJMP_TABLE_SIZE1]]) ; CHECK-NEXT: %[[CMP:.*]] = icmp eq i32 %[[LABEL]], 0 -; CHECK-NEXT: br i1 %[[CMP]], label %if.then2, label %if.end2 +; CHECK-NEXT: br i1 %[[CMP]], label %call.em.longjmp, label %if.end2 ; CHECK: if.else1: ; CHECK-NEXT: br label %if.end @@ -73,10 +73,12 @@ ; CHECK-NEXT: i32 1, label %entry.split.split ; CHECK-NEXT: ] -; CHECK: if.then2: -; CHECK-NEXT: %[[BUF:.*]] = bitcast i32* %[[SETJMP_TABLE1]] to i8* -; CHECK-NEXT: call void @free(i8* %[[BUF]]) -; CHECK-NEXT: call void @emscripten_longjmp([[PTR]] %[[__THREW__VAL]], i32 %[[THREWVALUE_VAL]]) +; CHECK: call.em.longjmp: +; CHECK-NEXT: %threw.phi = phi [[PTR]] [ %[[__THREW__VAL]], %if.then1 ] +; CHECK-NEXT: %threwvalue.phi = phi i32 [ %[[THREWVALUE_VAL]], %if.then1 ] +; CHECK-NEXT: %{{.*}} = bitcast i32* %[[SETJMP_TABLE1]] to i8* +; CHECK-NEXT: tail call void @free(i8* %{{.*}}) +; CHECK-NEXT: call void @emscripten_longjmp([[PTR]] %threw.phi, i32 %threwvalue.phi) ; CHECK-NEXT: unreachable ; CHECK: if.end2: @@ -85,8 +87,8 @@ } ; Test a case of a function call (which is not longjmp) after a setjmp -define void @setjmp_other() { -; CHECK-LABEL: @setjmp_other +define void @setjmp_longjmpable_call() { +; CHECK-LABEL: @setjmp_longjmpable_call entry: %buf = alloca [1 x %struct.__jmp_buf_tag], align 16 %arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf, i32 0, i32 0 @@ -105,6 +107,28 @@ ; CHECK-NEXT: ret void } +; When there are multiple longjmpable calls after setjmp. In this test we +; specifically check if 'call.em.longjmp' BB, which rethrows longjmps by calling +; emscripten_longjmp for ones that are not for this function's setjmp, is +; correctly created for multiple predecessors. +define void @setjmp_multiple_longjmpable_calls() { +; CHECK-LABEL: @setjmp_multiple_longjmpable_calls +entry: + %buf = alloca [1 x %struct.__jmp_buf_tag], align 16 + %arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf, i32 0, i32 0 + %call = call i32 @setjmp(%struct.__jmp_buf_tag* %arraydecay) #0 + call void @foo() + call void @foo() + ret void +; CHECK: call.em.longjmp: +; CHECK-NEXT: %threw.phi = phi [[PTR]] [ %__THREW__.val, %if.then1 ], [ %__THREW__.val4, %if.then15 ] +; CHECK-NEXT: %threwvalue.phi = phi i32 [ %__threwValue.val, %if.then1 ], [ %__threwValue.val8, %if.then15 ] +; CHECK-NEXT: %{{.*}} = bitcast i32* %[[SETJMP_TABLE1]] to i8* +; CHECK-NEXT: tail call void @free(i8* %{{.*}}) +; CHECK-NEXT: call void @emscripten_longjmp([[PTR]] %threw.phi, i32 %threwvalue.phi) +; CHECK-NEXT: unreachable +} + ; Test a case where a function has a setjmp call but no other calls that can ; longjmp. We don't need to do any transformation in this case. define void @setjmp_only(i8* %ptr) {