diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -477,7 +477,8 @@ Function *NewF = Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage, OrigF.getName() + Suffix); - NewF->addParamAttr(0, Attribute::NonNull); + if (Shape.ABI != coro::ABI::Async) + NewF->addParamAttr(0, Attribute::NonNull); // For the async lowering ABI we can't guarantee that the context argument is // not access via a different pointer not based on the argument. @@ -819,6 +820,13 @@ Attrs = Attrs.addParamAttributes(Context, ParamIndex, ParamAttrs); } +static void addAsyncContextAttrs(AttributeList &Attrs, LLVMContext &Context, + unsigned ParamIndex) { + AttrBuilder ParamAttrs; + ParamAttrs.addAttribute(Attribute::SwiftAsync); + Attrs = Attrs.addParamAttributes(Context, ParamIndex, ParamAttrs); +} + /// Clone the body of the original function into a resume function of /// some sort. void CoroCloner::create() { @@ -934,9 +942,16 @@ // followed by a return. // Don't change returns to unreachable because that will trip up the verifier. // These returns should be unreachable from the clone. - case coro::ABI::Async: + case coro::ABI::Async: { + auto *ActiveAsyncSuspend = cast(ActiveSuspend); + if (OrigF.hasParamAttribute(Shape.AsyncLowering.ContextArgNo, + Attribute::SwiftAsync)) { + auto ContextArgIndex = ActiveAsyncSuspend->getStorageArgumentIndex(); + addAsyncContextAttrs(NewAttrs, Context, ContextArgIndex); + } break; } + } NewF->setAttributes(NewAttrs); NewF->setCallingConv(Shape.getResumeFunctionCC()); diff --git a/llvm/test/Transforms/Coroutines/coro-async.ll b/llvm/test/Transforms/Coroutines/coro-async.ll --- a/llvm/test/Transforms/Coroutines/coro-async.ll +++ b/llvm/test/Transforms/Coroutines/coro-async.ll @@ -55,7 +55,7 @@ } -define swiftcc void @my_async_function(i8* %async.ctxt, %async.task* %task, %async.actor* %actor) !dbg !1 { +define swiftcc void @my_async_function(i8* swiftasync %async.ctxt, %async.task* %task, %async.actor* %actor) !dbg !1 { entry: %tmp = alloca { i64, i64 }, align 8 %proj.1 = getelementptr inbounds { i64, i64 }, { i64, i64 }* %tmp, i64 0, i32 0 @@ -119,7 +119,7 @@ ; CHECK: @my_async_function_pa_fp = constant <{ i32, i32 }> <{ {{.*}}, i32 176 } ; CHECK: @my_async_function2_fp = constant <{ i32, i32 }> <{ {{.*}}, i32 176 } -; CHECK-LABEL: define swiftcc void @my_async_function(i8* %async.ctxt, %async.task* %task, %async.actor* %actor) +; CHECK-LABEL: define swiftcc void @my_async_function(i8* swiftasync %async.ctxt, %async.task* %task, %async.actor* %actor) ; CHECK-SAME: !dbg ![[SP1:[0-9]+]] { ; CHECK: entry: ; CHECK: [[FRAMEPTR:%.*]] = getelementptr inbounds i8, i8* %async.ctxt, i64 128 @@ -149,7 +149,7 @@ ; CHECK: ret void ; CHECK: } -; CHECK-LABEL: define internal swiftcc void @my_async_function.resume.0(i8* nocapture readonly %0, i8* %1, i8* nocapture readnone %2) +; CHECK-LABEL: define internal swiftcc void @my_async_function.resume.0(i8* nocapture readonly swiftasync %0, i8* %1, i8* nocapture readnone %2) ; CHECK-SAME: !dbg ![[SP2:[0-9]+]] { ; CHECK: entryresume.0: ; CHECK: [[CALLER_CONTEXT_ADDR:%.*]] = bitcast i8* %0 to i8**