diff --git a/llvm/docs/Coroutines.rst b/llvm/docs/Coroutines.rst --- a/llvm/docs/Coroutines.rst +++ b/llvm/docs/Coroutines.rst @@ -1389,6 +1389,48 @@ | | Landingpad | nothing | nothing | +------------+-------------+-------------------+-------------------------------+ + +'llvm.coro.end.async' Intrinsic +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +:: + + declare i1 @llvm.coro.end.async(i8* , i1 , ...) + +Overview: +""""""""" + +The '``llvm.coro.end.async``' marks the point where execution of the resume part +of the coroutine should end and control should return to the caller. As part of +its variable tail arguments this instruction allows to specify a function and +the function's arguments that are to be tail called as the last action before +returning. + + +Arguments: +"""""""""" + +The first argument should refer to the coroutine handle of the enclosing +coroutine. A frontend is allowed to supply null as the first parameter, in this +case `coro-early` pass will replace the null with an appropriate coroutine +handle value. + +The second argument should be `true` if this coro.end is in the block that is +part of the unwind sequence leaving the coroutine body due to an exception and +`false` otherwise. + +The third argument if present should specify a function to be called. + +If the third argument is present, the remaining arguments are the arguments to +the function call. + +.. code-block:: llvm + + call i1 (i8*, i1, ...) @llvm.coro.end.async( + i8* %hdl, i1 0, + void (i8*, %async.task*, %async.actor*)* @must_tail_call_return, + i8* %ctxt, %async.task* %task, %async.actor* %actor) + unreachable + .. _coro.suspend: .. _suspend points: diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -1213,6 +1213,8 @@ ReadOnly>, NoCapture>]>; def int_coro_end : Intrinsic<[llvm_i1_ty], [llvm_ptr_ty, llvm_i1_ty], []>; +def int_coro_end_async + : Intrinsic<[llvm_i1_ty], [llvm_ptr_ty, llvm_i1_ty, llvm_vararg_ty], []>; def int_coro_frame : Intrinsic<[llvm_ptr_ty], [], [IntrNoMem]>; def int_coro_noop : Intrinsic<[llvm_ptr_ty], [], [IntrNoMem]>; diff --git a/llvm/lib/Transforms/Coroutines/CoroEarly.cpp b/llvm/lib/Transforms/Coroutines/CoroEarly.cpp --- a/llvm/lib/Transforms/Coroutines/CoroEarly.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroEarly.cpp @@ -164,10 +164,11 @@ if (cast(&I)->isFinal()) CB->setCannotDuplicate(); break; + case Intrinsic::coro_end_async: case Intrinsic::coro_end: // Make sure that fallthrough coro.end is not duplicated as CoroSplit // pass expects that there is at most one fallthrough coro.end. - if (cast(&I)->isFallthrough()) + if (cast(&I)->isFallthrough()) CB->setCannotDuplicate(); break; case Intrinsic::coro_noop: @@ -219,8 +220,8 @@ return coro::declaresIntrinsics( M, {"llvm.coro.id", "llvm.coro.id.retcon", "llvm.coro.id.retcon.once", "llvm.coro.destroy", "llvm.coro.done", "llvm.coro.end", - "llvm.coro.noop", "llvm.coro.free", "llvm.coro.promise", - "llvm.coro.resume", "llvm.coro.suspend"}); + "llvm.coro.end.async", "llvm.coro.noop", "llvm.coro.free", + "llvm.coro.promise", "llvm.coro.resume", "llvm.coro.suspend"}); } PreservedAnalyses CoroEarlyPass::run(Function &F, FunctionAnalysisManager &) { diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp --- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -2170,9 +2170,26 @@ } // Put CoroEnds into their own blocks. - for (CoroEndInst *CE : Shape.CoroEnds) + for (AnyCoroEndInst *CE : Shape.CoroEnds) { splitAround(CE, "CoroEnd"); + // Emit the musttail call function in a new block before the CoroEnd. + // We do this here so that the right suspend crossing info is computed for + // the uses of the musttail call function call. (Arguments to the coro.end + // instructions would be ignored) + if (auto *AsyncEnd = dyn_cast(CE)) { + auto *MustTailCallFn = AsyncEnd->getMustTailCallFunction(); + if (!MustTailCallFn) + continue; + IRBuilder<> Builder(AsyncEnd); + SmallVector Args(AsyncEnd->args()); + auto Arguments = ArrayRef(Args).drop_front(3); + auto *Call = createMustTailCall(AsyncEnd->getDebugLoc(), MustTailCallFn, + Arguments, Builder); + splitAround(Call, "MustTailCall.Before.CoroEnd"); + } + } + // Transforms multi-edge PHI Nodes, so that any value feeding into a PHI will // never has its definition separated from the PHI by the suspend point. rewritePHIs(F); diff --git a/llvm/lib/Transforms/Coroutines/CoroInstr.h b/llvm/lib/Transforms/Coroutines/CoroInstr.h --- a/llvm/lib/Transforms/Coroutines/CoroInstr.h +++ b/llvm/lib/Transforms/Coroutines/CoroInstr.h @@ -577,8 +577,7 @@ } }; -/// This represents the llvm.coro.end instruction. -class LLVM_LIBRARY_VISIBILITY CoroEndInst : public IntrinsicInst { +class LLVM_LIBRARY_VISIBILITY AnyCoroEndInst : public IntrinsicInst { enum { FrameArg, UnwindArg }; public: @@ -587,6 +586,19 @@ return cast(getArgOperand(UnwindArg))->isOneValue(); } + // Methods to support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + auto ID = I->getIntrinsicID(); + return ID == Intrinsic::coro_end || ID == Intrinsic::coro_end_async; + } + static bool classof(const Value *V) { + return isa(V) && classof(cast(V)); + } +}; + +/// This represents the llvm.coro.end instruction. +class LLVM_LIBRARY_VISIBILITY CoroEndInst : public AnyCoroEndInst { +public: // Methods to support type inquiry through isa, cast, and dyn_cast: static bool classof(const IntrinsicInst *I) { return I->getIntrinsicID() == Intrinsic::coro_end; @@ -596,6 +608,30 @@ } }; +/// This represents the llvm.coro.end instruction. +class LLVM_LIBRARY_VISIBILITY CoroAsyncEndInst : public AnyCoroEndInst { + enum { FrameArg, UnwindArg, MustTailCallFuncArg }; + +public: + void checkWellFormed() const; + + Function *getMustTailCallFunction() const { + if (getNumArgOperands() < 3) + return nullptr; + + return cast( + getArgOperand(MustTailCallFuncArg)->stripPointerCasts()); + } + + // Methods to support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_end_async; + } + static bool classof(const Value *V) { + return isa(V) && classof(cast(V)); + } +}; + /// This represents the llvm.coro.alloca.alloc instruction. class LLVM_LIBRARY_VISIBILITY CoroAllocaAllocInst : public IntrinsicInst { enum { SizeArg, AlignArg }; diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h --- a/llvm/lib/Transforms/Coroutines/CoroInternal.h +++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h @@ -92,7 +92,7 @@ // values used during CoroSplit pass. struct LLVM_LIBRARY_VISIBILITY Shape { CoroBeginInst *CoroBegin; - SmallVector CoroEnds; + SmallVector CoroEnds; SmallVector CoroSizes; SmallVector CoroSuspends; SmallVector SwiftErrorOps; @@ -270,6 +270,8 @@ }; void buildCoroutineFrame(Function &F, Shape &Shape); +CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn, + ArrayRef Arguments, IRBuilder<> &); } // End namespace coro. } // End namespace llvm diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -173,8 +173,53 @@ Shape.emitDealloc(Builder, FramePtr, CG); } +/// Replace an llvm.coro.end.async. +/// Will inline the must tail call function call if there is one. +/// \returns true if cleanup of the coro.end block is needed, false otherwise. +static bool replaceCoroEndAsync(AnyCoroEndInst *End) { + IRBuilder<> Builder(End); + + auto *EndAsync = dyn_cast(End); + if (!EndAsync) { + Builder.CreateRetVoid(); + return true /*needs cleanup of coro.end block*/; + } + + auto *MustTailCallFunc = EndAsync->getMustTailCallFunction(); + if (!MustTailCallFunc) { + Builder.CreateRetVoid(); + return true /*needs cleanup of coro.end block*/; + } + + // Move the must tail call from the predecessor block into the end block. + auto *CoroEndBlock = End->getParent(); + auto *MustTailCallFuncBlock = CoroEndBlock->getSinglePredecessor(); + assert(MustTailCallFuncBlock && "Must have a single predecessor block"); + auto It = MustTailCallFuncBlock->getTerminator()->getIterator(); + auto *MustTailCall = cast(&*std::prev(It)); + CoroEndBlock->getInstList().splice( + End->getIterator(), MustTailCallFuncBlock->getInstList(), MustTailCall); + + // Insert the return instruction. + Builder.SetInsertPoint(End); + Builder.CreateRetVoid(); + InlineFunctionInfo FnInfo; + + // Remove the rest of the block, by splitting it into an unreachable block. + auto *BB = End->getParent(); + BB->splitBasicBlock(End); + BB->getTerminator()->eraseFromParent(); + + auto InlineRes = InlineFunction(*MustTailCall, FnInfo); + assert(InlineRes.isSuccess() && "Expected inlining to succeed"); + (void)InlineRes; + + // We have cleaned up the coro.end block above. + return false; +} + /// Replace a non-unwind call to llvm.coro.end. -static void replaceFallthroughCoroEnd(CoroEndInst *End, +static void replaceFallthroughCoroEnd(AnyCoroEndInst *End, const coro::Shape &Shape, Value *FramePtr, bool InResume, CallGraph *CG) { // Start inserting right before the coro.end. @@ -192,9 +237,12 @@ break; // In async lowering this returns. - case coro::ABI::Async: - Builder.CreateRetVoid(); + case coro::ABI::Async: { + bool CoroEndBlockNeedsCleanup = replaceCoroEndAsync(End); + if (!CoroEndBlockNeedsCleanup) + return; break; + } // In unique continuation lowering, the continuations always return void. // But we may have implicitly allocated storage. @@ -229,8 +277,9 @@ } /// Replace an unwind call to llvm.coro.end. -static void replaceUnwindCoroEnd(CoroEndInst *End, const coro::Shape &Shape, - Value *FramePtr, bool InResume, CallGraph *CG){ +static void replaceUnwindCoroEnd(AnyCoroEndInst *End, const coro::Shape &Shape, + Value *FramePtr, bool InResume, + CallGraph *CG) { IRBuilder<> Builder(End); switch (Shape.ABI) { @@ -258,7 +307,7 @@ } } -static void replaceCoroEnd(CoroEndInst *End, const coro::Shape &Shape, +static void replaceCoroEnd(AnyCoroEndInst *End, const coro::Shape &Shape, Value *FramePtr, bool InResume, CallGraph *CG) { if (End->isUnwind()) replaceUnwindCoroEnd(End, Shape, FramePtr, InResume, CG); @@ -511,10 +560,10 @@ } void CoroCloner::replaceCoroEnds() { - for (CoroEndInst *CE : Shape.CoroEnds) { + for (AnyCoroEndInst *CE : Shape.CoroEnds) { // We use a null call graph because there's no call graph node for // the cloned function yet. We'll just be rebuilding that later. - auto NewCE = cast(VMap[CE]); + auto NewCE = cast(VMap[CE]); replaceCoroEnd(NewCE, Shape, NewFramePtr, /*in resume*/ true, nullptr); } } @@ -1385,6 +1434,23 @@ } } +CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn, + ArrayRef Arguments, + IRBuilder<> &Builder) { + auto FnTy = + cast(MustTailCallFn->getType()->getPointerElementType()); + // Coerce the arguments, llvm optimizations seem to ignore the types in + // vaarg functions and throws away casts in optimized mode. + SmallVector CallArgs; + coerceArguments(Builder, FnTy, Arguments, CallArgs); + + auto *TailCall = Builder.CreateCall(FnTy, MustTailCallFn, CallArgs); + TailCall->setTailCallKind(CallInst::TCK_MustTail); + TailCall->setDebugLoc(Loc); + TailCall->setCallingConv(MustTailCallFn->getCallingConv()); + return TailCall; +} + static void splitAsyncCoroutine(Function &F, coro::Shape &Shape, SmallVectorImpl &Clones) { assert(Shape.ABI == coro::ABI::Async); @@ -1443,18 +1509,10 @@ // Insert the call to the tail call function and inline it. auto *Fn = Suspend->getMustTailCallFunction(); - auto DbgLoc = Suspend->getDebugLoc(); - SmallVector Args(Suspend->operand_values()); - auto FnArgs = ArrayRef(Args).drop_front(3).drop_back(1); - auto FnTy = cast(Fn->getType()->getPointerElementType()); - // Coerce the arguments, llvm optimizations seem to ignore the types in - // vaarg functions and throws away casts in optimized mode. - SmallVector CallArgs; - coerceArguments(Builder, FnTy, FnArgs, CallArgs); - auto *TailCall = Builder.CreateCall(FnTy, Fn, CallArgs); - TailCall->setDebugLoc(DbgLoc); - TailCall->setTailCall(); - TailCall->setCallingConv(Fn->getCallingConv()); + SmallVector Args(Suspend->args()); + auto FnArgs = ArrayRef(Args).drop_front(3); + auto *TailCall = + coro::createMustTailCall(Suspend->getDebugLoc(), Fn, FnArgs, Builder); Builder.CreateRetVoid(); InlineFunctionInfo FnInfo; auto InlineRes = InlineFunction(*TailCall, FnInfo); @@ -1683,7 +1741,7 @@ if (!Shape.CoroBegin) return; - for (llvm::CoroEndInst *End : Shape.CoroEnds) { + for (llvm::AnyCoroEndInst *End : Shape.CoroEnds) { auto &Context = End->getContext(); End->replaceAllUsesWith(ConstantInt::getFalse(Context)); End->eraseFromParent(); diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp --- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -131,6 +131,7 @@ "llvm.coro.destroy", "llvm.coro.done", "llvm.coro.end", + "llvm.coro.end.async", "llvm.coro.frame", "llvm.coro.free", "llvm.coro.id", @@ -316,11 +317,16 @@ CoroBegin = CB; break; } + case Intrinsic::coro_end_async: case Intrinsic::coro_end: - CoroEnds.push_back(cast(II)); - if (CoroEnds.back()->isFallthrough()) { + CoroEnds.push_back(cast(II)); + if (auto *AsyncEnd = dyn_cast(II)) { + AsyncEnd->checkWellFormed(); + } + if (CoroEnds.back()->isFallthrough() && isa(II)) { // Make sure that the fallthrough coro.end is the first element in the // CoroEnds vector. + // Note: I don't think this is neccessary anymore. if (CoroEnds.size() > 1) { if (CoroEnds.front()->isFallthrough()) report_fatal_error( @@ -353,7 +359,7 @@ } // Replace all coro.ends with unreachable instruction. - for (CoroEndInst *CE : CoroEnds) + for (AnyCoroEndInst *CE : CoroEnds) changeToUnreachable(CE, /*UseLLVMTrap=*/false); return; @@ -713,6 +719,19 @@ checkAsyncContextProjectFunction(this, getAsyncContextProjectionFunction()); } +void CoroAsyncEndInst::checkWellFormed() const { + auto *MustTailCallFunc = getMustTailCallFunction(); + if (!MustTailCallFunc) + return; + auto *FnTy = + cast(MustTailCallFunc->getType()->getPointerElementType()); + if (FnTy->getNumParams() != (getNumArgOperands() - 3)) + fail(this, + "llvm.coro.end.async must tail call function argument type must " + "match the tail arguments", + MustTailCallFunc); +} + void LLVMAddCoroEarlyPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createCoroEarlyLegacyPass()); } diff --git a/llvm/test/Transforms/Coroutines/coro-async.ll b/llvm/test/Transforms/Coroutines/coro-async.ll --- a/llvm/test/Transforms/Coroutines/coro-async.ll +++ b/llvm/test/Transforms/Coroutines/coro-async.ll @@ -94,7 +94,7 @@ call void @some_user(i64 %val.2) tail call swiftcc void @asyncReturn(i8* %async.ctxt, %async.task* %task.2, %async.actor* %actor) - call i1 @llvm.coro.end(i8* %hdl, i1 0) + call i1 (i8*, i1, ...) @llvm.coro.end.async(i8* %hdl, i1 0) unreachable } @@ -311,12 +311,75 @@ %continuation_task_arg = extractvalue {i8*, i8*, i8*} %res, 1 %task.2 = bitcast i8* %continuation_task_arg to %async.task* tail call swiftcc void @asyncReturn(i8* %async.ctxt, %async.task* %task.2, %async.actor* %actor) - call i1 @llvm.coro.end(i8* %hdl, i1 0) + call i1 (i8*, i1, ...) @llvm.coro.end.async(i8* %hdl, i1 0) unreachable } + +@multiple_coro_end_async_fp = constant <{ i32, i32 }> + <{ i32 trunc ( ; Relative pointer to async function + i64 sub ( + i64 ptrtoint (void (i8*, %async.task*, %async.actor*)* @multiple_coro_end_async to i64), + i64 ptrtoint (i32* getelementptr inbounds (<{ i32, i32 }>, <{ i32, i32 }>* @multiple_coro_end_async_fp, i32 0, i32 1) to i64) + ) + to i32), + i32 128 ; Initial async context size without space for frame +}> + +define swiftcc void @must_tail_call_return(i8* %async.ctxt, %async.task* %task, %async.actor* %actor) { + musttail call swiftcc void @asyncReturn(i8* %async.ctxt, %async.task* %task, %async.actor* %actor) + ret void +} + +define swiftcc void @multiple_coro_end_async(i8* %async.ctxt, %async.task* %task, %async.actor* %actor) { +entry: + %id = call token @llvm.coro.id.async(i32 128, i32 16, i32 0, + i8* bitcast (<{i32, i32}>* @dont_crash_on_cf_fp to i8*)) + %hdl = call i8* @llvm.coro.begin(token %id, i8* null) + %arg0 = bitcast %async.task* %task to i8* + %arg1 = bitcast <{ i32, i32}>* @my_other_async_function_fp to i8* + %callee_context = call i8* @llvm.coro.async.context.alloc(i8* %arg0, i8* %arg1) + %callee_context.0 = bitcast i8* %callee_context to %async.ctxt* + %callee_context.return_to_caller.addr = getelementptr inbounds %async.ctxt, %async.ctxt* %callee_context.0, i32 0, i32 1 + %return_to_caller.addr = bitcast void(i8*, %async.task*, %async.actor*)** %callee_context.return_to_caller.addr to i8** + %resume.func_ptr = call i8* @llvm.coro.async.resume() + store i8* %resume.func_ptr, i8** %return_to_caller.addr + %callee_context.caller_context.addr = getelementptr inbounds %async.ctxt, %async.ctxt* %callee_context.0, i32 0, i32 0 + store i8* %async.ctxt, i8** %callee_context.caller_context.addr + %resume_proj_fun = bitcast i8*(i8*)* @resume_context_projection to i8* + %callee = bitcast void(i8*, %async.task*, %async.actor*)* @asyncSuspend to i8* + %res = call {i8*, i8*, i8*} (i8*, i8*, ...) @llvm.coro.suspend.async( + i8* %resume.func_ptr, + i8* %resume_proj_fun, + void (i8*, i8*, %async.task*, %async.actor*)* @dont_crash_on_cf_dispatch, + i8* %callee, i8* %callee_context, %async.task* %task, %async.actor *%actor) + + call void @llvm.coro.async.context.dealloc(i8* %callee_context) + %continuation_task_arg = extractvalue {i8*, i8*, i8*} %res, 1 + %task.2 = bitcast i8* %continuation_task_arg to %async.task* + %eq = icmp eq i8 * %continuation_task_arg, null + br i1 %eq, label %is_equal, label %is_not_equal + +is_equal: + tail call swiftcc void @asyncReturn(i8* %async.ctxt, %async.task* %task.2, %async.actor* %actor) + call i1 (i8*, i1, ...) @llvm.coro.end.async(i8* %hdl, i1 0) + unreachable + +is_not_equal: + call i1 (i8*, i1, ...) @llvm.coro.end.async( + i8* %hdl, i1 0, + void (i8*, %async.task*, %async.actor*)* @must_tail_call_return, + i8* %async.ctxt, %async.task* %task.2, %async.actor* null) + unreachable +} + +; CHECK-LABEL: define internal swiftcc void @multiple_coro_end_async.resume.0( +; CHECK: musttail call swiftcc void @asyncReturn( +; CHECK: ret void + declare i8* @llvm.coro.prepare.async(i8*) declare token @llvm.coro.id.async(i32, i32, i32, i8*) declare i8* @llvm.coro.begin(token, i8*) +declare i1 @llvm.coro.end.async(i8*, i1, ...) declare i1 @llvm.coro.end(i8*, i1) declare {i8*, i8*, i8*} @llvm.coro.suspend.async(i8*, i8*, ...) declare i8* @llvm.coro.async.context.alloc(i8*, i8*)