Index: lib/Transforms/Coroutines/CoroEarly.cpp =================================================================== --- lib/Transforms/Coroutines/CoroEarly.cpp +++ lib/Transforms/Coroutines/CoroEarly.cpp @@ -26,10 +26,11 @@ // Created on demand if CoroEarly pass has work to do. class Lowerer : public coro::LowererBase { IRBuilder<> Builder; - PointerType *AnyResumeFnPtrTy; + PointerType *const AnyResumeFnPtrTy; void lowerResumeOrDestroy(CallSite CS, CoroSubFnInst::ResumeKind); void lowerCoroPromise(CoroPromiseInst *Intrin); + void lowerCoroDone(IntrinsicInst *II); public: Lowerer(Module &M) @@ -81,6 +82,27 @@ Intrin->eraseFromParent(); } +// When a coroutine reaches final suspend point, it zeros out ResumeFnAddr in +// the coroutine frame (it is UB to resume from a final suspend point). +// The llvm.coro.done intrinsic is used to check whether a coroutine is +// suspended at the final suspend point or not. +void Lowerer::lowerCoroDone(IntrinsicInst *II) { + Value *Operand = II->getArgOperand(0); + + // ResumeFnAddr is the first pointer sized element of the coroutine frame. + auto *FrameTy = Int8Ptr; + PointerType *FramePtrTy = FrameTy->getPointerTo(); + + Builder.SetInsertPoint(II); + auto *BCI = Builder.CreateBitCast(Operand, FramePtrTy); + auto *Gep = Builder.CreateConstInBoundsGEP1_32(FrameTy, BCI, 0); + auto *Load = Builder.CreateLoad(Gep); + auto *Cond = Builder.CreateICmpEQ(Load, NullPtr); + + II->replaceAllUsesWith(Cond); + II->eraseFromParent(); +} + // Prior to CoroSplit, calls to coro.begin needs to be marked as NoDuplicate, // as CoroSplit assumes there is exactly one coro.begin. After CoroSplit, // NoDuplicate attribute will be removed from coro.begin otherwise, it will @@ -131,6 +153,9 @@ case Intrinsic::coro_promise: lowerCoroPromise(cast(&I)); break; + case Intrinsic::coro_done: + lowerCoroDone(cast(&I)); + break; } Changed = true; } @@ -153,9 +178,9 @@ // This pass has work to do only if we find intrinsics we are going to lower // in the module. bool doInitialization(Module &M) override { - if (coro::declaresIntrinsics(M, {"llvm.coro.begin", "llvm.coro.resume", - "llvm.coro.destroy", "llvm.coro.suspend", - "llvm.coro.end"})) + if (coro::declaresIntrinsics(M, {"llvm.coro.begin", "llvm.coro.end", + "llvm.coro.resume", "llvm.coro.destroy", + "llvm.coro.done", "llvm.coro.suspend"})) L = llvm::make_unique(M); return false; } Index: lib/Transforms/Coroutines/CoroSplit.cpp =================================================================== --- lib/Transforms/Coroutines/CoroSplit.cpp +++ lib/Transforms/Coroutines/CoroSplit.cpp @@ -62,8 +62,8 @@ Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size()); Shape.ResumeSwitch = Switch; - uint32_t SuspendIndex = 0; - for (auto S : Shape.CoroSuspends) { + size_t SuspendIndex = 0; + for (CoroSuspendInst *S : Shape.CoroSuspends) { ConstantInt *IndexVal = Shape.getIndex(SuspendIndex); // Replace CoroSave with a store to Index: @@ -71,9 +71,18 @@ // store i32 0, i32* %index.addr1 auto *Save = S->getCoroSave(); Builder.SetInsertPoint(Save); - auto *GepIndex = Builder.CreateConstInBoundsGEP2_32( - FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr"); - Builder.CreateStore(IndexVal, GepIndex); + if (S->isFinal()) { + // Final suspend point is represented by storing zero in ResumeFnAddr. + auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, + 0, "ResumeFn.addr"); + auto *NullPtr = ConstantPointerNull::get(cast( + cast(GepIndex->getType())->getElementType())); + Builder.CreateStore(NullPtr, GepIndex); + } else { + auto *GepIndex = Builder.CreateConstInBoundsGEP2_32( + FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr"); + Builder.CreateStore(IndexVal, GepIndex); + } Save->replaceAllUsesWith(ConstantTokenNone::get(C)); Save->eraseFromParent(); @@ -135,6 +144,37 @@ BB->getTerminator()->eraseFromParent(); } +// Rewrite final suspend point handling. We do not use suspend index to +// represent the final suspend point. Instead we zero-out ResumeFnAddr in the +// coroutine frame, since it is undefined behavior to resume a coroutine +// suspended at the final suspend point. Thus, in the resume function, we can +// simply remove the last case (when coro::Shape is built, the final suspend +// point (if present) is always the last element of CoroSuspends array). +// In the destroy function, we add a code sequence to check if ResumeFnAddress +// is Null, and if so, jump to the appropriate label to handle cleanup from the +// final suspend point. +static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr, + coro::Shape &Shape, SwitchInst *Switch, + bool IsDestroy) { + assert(Shape.HasFinalSuspend); + auto FinalCase = --Switch->case_end(); + BasicBlock *ResumeBB = FinalCase.getCaseSuccessor(); + Switch->removeCase(FinalCase); + if (IsDestroy) { + BasicBlock *OldSwitchBB = Switch->getParent(); + auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch"); + Builder.SetInsertPoint(OldSwitchBB->getTerminator()); + auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, FramePtr, + 0, 0, "ResumeFn.addr"); + auto *Load = Builder.CreateLoad(GepIndex); + auto *NullPtr = + ConstantPointerNull::get(cast(Load->getType())); + auto *Cond = Builder.CreateICmpEQ(Load, NullPtr); + Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB); + OldSwitchBB->getTerminator()->eraseFromParent(); + } +} + // Create a resume clone by cloning the body of the original function, setting // new entry block and replacing coro.suspend an appropriate value to force // resume or cleanup pass for every suspend point. @@ -205,6 +245,15 @@ Value *OldVFrame = cast(VMap[Shape.CoroBegin]); OldVFrame->replaceAllUsesWith(NewVFrame); + // Rewrite final suspend handling as it is not done via switch (allows to + // remove final case from the switch, since it is undefined behavior to resume + // the coroutine suspended at the final suspend point. + if (Shape.HasFinalSuspend) { + auto *Switch = cast(VMap[Shape.ResumeSwitch]); + bool IsDestroy = FnIndex != 0; + handleFinalSuspend(Builder, NewFramePtr, Shape, Switch, IsDestroy); + } + // Replace coro suspend with the appropriate resume index. // Replacing coro.suspend with (0) will result in control flow proceeding to // a resume label associated with a suspend point, replacing it with (1) will Index: lib/Transforms/Coroutines/Coroutines.cpp =================================================================== --- lib/Transforms/Coroutines/Coroutines.cpp +++ lib/Transforms/Coroutines/Coroutines.cpp @@ -215,6 +215,7 @@ // Collect "interesting" coroutine intrinsics. void coro::Shape::buildFrom(Function &F) { + size_t FinalSuspendIndex = 0; clear(*this); SmallVector CoroFrames; for (Instruction &I : instructions(F)) { @@ -230,16 +231,12 @@ break; case Intrinsic::coro_suspend: CoroSuspends.push_back(cast(II)); - // Make sure that the final suspend is the first suspend point in the - // CoroSuspends vector. if (CoroSuspends.back()->isFinal()) { + if (HasFinalSuspend) + report_fatal_error( + "Only one suspend point can be marked as final"); HasFinalSuspend = true; - if (CoroSuspends.size() > 1) { - if (CoroSuspends.front()->isFinal()) - report_fatal_error( - "Only one suspend point can be marked as final"); - std::swap(CoroSuspends.front(), CoroSuspends.back()); - } + FinalSuspendIndex = CoroSuspends.size() - 1; } break; case Intrinsic::coro_begin: { @@ -309,4 +306,9 @@ for (CoroSuspendInst *CS : CoroSuspends) if (!CS->getCoroSave()) createCoroSave(CoroBegin, CS); + + // Move final suspend to be the last element in the CoroSuspends vector. + if (HasFinalSuspend && + FinalSuspendIndex != CoroSuspends.size() - 1) + std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back()); } Index: test/Transforms/Coroutines/ex5.ll =================================================================== --- /dev/null +++ test/Transforms/Coroutines/ex5.ll @@ -0,0 +1,73 @@ +; Fifth example from Doc/Coroutines.rst (final suspend) +; RUN: opt < %s -O2 -enable-coroutines -S | FileCheck %s + +define i8* @f(i32 %n) { +entry: + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %size = call i32 @llvm.coro.size.i32() + %alloc = call i8* @malloc(i32 %size) + %hdl = call noalias i8* @llvm.coro.begin(token %id, i8* %alloc) + br label %while.cond +while.cond: + %n.val = phi i32 [ %n, %entry ], [ %dec, %while.body ] + %cmp = icmp sgt i32 %n.val, 0 + br i1 %cmp, label %while.body, label %while.end + +while.body: + %dec = add nsw i32 %n.val, -1 + call void @print(i32 %n.val) #4 + %s = call i8 @llvm.coro.suspend(token none, i1 false) + switch i8 %s, label %suspend [i8 0, label %while.cond + i8 1, label %cleanup] +while.end: + %s.final = call i8 @llvm.coro.suspend(token none, i1 true) + switch i8 %s.final, label %suspend [i8 0, label %trap + i8 1, label %cleanup] +trap: + call void @llvm.trap() + unreachable +cleanup: + %mem = call i8* @llvm.coro.free(token %id, i8* %hdl) + call void @free(i8* %mem) + br label %suspend +suspend: + call void @llvm.coro.end(i8* %hdl, i1 false) + ret i8* %hdl +} + +declare noalias i8* @malloc(i32) +declare void @print(i32) +declare void @llvm.trap() +declare void @free(i8* nocapture) + +declare token @llvm.coro.id( i32, i8*, i8*, i8*) +declare i32 @llvm.coro.size.i32() +declare i8* @llvm.coro.begin(token, i8*) +declare token @llvm.coro.save(i8*) +declare i8 @llvm.coro.suspend(token, i1) +declare i8* @llvm.coro.free(token, i8*) +declare void @llvm.coro.end(i8*, i1) + +; CHECK-LABEL: @main +define i32 @main() { +entry: + %hdl = call i8* @f(i32 4) + br label %while +while: + call void @llvm.coro.resume(i8* %hdl) + %done = call i1 @llvm.coro.done(i8* %hdl) + br i1 %done, label %end, label %while +end: + call void @llvm.coro.destroy(i8* %hdl) + ret i32 0 + +; CHECK: call void @print(i32 4) +; CHECK: call void @print(i32 3) +; CHECK: call void @print(i32 2) +; CHECK: call void @print(i32 1) +; CHECK: ret i32 0 +} + +declare i1 @llvm.coro.done(i8*) +declare void @llvm.coro.resume(i8*) +declare void @llvm.coro.destroy(i8*)