Index: cfe/trunk/lib/CodeGen/CGCoroutine.cpp =================================================================== --- cfe/trunk/lib/CodeGen/CGCoroutine.cpp +++ cfe/trunk/lib/CodeGen/CGCoroutine.cpp @@ -181,10 +181,8 @@ auto *SaveCall = Builder.CreateCall(CoroSave, {NullPtr}); auto *SuspendRet = CGF.EmitScalarExpr(S.getSuspendExpr()); - if (SuspendRet != nullptr) { + if (SuspendRet != nullptr && SuspendRet->getType()->isIntegerTy(1)) { // Veto suspension if requested by bool returning await_suspend. - assert(SuspendRet->getType()->isIntegerTy(1) && - "Sema should have already checked that it is void or bool"); BasicBlock *RealSuspendBlock = CGF.createBasicBlock(Prefix + Twine(".suspend.bool")); CGF.Builder.CreateCondBr(SuspendRet, RealSuspendBlock, ReadyBlock); Index: cfe/trunk/lib/Sema/SemaCoroutine.cpp =================================================================== --- cfe/trunk/lib/Sema/SemaCoroutine.cpp +++ cfe/trunk/lib/Sema/SemaCoroutine.cpp @@ -363,6 +363,32 @@ return S.ActOnCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr); } +// See if return type is coroutine-handle and if so, invoke builtin coro-resume +// on its address. This is to enable experimental support for coroutine-handle +// returning await_suspend that results in a guranteed tail call to the target +// coroutine. +static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E, + SourceLocation Loc) { + if (RetType->isReferenceType()) + return nullptr; + Type const *T = RetType.getTypePtr(); + if (!T->isClassType() && !T->isStructureType()) + return nullptr; + + // FIXME: Add convertability check to coroutine_handle<>. Possibly via + // EvaluateBinaryTypeTrait(BTT_IsConvertible, ...) which is at the moment + // a private function in SemaExprCXX.cpp + + ExprResult AddressExpr = buildMemberCall(S, E, Loc, "address", None); + if (AddressExpr.isInvalid()) + return nullptr; + + Expr *JustAddress = AddressExpr.get(); + // FIXME: Check that the type of AddressExpr is void* + return buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_resume, + JustAddress); +} + /// Build calls to await_ready, await_suspend, and await_resume for a co_await /// expression. static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise, @@ -412,16 +438,21 @@ // - await-suspend is the expression e.await_suspend(h), which shall be // a prvalue of type void or bool. QualType RetType = AwaitSuspend->getCallReturnType(S.Context); - // non-class prvalues always have cv-unqualified types - QualType AdjRetType = RetType.getUnqualifiedType(); - if (RetType->isReferenceType() || - (AdjRetType != S.Context.BoolTy && AdjRetType != S.Context.VoidTy)) { - S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(), - diag::err_await_suspend_invalid_return_type) - << RetType; - S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) - << AwaitSuspend->getDirectCallee(); - Calls.IsInvalid = true; + // Experimental support for coroutine_handle returning await_suspend. + if (Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc)) + Calls.Results[ACT::ACT_Suspend] = TailCallSuspend; + else { + // non-class prvalues always have cv-unqualified types + QualType AdjRetType = RetType.getUnqualifiedType(); + if (RetType->isReferenceType() || + (AdjRetType != S.Context.BoolTy && AdjRetType != S.Context.VoidTy)) { + S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(), + diag::err_await_suspend_invalid_return_type) + << RetType; + S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) + << AwaitSuspend->getDirectCallee(); + Calls.IsInvalid = true; + } } } Index: cfe/trunk/test/CodeGenCoroutines/coro-await.cpp =================================================================== --- cfe/trunk/test/CodeGenCoroutines/coro-await.cpp +++ cfe/trunk/test/CodeGenCoroutines/coro-await.cpp @@ -12,6 +12,7 @@ struct coroutine_handle { void *ptr; static coroutine_handle from_address(void *); + void *address(); }; template @@ -326,3 +327,20 @@ // CHECK-NEXT: store %struct.RefTag* %[[RES3]], %struct.RefTag** %[[ZVAR]], RefTag& z = co_yield 42; } + +struct TailCallAwait { + bool await_ready(); + std::experimental::coroutine_handle<> await_suspend(std::experimental::coroutine_handle<>); + void await_resume(); +}; + +// CHECK-LABEL: @TestTailcall( +extern "C" void TestTailcall() { + co_await TailCallAwait{}; + + // CHECK: %[[RESULT:.+]] = call i8* @_ZN13TailCallAwait13await_suspendENSt12experimental16coroutine_handleIvEE(%struct.TailCallAwait* + // CHECK: %[[COERCE:.+]] = getelementptr inbounds %"struct.std::experimental::coroutine_handle", %"struct.std::experimental::coroutine_handle"* %[[TMP:.+]], i32 0, i32 0 + // CHECK: store i8* %[[RESULT]], i8** %[[COERCE]] + // CHECK: %[[ADDR:.+]] = call i8* @_ZNSt12experimental16coroutine_handleIvE7addressEv(%"struct.std::experimental::coroutine_handle"* %[[TMP]]) + // CHECK: call void @llvm.coro.resume(i8* %[[ADDR]]) +}