diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -1213,9 +1213,8 @@ def int_coro_async_resume : Intrinsic<[llvm_ptr_ty], [], []>; -def int_coro_suspend_async : Intrinsic<[llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty], - [llvm_ptr_ty, llvm_ptr_ty, llvm_vararg_ty], - []>; +def int_coro_suspend_async + : Intrinsic<[llvm_any_ty], [llvm_ptr_ty, llvm_ptr_ty, llvm_vararg_ty], []>; def int_coro_prepare_async : Intrinsic<[llvm_ptr_ty], [llvm_ptr_ty], [IntrNoMem]>; def int_coro_begin : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty], diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h --- a/llvm/lib/Transforms/Coroutines/CoroInternal.h +++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h @@ -146,6 +146,7 @@ struct AsyncLoweringStorage { FunctionType *AsyncFuncTy; Value *Context; + CallingConv::ID AsyncCC; unsigned ContextArgNo; uint64_t ContextHeaderSize; uint64_t ContextAlignment; @@ -208,7 +209,8 @@ case coro::ABI::RetconOnce: return RetconLowering.ResumePrototype->getFunctionType(); case coro::ABI::Async: - return AsyncLowering.AsyncFuncTy; + // Not used. The function type depends on the active suspend. + return nullptr; } llvm_unreachable("Unknown coro::ABI enum"); @@ -245,7 +247,7 @@ case coro::ABI::RetconOnce: return RetconLowering.ResumePrototype->getCallingConv(); case coro::ABI::Async: - return CallingConv::Swift; + return AsyncLowering.AsyncCC; } llvm_unreachable("Unknown coro::ABI enum"); } diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -454,11 +454,23 @@ } } +static FunctionType * +getFunctionTypeFromAsyncSuspend(AnyCoroSuspendInst *Suspend) { + auto *AsyncSuspend = cast(Suspend); + auto *StructTy = cast(AsyncSuspend->getType()); + auto &Context = Suspend->getParent()->getParent()->getContext(); + auto *VoidTy = Type::getVoidTy(Context); + return FunctionType::get(VoidTy, StructTy->elements(), false); +} + static Function *createCloneDeclaration(Function &OrigF, coro::Shape &Shape, const Twine &Suffix, - Module::iterator InsertBefore) { + Module::iterator InsertBefore, + AnyCoroSuspendInst *ActiveSuspend) { Module *M = OrigF.getParent(); - auto *FnTy = Shape.getResumeFunctionType(); + auto *FnTy = (Shape.ABI != coro::ABI::Async) + ? Shape.getResumeFunctionType() + : getFunctionTypeFromAsyncSuspend(ActiveSuspend); Function *NewF = Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage, @@ -805,7 +817,7 @@ // Create the new function if we don't already have one. if (!NewF) { NewF = createCloneDeclaration(OrigF, Shape, Suffix, - OrigF.getParent()->end()); + OrigF.getParent()->end(), ActiveSuspend); } // Replace all args with undefs. The buildCoroutineFrame algorithm already @@ -1528,8 +1540,8 @@ auto *Suspend = cast(Shape.CoroSuspends[Idx]); // Create the clone declaration. - auto *Continuation = - createCloneDeclaration(F, Shape, ".resume." + Twine(Idx), NextF); + auto *Continuation = createCloneDeclaration( + F, Shape, ".resume." + Twine(Idx), NextF, Suspend); Clones.push_back(Continuation); // Insert a branch to a new return block immediately before the suspend @@ -1629,7 +1641,7 @@ // Create the clone declaration. auto Continuation = - createCloneDeclaration(F, Shape, ".resume." + Twine(i), NextF); + createCloneDeclaration(F, Shape, ".resume." + Twine(i), NextF, nullptr); Clones.push_back(Continuation); // Insert a branch to the unified return block immediately before diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp --- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -399,11 +399,7 @@ this->AsyncLowering.ContextAlignment = AsyncId->getStorageAlignment().value(); this->AsyncLowering.AsyncFuncPointer = AsyncId->getAsyncFunctionPointer(); - auto &Context = F.getContext(); - auto *Int8PtrTy = Type::getInt8PtrTy(Context); - auto *VoidTy = Type::getVoidTy(Context); - this->AsyncLowering.AsyncFuncTy = - FunctionType::get(VoidTy, {Int8PtrTy, Int8PtrTy, Int8PtrTy}, false); + this->AsyncLowering.AsyncCC = F.getCallingConv(); break; }; case Intrinsic::coro_id_retcon: diff --git a/llvm/test/Transforms/Coroutines/coro-async.ll b/llvm/test/Transforms/Coroutines/coro-async.ll --- a/llvm/test/Transforms/Coroutines/coro-async.ll +++ b/llvm/test/Transforms/Coroutines/coro-async.ll @@ -377,6 +377,75 @@ ; CHECK: musttail call swiftcc void @asyncReturn( ; CHECK: ret void +@polymorphic_suspend_return_fp = constant <{ i32, i32 }> + <{ i32 trunc ( ; Relative pointer to async function + i64 sub ( + i64 ptrtoint (void (i8*, %async.task*, %async.actor*)* @polymorphic_suspend_return to i64), + i64 ptrtoint (i32* getelementptr inbounds (<{ i32, i32 }>, <{ i32, i32 }>* @polymorphic_suspend_return_fp, i32 0, i32 1) to i64) + ) + to i32), + i32 64 ; Initial async context size without space for frame +}> + +define swiftcc void @polymorphic_suspend_return(i8* %async.ctxt, %async.task* %task, %async.actor* %actor) { +entry: + %tmp = alloca { i64, i64 }, align 8 + %proj.1 = getelementptr inbounds { i64, i64 }, { i64, i64 }* %tmp, i64 0, i32 0 + %proj.2 = getelementptr inbounds { i64, i64 }, { i64, i64 }* %tmp, i64 0, i32 1 + + %id = call token @llvm.coro.id.async(i32 128, i32 16, i32 0, + i8* bitcast (<{i32, i32}>* @polymorphic_suspend_return_fp to i8*)) + %hdl = call i8* @llvm.coro.begin(token %id, i8* null) + store i64 0, i64* %proj.1, align 8 + store i64 1, i64* %proj.2, align 8 + call void @some_may_write(i64* %proj.1) + + ; Begin lowering: apply %my_other_async_function(%args...) + + ; setup callee context + %arg0 = bitcast %async.task* %task to i8* + %arg1 = bitcast <{ i32, i32}>* @my_other_async_function_fp to i8* + %callee_context = call i8* @llvm.coro.async.context.alloc(i8* %arg0, i8* %arg1) + %callee_context.0 = bitcast i8* %callee_context to %async.ctxt* + ; store arguments ... + ; ... (omitted) + + ; store the return continuation + %callee_context.return_to_caller.addr = getelementptr inbounds %async.ctxt, %async.ctxt* %callee_context.0, i32 0, i32 1 + %return_to_caller.addr = bitcast void(i8*, %async.task*, %async.actor*)** %callee_context.return_to_caller.addr to i8** + %resume.func_ptr = call i8* @llvm.coro.async.resume() + store i8* %resume.func_ptr, i8** %return_to_caller.addr + + ; store caller context into callee context + %callee_context.caller_context.addr = getelementptr inbounds %async.ctxt, %async.ctxt* %callee_context.0, i32 0, i32 0 + store i8* %async.ctxt, i8** %callee_context.caller_context.addr + %resume_proj_fun = bitcast i8*(i8*)* @resume_context_projection to i8* + %callee = bitcast void(i8*, %async.task*, %async.actor*)* @asyncSuspend to i8* + %res = call {i8*, i8*, i8*, i8*} (i8*, i8*, ...) @llvm.coro.suspend.async.sl_p0i8p0i8p0i8p0i8s( + i8* %resume.func_ptr, + i8* %resume_proj_fun, + void (i8*, i8*, %async.task*, %async.actor*)* @my_async_function.my_other_async_function_fp.apply, + i8* %callee, i8* %callee_context, %async.task* %task, %async.actor *%actor) + + call void @llvm.coro.async.context.dealloc(i8* %callee_context) + %continuation_task_arg = extractvalue {i8*, i8*, i8*, i8*} %res, 3 + %task.2 = bitcast i8* %continuation_task_arg to %async.task* + %val = load i64, i64* %proj.1 + call void @some_user(i64 %val) + %val.2 = load i64, i64* %proj.2 + call void @some_user(i64 %val.2) + + tail call swiftcc void @asyncReturn(i8* %async.ctxt, %async.task* %task.2, %async.actor* %actor) + call i1 (i8*, i1, ...) @llvm.coro.end.async(i8* %hdl, i1 0) + unreachable +} + +; CHECK-LABEL: define swiftcc void @polymorphic_suspend_return(i8* %async.ctxt, %async.task* %task, %async.actor* %actor) +; CHECK-LABEL: define internal swiftcc void @polymorphic_suspend_return.resume.0(i8* {{.*}}%0, i8* {{.*}}%1, i8* {{.*}}%2, i8* {{.*}}%3) +; CHECK: bitcast i8* %3 to %async.task* +; CHECK: } + +declare { i8*, i8*, i8*, i8* } @llvm.coro.suspend.async.sl_p0i8p0i8p0i8p0i8s(i8*, i8*, ...) declare i8* @llvm.coro.prepare.async(i8*) declare token @llvm.coro.id.async(i32, i32, i32, i8*) declare i8* @llvm.coro.begin(token, i8*)