diff --git a/clang/lib/CodeGen/CGException.cpp b/clang/lib/CodeGen/CGException.cpp --- a/clang/lib/CodeGen/CGException.cpp +++ b/clang/lib/CodeGen/CGException.cpp @@ -1638,7 +1638,8 @@ } void CodeGenFunction::EmitSEHTryStmt(const SEHTryStmt &S) { - EnterSEHTryStmt(S); + bool ContainsRetStmt = false; + EnterSEHTryStmt(S, ContainsRetStmt); { JumpDest TryExit = getJumpDestInCurrentScope("__try.__leave"); @@ -1667,7 +1668,7 @@ else delete TryExit.getBlock(); } - ExitSEHTryStmt(S); + ExitSEHTryStmt(S, ContainsRetStmt); } // Recursively walk through blocks in a _try @@ -1702,8 +1703,9 @@ namespace { struct PerformSEHFinally final : EHScopeStack::Cleanup { llvm::Function *OutlinedFinally; - PerformSEHFinally(llvm::Function *OutlinedFinally) - : OutlinedFinally(OutlinedFinally) {} + bool RetFromFinally; + PerformSEHFinally(llvm::Function *OutlinedFinally, bool RetFromFinally) + : OutlinedFinally(OutlinedFinally), RetFromFinally(RetFromFinally) {} void Emit(CodeGenFunction &CGF, Flags F) override { ASTContext &Context = CGF.getContext(); @@ -1747,6 +1749,21 @@ auto Callee = CGCallee::forDirect(OutlinedFinally); CGF.EmitCall(FnInfo, Callee, ReturnValueSlot(), Args); + + if (F.isForEHCleanup() && RetFromFinally) { + llvm::BasicBlock *AbnormalCont = CGF.createBasicBlock("if.then"); + llvm::BasicBlock *NormalCont = CGF.createBasicBlock("if.end"); + llvm::Value *ShouldRetLoad = + CGF.Builder.CreateLoad(CGF.SEHRetNowStack.back()); + llvm::Value *ShouldRet = CGF.Builder.CreateIsNotNull(ShouldRetLoad); + + CGF.Builder.CreateCondBr(ShouldRet, AbnormalCont, NormalCont); + CGF.EmitBlock(AbnormalCont); + CGF.EmitSEHLocalUnwind(); + CGF.Builder.CreateUnreachable(); + + CGF.EmitBlock(NormalCont); + } } }; } // end anonymous namespace @@ -1758,12 +1775,13 @@ const VarDecl *ParentThis; llvm::SmallSetVector Captures; Address SEHCodeSlot = Address::invalid(); + bool ContainsRetStmt = false; CaptureFinder(CodeGenFunction &ParentCGF, const VarDecl *ParentThis) : ParentCGF(ParentCGF), ParentThis(ParentThis) {} // Return true if we need to do any capturing work. bool foundCaptures() { - return !Captures.empty() || SEHCodeSlot.isValid(); + return !Captures.empty() || SEHCodeSlot.isValid() || ContainsRetStmt; } void Visit(const Stmt *S) { @@ -1805,6 +1823,25 @@ break; } } + + void VisitReturnStmt(const ReturnStmt *) { ContainsRetStmt = true; } +}; +} // end anonymous namespace + +namespace { +/// Find all local variable captures in the statement. +struct ReturnStmtFinder : ConstStmtVisitor { + bool ContainsRetStmt = false; + + void Visit(const Stmt *S) { + // See if this is a capture, then recurse. + ConstStmtVisitor::Visit(S); + for (const Stmt *Child : S->children()) + if (Child) + Visit(Child); + } + + void VisitReturnStmt(const ReturnStmt *) { ContainsRetStmt = true; } }; } // end anonymous namespace @@ -1853,7 +1890,8 @@ bool IsFilter) { // Find all captures in the Stmt. CaptureFinder Finder(ParentCGF, ParentCGF.CXXABIThisDecl); - Finder.Visit(OutlinedStmt); + if (OutlinedStmt) + Finder.Visit(OutlinedStmt); // We can exit early on x86_64 when there are no captures. We just have to // save the exception code in filters so that __exception_code() works. @@ -1991,6 +2029,16 @@ if (IsFilter) EmitSEHExceptionCodeSave(ParentCGF, ParentFP, EntryFP); + + if (Finder.ContainsRetStmt) { + SEHRetNowParent = recoverAddrOfEscapedLocal( + ParentCGF, ParentCGF.SEHRetNowStack.back(), ParentFP); + Address ParentSEHRetVal = + ParentCGF.ParentCGF ? ParentCGF.SEHReturnValue : ParentCGF.ReturnValue; + if (ParentSEHRetVal.isValid()) + SEHReturnValue = + recoverAddrOfEscapedLocal(ParentCGF, ParentSEHRetVal, ParentFP); + } } /// Arrange a function prototype that can be called by Windows exception @@ -2150,19 +2198,93 @@ void CodeGenFunction::pushSEHCleanup(CleanupKind Kind, llvm::Function *FinallyFunc) { - EHStack.pushCleanup(Kind, FinallyFunc); + EHStack.pushCleanup(Kind, FinallyFunc, false); } -void CodeGenFunction::EnterSEHTryStmt(const SEHTryStmt &S) { +void CodeGenFunction::EnterSEHTryStmt(const SEHTryStmt &S, + bool &ContainsRetStmt) { CodeGenFunction HelperCGF(CGM, /*suppressNewContext=*/true); HelperCGF.ParentCGF = this; if (const SEHFinallyStmt *Finally = S.getFinallyHandler()) { + ReturnStmtFinder Finder; + Finder.Visit(Finally); + ContainsRetStmt = Finder.ContainsRetStmt; + if (ContainsRetStmt) { + // Suppose we have something like: + // __try { + // f1(); + // } __finally { + // f2(); + // if (z) + // return; + // f3(); + // } + // + // We want to generate code something like this, where "StopUnwinding()" + // refers to the operation of aborting the unwind, and jupmping back + // to normal code. + // + // int immediate_return = 0; + // __try { + // f1(); + // } __finally { + // f2(); + // if (z) { + // immediate_return = 1; + // goto end_of_finally; + // } + // f3(); + // end_of_finally: + // if (_abnormal_termination()) + // StopUnwinding(); + // } + // if (immediate_return) { + // return; + // } + // + // To handle the non-unwind case, we need to synthesize the + // "immediate_return" variable, and use it to change control flow + // after the finally block. + // + // To make "StopUnwinding()" work, we use _local_unwind. This function + // tells the SEH unwinder to recompute the unwind action: instead of + // using the __except handler that was already computed, stop unwinding + // when the unwinder reaches the current function. (The mechanism used + // here is unofficially called a "collided unwind".) + // + // We represent the destination of _local_unwind with a fake CatchPad: + // when the backend sees a filter named "__IsLocalUnwind", it arranges + // the unwind tables so that _local_unwind stops at that CatchPad, but + // other unwinding ignores it. + // + // Note that this construct could itself be inside an __try or __finally + // block. + // + // If it's inside the __try of a __try/__finally, the outer __finally + // executes before the function returns. + // + // If it's inside a __finally, we need to jump out of that __finally + // in a similar way. + + // Initialize the variable controlling the exception filter. + SEHRetNowStack.push_back( + CreateTempAlloca(CGM.Int8Ty, CharUnits::fromQuantity(1), "retnow")); + Builder.CreateStore(Builder.getInt8(0), SEHRetNowStack.back()); + + // Create the exception filter. + EHCatchScope *CatchScope = EHStack.pushCatch(1); + llvm::Function *FilterFunc = GenerateSEHIsLocalUnwindFunction(); + llvm::Constant *OpaqueFunc = + llvm::ConstantExpr::getBitCast(FilterFunc, Int8PtrTy); + CatchScope->setHandler(0, OpaqueFunc, createBasicBlock("__except.ret")); + } // Outline the finally block. llvm::Function *FinallyFunc = HelperCGF.GenerateSEHFinallyFunction(*this, *Finally); // Push a cleanup for __finally blocks. - EHStack.pushCleanup(NormalAndEHCleanup, FinallyFunc); + EHStack.pushCleanup(NormalAndEHCleanup, FinallyFunc, + ContainsRetStmt); return; } @@ -2194,10 +2316,72 @@ CatchScope->setHandler(0, OpaqueFunc, createBasicBlock("__except.ret")); } -void CodeGenFunction::ExitSEHTryStmt(const SEHTryStmt &S) { +llvm::Function *CodeGenFunction::GenerateSEHIsLocalUnwindFunction() { + // IsLocalUnwind is a void dummy func just for readability. + if (llvm::Function *F = CGM.getModule().getFunction("__IsLocalUnwind")) + return F; + + llvm::LLVMContext &Ctx = getLLVMContext(); + llvm::Type *ArgTys[] = {llvm::Type::getInt8PtrTy(Ctx), + llvm::Type::getInt8PtrTy(Ctx)}; + return llvm::Function::Create( + llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), ArgTys, false), + llvm::GlobalVariable::ExternalWeakLinkage, "__IsLocalUnwind", + &CGM.getModule()); +} + +void CodeGenFunction::EmitSEHLocalUnwind() { + EmitRuntimeCallOrInvoke(CGM.getIntrinsic(llvm::Intrinsic::seh_localunwind)); +} + +void CodeGenFunction::ExitSEHTryStmt(const SEHTryStmt &S, + bool ContainsRetStmt) { // Just pop the cleanup if it's a __finally block. if (S.getFinallyHandler()) { PopCleanupBlock(); + if (ContainsRetStmt) { + // Create __except block and control flow handling for return from + // __finally. See comment in EnterSEHTryStmt. + // + // First, create the point where we check for a return + // from the __finally. + llvm::BasicBlock *ContBB = createBasicBlock("__finally.cont"); + if (HaveInsertPoint()) + Builder.CreateBr(ContBB); + + EmitBlock(ContBB); + + // On the normal path, check if we have a return-from-finally. + llvm::BasicBlock *AbnormalCont = createBasicBlock("if.then"); + llvm::BasicBlock *NormalCont = createBasicBlock("if.end"); + llvm::Value *ShouldRetLoad = Builder.CreateLoad(SEHRetNowStack.back()); + llvm::Value *ShouldRet = Builder.CreateIsNotNull(ShouldRetLoad); + + Builder.CreateCondBr(ShouldRet, AbnormalCont, NormalCont); + + // Check if our filter function returned true. + EHCatchScope &CatchScope = cast(*EHStack.begin()); + emitCatchDispatchBlock(*this, CatchScope); + + // Grab the block before we pop the handler. + llvm::BasicBlock *CatchPadBB = CatchScope.getHandler(0).Block; + EHStack.popCatch(); + + // The catch block only catches return-from-finally. + EmitBlockAfterUses(CatchPadBB); + llvm::CatchPadInst *CPI = + cast(CatchPadBB->getFirstNonPHI()); + Builder.CreateCatchRet(CPI, AbnormalCont); + EmitBlock(AbnormalCont); + + // If the try block is nested inside a finally block, forward the + // return from __finally to the parent function. + if (SEHRetNowParent.isValid()) + Builder.CreateStore(Builder.getInt8(1), SEHRetNowParent); + EmitBranchThroughCleanup(ReturnBlock); + + EmitBlock(NormalCont); + } return; } diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp --- a/clang/lib/CodeGen/CGStmt.cpp +++ b/clang/lib/CodeGen/CGStmt.cpp @@ -1269,10 +1269,10 @@ ReturnLocation); } - // Returning from an outlined SEH helper is UB, and we already warn on it. + Address ReturnValue = this->ReturnValue; if (IsOutlinedSEHHelper) { - Builder.CreateUnreachable(); - Builder.ClearInsertionPoint(); + Builder.CreateStore(Builder.getInt8(1), SEHRetNowParent); + ReturnValue = SEHReturnValue; } // Emit the result value, even if unused, to evaluate the side effects. diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -672,6 +672,16 @@ /// a value from the top of the stack. SmallVector SEHCodeSlotStack; + /// Variable that indicates abnormal termination from the a child finally + /// block. + SmallVector SEHRetNowStack; + + /// Ponter to the parent function's SEHRetNow variable. + Address SEHRetNowParent = Address::invalid(); + + /// Ponter to the root function's ReturnValue variable. + Address SEHReturnValue = Address::invalid(); + /// Value returned by __exception_info intrinsic. llvm::Value *SEHInfo = nullptr; @@ -3281,11 +3291,14 @@ void EmitCXXTryStmt(const CXXTryStmt &S); void EmitSEHTryStmt(const SEHTryStmt &S); void EmitSEHLeaveStmt(const SEHLeaveStmt &S); - void EnterSEHTryStmt(const SEHTryStmt &S); - void ExitSEHTryStmt(const SEHTryStmt &S); + void EnterSEHTryStmt(const SEHTryStmt &S, bool &ContainsRetStmt); + void ExitSEHTryStmt(const SEHTryStmt &S, bool ContainsRetStmt); void VolatilizeTryBlocks(llvm::BasicBlock *BB, llvm::SmallPtrSet &V); + void EmitSEHLocalUnwind(); + llvm::Function *GenerateSEHIsLocalUnwindFunction(); + void pushSEHCleanup(CleanupKind kind, llvm::Function *FinallyFunc); void startOutlinedSEHHelper(CodeGenFunction &ParentCGF, bool IsFilter, diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -541,6 +541,9 @@ def int_seh_scope_begin : Intrinsic<[], [], [IntrNoMem]>; def int_seh_scope_end : Intrinsic<[], [], [IntrNoMem]>; +// Call _local_unwind to unwind to a local catchpad. +def int_seh_localunwind : Intrinsic<[], [], [IntrNoReturn]>; + // Note: we treat stacksave/stackrestore as writemem because we don't otherwise // model their dependencies on allocas. def int_stacksave : DefaultAttrsIntrinsic<[llvm_ptr_ty]>, diff --git a/llvm/lib/CodeGen/AsmPrinter/WinException.cpp b/llvm/lib/CodeGen/AsmPrinter/WinException.cpp --- a/llvm/lib/CodeGen/AsmPrinter/WinException.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/WinException.cpp @@ -626,6 +626,17 @@ LastStartLabel = StateChange.NewStartLabel; LastEHState = StateChange.NewState; } + for (auto Entry : FuncInfo.SEHUnwindMap) { + if (!Entry.IsFinally && Entry.ToState != -1) { + // Mark up the destination of _local_unwind so it doesn't unwind + // too far. + // + // FIXME: Can this overlap with the EH_LABEL for an invoke? + auto *Handler = Entry.Handler.get(); + const MCSymbol *Begin = Handler->getSymbol(); + emitSEHActionsForRange(FuncInfo, Begin, Begin, Entry.ToState); + } + } OS.emitLabel(TableEnd); } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -2937,6 +2937,41 @@ DAG.setRoot(DAG.getNode(ISD::INTRINSIC_VOID, getCurSDLoc(), VTs, Ops)); break; } + case Intrinsic::seh_localunwind: { + if (!isa(EHPadBB->getTerminator())) { + report_fatal_error("localunwind doesn't point to catchswitch"); + } + auto *CatchSwitch = cast(EHPadBB->getTerminator()); + if (CatchSwitch->getNumHandlers() == 0) { + report_fatal_error("catchswitch with no handler"); + } + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + TargetLowering::ArgListEntry SP, DestBB; + Type *PtrTy = PointerType::getInt8PtrTy(*DAG.getContext()); + EVT PtrVT = TLI.getPointerTy(DAG.getDataLayout()); + SP.Node = DAG.getNode(ISD::FRAMEADDR, getCurSDLoc(), PtrVT, + DAG.getIntPtrConstant(0, getCurSDLoc())); + SP.Ty = PtrTy; + FuncInfo.MBBMap[*CatchSwitch->handler_begin()]->setHasAddressTaken(); + DestBB.Node = + DAG.getBlockAddress(BlockAddress::get(const_cast( + *CatchSwitch->handler_begin())), + PtrVT); + DestBB.Ty = PtrTy; + TargetLowering::ArgListTy Args{SP, DestBB}; + + SDValue Callee = DAG.getExternalSymbol("_local_unwind", PtrVT); + TargetLowering::CallLoweringInfo CLI(DAG); + CLI.setDebugLoc(getCurSDLoc()) + .setChain(getRoot()) + .setCallee(CallingConv::C, Type::getVoidTy(*DAG.getContext()), Callee, + std::move(Args)) + .setNoReturn(); + CLI.CB = &I; + lowerInvokable(CLI, EHPadBB); + break; + } } } else if (I.countOperandBundlesOfType(LLVMContext::OB_deopt)) { // Currently we do not lower any intrinsic calls with deopt operand bundles. diff --git a/llvm/lib/CodeGen/WinEHPrepare.cpp b/llvm/lib/CodeGen/WinEHPrepare.cpp --- a/llvm/lib/CodeGen/WinEHPrepare.cpp +++ b/llvm/lib/CodeGen/WinEHPrepare.cpp @@ -380,6 +380,20 @@ const Function *Filter = dyn_cast(FilterOrNull); assert((Filter || FilterOrNull->isNullValue()) && "unexpected filter value"); + // Filters named __IsLocalUnwind are treated specially: we want to catch + // unwinds from _local_unwind, but not catchrets in the same funclet. + // (They both need to point at the same catchswitch to pass the verifier + // checks for nesting.) To make this work, we mess with the state + // numbering: the "parent" of any cleanupret pointing to this catchpad is + // actually this catchpad's parent. + // + // Note that _local_unwind looks for unwind table entries for the + // catchpad; if there aren't any, it assumes the catchpad doesn't have a + // parent. + bool IsLocalUnwind = + Filter && Filter->getName().startswith("__IsLocalUnwind"); + if (IsLocalUnwind) + Filter = nullptr; int TryState = addSEHExcept(FuncInfo, ParentState, Filter, CatchPadBB); // Everything in the __try block uses TryState as its parent state. @@ -390,7 +404,7 @@ if ((PredBlock = getEHPadFromPredecessor(PredBlock, CatchSwitch->getParentPad()))) calculateSEHStateNumbers(FuncInfo, PredBlock->getFirstNonPHI(), - TryState); + IsLocalUnwind ? ParentState : TryState); // Everything in the __except block unwinds to ParentState, just like code // outside the __try. diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -4612,6 +4612,7 @@ F->getIntrinsicID() == Intrinsic::experimental_patchpoint_i64 || F->getIntrinsicID() == Intrinsic::experimental_gc_statepoint || F->getIntrinsicID() == Intrinsic::wasm_rethrow || + F->getIntrinsicID() == Intrinsic::seh_localunwind || IsAttachedCallOperand(F, CBI, i), "Cannot invoke an intrinsic other than donothing, patchpoint, " "statepoint, coro_resume, coro_destroy or clang.arc.attachedcall",