diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -1188,7 +1188,7 @@ []>; def int_coro_alloc : Intrinsic<[llvm_i1_ty], [llvm_token_ty], []>; def int_coro_id_async : Intrinsic<[llvm_token_ty], - [llvm_i32_ty, llvm_i32_ty, llvm_ptr_ty, llvm_ptr_ty], + [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_ptr_ty], []>; def int_coro_async_context_alloc : Intrinsic<[llvm_ptr_ty], [llvm_ptr_ty, llvm_ptr_ty], diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp --- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -1911,8 +1911,7 @@ for (User *U : Def->users()) { auto Inst = cast(U); if (Inst->getParent() != CoroBegin->getParent() || - Dom.dominates(CoroBegin, Inst) || - isa(Inst) /*'fake' use of async context argument*/) + Dom.dominates(CoroBegin, Inst)) continue; if (ToMove.insert(Inst)) Worklist.push_back(Inst); diff --git a/llvm/lib/Transforms/Coroutines/CoroInstr.h b/llvm/lib/Transforms/Coroutines/CoroInstr.h --- a/llvm/lib/Transforms/Coroutines/CoroInstr.h +++ b/llvm/lib/Transforms/Coroutines/CoroInstr.h @@ -293,11 +293,13 @@ } /// The async context parameter. - Value *getStorage() const { return getArgOperand(StorageArg); } + Value *getStorage() const { + return getParent()->getParent()->getArg(getStorageArgumentIndex()); + } unsigned getStorageArgumentIndex() const { - auto *Arg = cast(getArgOperand(StorageArg)->stripPointerCasts()); - return Arg->getArgNo(); + auto *Arg = cast(getArgOperand(StorageArg)); + return Arg->getZExtValue(); } /// Return the async function pointer address. This should be the address of diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -668,16 +668,23 @@ auto *FramePtrTy = Shape.FrameTy->getPointerTo(); auto *ProjectionFunc = cast(ActiveSuspend) ->getAsyncContextProjectionFunction(); + auto DbgLoc = + cast(VMap[ActiveSuspend])->getDebugLoc(); // Calling i8* (i8*) auto *CallerContext = Builder.CreateCall( cast(ProjectionFunc->getType()->getPointerElementType()), ProjectionFunc, CalleeContext); CallerContext->setCallingConv(ProjectionFunc->getCallingConv()); + CallerContext->setDebugLoc(DbgLoc); // The frame is located after the async_context header. auto &Context = Builder.getContext(); auto *FramePtrAddr = Builder.CreateConstInBoundsGEP1_32( Type::getInt8Ty(Context), CallerContext, Shape.AsyncLowering.FrameOffset, "async.ctx.frameptr"); + // Inline the projection function. + InlineFunctionInfo InlineInfo; + auto InlineRes = InlineFunction(*CallerContext, InlineInfo); + assert(InlineRes.isSuccess()); return Builder.CreateBitCast(FramePtrAddr, FramePtrTy); } // In continuation-lowering, the argument is the opaque storage. @@ -1364,6 +1371,22 @@ Suspend->setOperand(0, UndefValue::get(Int8PtrTy)); } +/// Coerce the arguments in \p FnArgs according to \p FnTy in \p CallArgs. +static void coerceArguments(IRBuilder<> &Builder, FunctionType *FnTy, + ArrayRef FnArgs, + SmallVectorImpl &CallArgs) { + size_t ArgIdx = 0; + for (auto paramTy : FnTy->params()) { + assert(ArgIdx < FnArgs.size()); + if (paramTy != FnArgs[ArgIdx]->getType()) + CallArgs.push_back( + Builder.CreateBitOrPointerCast(FnArgs[ArgIdx], paramTy)); + else + CallArgs.push_back(FnArgs[ArgIdx]); + ++ArgIdx; + } +} + static void splitAsyncCoroutine(Function &F, coro::Shape &Shape, SmallVectorImpl &Clones) { assert(Shape.ABI == coro::ABI::Async); @@ -1420,14 +1443,23 @@ IRBuilder<> Builder(ReturnBB); - // Insert the call to the tail call function. - auto *Fun = Suspend->getMustTailCallFunction(); + // Insert the call to the tail call function and inline it. + auto *Fn = Suspend->getMustTailCallFunction(); + auto DbgLoc = Suspend->getDebugLoc(); SmallVector Args(Suspend->operand_values()); - auto *TailCall = Builder.CreateCall( - cast(Fun->getType()->getPointerElementType()), Fun, - ArrayRef(Args).drop_front(3).drop_back(1)); - TailCall->setTailCallKind(CallInst::TCK_MustTail); - TailCall->setCallingConv(Fun->getCallingConv()); + auto FnArgs = ArrayRef(Args).drop_front(3).drop_back(1); + auto FnTy = cast(Fn->getType()->getPointerElementType()); + // Coerce the arguments, llvm optimizations seem to ignore the types in + // vaarg functions and throws away casts in optimized mode. + SmallVector CallArgs; + coerceArguments(Builder, FnTy, FnArgs, CallArgs); + auto *TailCall = Builder.CreateCall(FnTy, Fn, CallArgs); + TailCall->setDebugLoc(DbgLoc); + TailCall->setTailCall(); + TailCall->setCallingConv(Fn->getCallingConv()); + InlineFunctionInfo FnInfo; + auto InlineRes = InlineFunction(*TailCall, FnInfo); + assert(InlineRes.isSuccess() && "Expected inlining to succeed"); Builder.CreateRetVoid(); // Replace the lvm.coro.async.resume intrisic call. diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp --- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -683,11 +683,12 @@ } void CoroIdAsyncInst::checkWellFormed() const { - // TODO: check that the StorageArg is a parameter of this function. checkConstantInt(this, getArgOperand(SizeArg), "size argument to coro.id.async must be constant"); checkConstantInt(this, getArgOperand(AlignArg), "alignment argument to coro.id.async must be constant"); + checkConstantInt(this, getArgOperand(StorageArg), + "storage argument offset to coro.id.async must be constant"); checkAsyncFuncPointer(this, getArgOperand(AsyncFuncPtrArg)); } diff --git a/llvm/test/Transforms/Coroutines/coro-async.ll b/llvm/test/Transforms/Coroutines/coro-async.ll --- a/llvm/test/Transforms/Coroutines/coro-async.ll +++ b/llvm/test/Transforms/Coroutines/coro-async.ll @@ -28,8 +28,9 @@ }> ; Function that implements the dispatch to the callee function. -define swiftcc void @my_async_function.my_other_async_function_fp.apply(i8* %async.ctxt, %async.task* %task, %async.actor* %actor) { - musttail call swiftcc void @asyncSuspend(i8* %async.ctxt, %async.task* %task, %async.actor* %actor) +define swiftcc void @my_async_function.my_other_async_function_fp.apply(i8* %fnPtr, i8* %async.ctxt, %async.task* %task, %async.actor* %actor) { + %callee = bitcast i8* %fnPtr to void(i8*, %async.task*, %async.actor*)* + tail call swiftcc void %callee(i8* %async.ctxt, %async.task* %task, %async.actor* %actor) ret void } @@ -50,8 +51,7 @@ %proj.1 = getelementptr inbounds { i64, i64 }, { i64, i64 }* %tmp, i64 0, i32 0 %proj.2 = getelementptr inbounds { i64, i64 }, { i64, i64 }* %tmp, i64 0, i32 1 - %id = call token @llvm.coro.id.async(i32 128, i32 16, - i8* %async.ctxt, + %id = call token @llvm.coro.id.async(i32 128, i32 16, i32 0, i8* bitcast (<{i32, i32}>* @my_async_function_fp to i8*)) %hdl = call i8* @llvm.coro.begin(token %id, i8* null) store i64 0, i64* %proj.1, align 8 @@ -78,11 +78,12 @@ %callee_context.caller_context.addr = getelementptr inbounds %async.ctxt, %async.ctxt* %callee_context.0, i32 0, i32 0 store i8* %async.ctxt, i8** %callee_context.caller_context.addr %resume_proj_fun = bitcast i8*(i8*)* @resume_context_projection to i8* + %callee = bitcast void(i8*, %async.task*, %async.actor*)* @asyncSuspend to i8* %res = call {i8*, i8*, i8*} (i8*, i8*, ...) @llvm.coro.suspend.async( i8* %resume.func_ptr, i8* %resume_proj_fun, - void (i8*, %async.task*, %async.actor*)* @my_async_function.my_other_async_function_fp.apply, - i8* %callee_context, %async.task* %task, %async.actor *%actor) + void (i8*, i8*, %async.task*, %async.actor*)* @my_async_function.my_other_async_function_fp.apply, + i8* %callee, i8* %callee_context, %async.task* %task, %async.actor *%actor) call void @llvm.coro.async.context.dealloc(i8* %callee_context) %continuation_task_arg = extractvalue {i8*, i8*, i8*} %res, 1 @@ -126,7 +127,7 @@ ; CHECK: store i8* bitcast (void (i8*, i8*, i8*)* @my_async_function.resume.0 to i8*), i8** [[RETURN_TO_CALLER_ADDR]] ; CHECK: [[CALLER_CONTEXT_ADDR:%.*]] = bitcast i8* [[CALLEE_CTXT]] to i8** ; CHECK: store i8* %async.ctxt, i8** [[CALLER_CONTEXT_ADDR]] -; CHECK: musttail call swiftcc void @asyncSuspend(i8* [[CALLEE_CTXT]], %async.task* %task, %async.actor* %actor) +; CHECK: tail call swiftcc void @asyncSuspend(i8* [[CALLEE_CTXT]], %async.task* %task, %async.actor* %actor) ; CHECK: ret void ; CHECK: } @@ -170,7 +171,7 @@ define swiftcc void @my_async_function2(%async.task* %task, %async.actor* %actor, i8* %async.ctxt) { entry: - %id = call token @llvm.coro.id.async(i32 128, i32 16, i8* %async.ctxt, i8* bitcast (<{i32, i32}>* @my_async_function2_fp to i8*)) + %id = call token @llvm.coro.id.async(i32 128, i32 16, i32 2, i8* bitcast (<{i32, i32}>* @my_async_function2_fp to i8*)) %hdl = call i8* @llvm.coro.begin(token %id, i8* null) ; setup callee context %arg0 = bitcast %async.task* %task to i8* @@ -185,11 +186,12 @@ %callee_context.caller_context.addr = getelementptr inbounds %async.ctxt, %async.ctxt* %callee_context.0, i32 0, i32 0 store i8* %async.ctxt, i8** %callee_context.caller_context.addr %resume_proj_fun = bitcast i8*(i8*)* @resume_context_projection to i8* + %callee = bitcast void(i8*, %async.task*, %async.actor*)* @asyncSuspend to i8* %res = call {i8*, i8*, i8*} (i8*, i8*, ...) @llvm.coro.suspend.async( i8* %resume.func_ptr, i8* %resume_proj_fun, - void (i8*, %async.task*, %async.actor*)* @my_async_function.my_other_async_function_fp.apply, - i8* %callee_context, %async.task* %task, %async.actor *%actor) + void (i8*, i8*, %async.task*, %async.actor*)* @my_async_function.my_other_async_function_fp.apply, + i8* %callee, i8* %callee_context, %async.task* %task, %async.actor *%actor) %continuation_task_arg = extractvalue {i8*, i8*, i8*} %res, 0 %task.2 = bitcast i8* %continuation_task_arg to %async.task* @@ -202,11 +204,12 @@ %callee_context.caller_context.addr.1 = getelementptr inbounds %async.ctxt, %async.ctxt* %callee_context.0.1, i32 0, i32 0 store i8* %async.ctxt, i8** %callee_context.caller_context.addr.1 %resume_proj_fun.2 = bitcast i8*(i8*)* @resume_context_projection to i8* + %callee.2 = bitcast void(i8*, %async.task*, %async.actor*)* @asyncSuspend to i8* %res.2 = call {i8*, i8*, i8*} (i8*, i8*, ...) @llvm.coro.suspend.async( i8* %resume.func_ptr.1, i8* %resume_proj_fun.2, - void (i8*, %async.task*, %async.actor*)* @my_async_function.my_other_async_function_fp.apply, - i8* %callee_context, %async.task* %task, %async.actor *%actor) + void (i8*, i8*, %async.task*, %async.actor*)* @my_async_function.my_other_async_function_fp.apply, + i8* %callee.2, i8* %callee_context, %async.task* %task, %async.actor *%actor) call void @llvm.coro.async.context.dealloc(i8* %callee_context) %continuation_actor_arg = extractvalue {i8*, i8*, i8*} %res.2, 1 @@ -225,7 +228,7 @@ ; CHECK: store i8* [[CALLEE_CTXT]], ; CHECK: store i8* bitcast (void (i8*, i8*, i8*)* @my_async_function2.resume.0 to i8*), ; CHECK: store i8* %async.ctxt, -; CHECK: musttail call swiftcc void @asyncSuspend(i8* [[CALLEE_CTXT]], %async.task* %task, %async.actor* %actor) +; CHECK: tail call swiftcc void @asyncSuspend(i8* [[CALLEE_CTXT]], %async.task* %task, %async.actor* %actor) ; CHECK: ret void ; CHECK-LABEL: define internal swiftcc void @my_async_function2.resume.0(i8* %0, i8* nocapture readnone %1, i8* nocapture readonly %2) { @@ -235,7 +238,7 @@ ; CHECK: [[CALLEE_CTXT_SPILL_ADDR2:%.*]] = bitcast i8* [[CALLEE_CTXT_SPILL_ADDR]] to i8** ; CHECK: store i8* bitcast (void (i8*, i8*, i8*)* @my_async_function2.resume.1 to i8*), ; CHECK: [[CALLLE_CTXT_RELOAD:%.*]] = load i8*, i8** [[CALLEE_CTXT_SPILL_ADDR2]] -; CHECK: musttail call swiftcc void @asyncSuspend(i8* [[CALLEE_CTXT_RELOAD]] +; CHECK: tail call swiftcc void @asyncSuspend(i8* [[CALLEE_CTXT_RELOAD]] ; CHECK: ret void ; CHECK-LABEL: define internal swiftcc void @my_async_function2.resume.1(i8* nocapture readnone %0, i8* %1, i8* nocapture readonly %2) { @@ -258,7 +261,7 @@ ; CHECK: ret void declare i8* @llvm.coro.prepare.async(i8*) -declare token @llvm.coro.id.async(i32, i32, i8*, i8*) +declare token @llvm.coro.id.async(i32, i32, i32, i8*) declare i8* @llvm.coro.begin(token, i8*) declare i1 @llvm.coro.end(i8*, i1) declare {i8*, i8*, i8*} @llvm.coro.suspend.async(i8*, i8*, ...)