Index: include/clang/Basic/DiagnosticSemaKinds.td =================================================================== --- include/clang/Basic/DiagnosticSemaKinds.td +++ include/clang/Basic/DiagnosticSemaKinds.td @@ -8656,6 +8656,8 @@ "'std::experimental::coroutine_traits' must be a class template">; def err_implied_std_coroutine_traits_promise_type_not_found : Error< "this function cannot be a coroutine: %q0 has no member named 'promise_type'">; +def err_malformed_std_coroutine_handle : Error< + "'std::experimental::coroutine_handle' must be a class template">; def err_implied_std_coroutine_traits_promise_type_not_class : Error< "this function cannot be a coroutine: %0 is not a class">; def err_coroutine_promise_type_incomplete : Error< Index: lib/Sema/SemaCoroutine.cpp =================================================================== --- lib/Sema/SemaCoroutine.cpp +++ lib/Sema/SemaCoroutine.cpp @@ -116,6 +116,53 @@ return PromiseType; } +/// Look up the std::coroutine_traits<...>::promise_type for the given +/// function type. +static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType, + SourceLocation Loc) { + if (PromiseType.isNull()) + return QualType(); + + NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace(); + assert(StdExp && "Should already be diagnosed"); + + LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_handle"), + Loc, Sema::LookupOrdinaryName); + if (!S.LookupQualifiedName(Result, StdExp)) { + S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found); + return QualType(); + } + + ClassTemplateDecl *CoroHandle = Result.getAsSingle(); + if (!CoroHandle) { + Result.suppressDiagnostics(); + // We found something weird. Complain about the first thing we found. + NamedDecl *Found = *Result.begin(); + S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle); + return QualType(); + } + + // Form template argument list for coroutine_traits. + TemplateArgumentListInfo Args(Loc, Loc); + Args.addArgument(TemplateArgumentLoc( + TemplateArgument(PromiseType), + S.Context.getTrivialTypeSourceInfo(PromiseType, Loc))); + + // Build the template-id. + QualType CoroHandleType = + S.CheckTemplateIdType(TemplateName(CoroHandle), Loc, Args); + if (CoroHandleType.isNull()) + return QualType(); + if (S.RequireCompleteType(Loc, CoroHandleType, + diag::err_coroutine_traits_missing_specialization)) + return QualType(); + + auto *RD = CoroHandleType->getAsCXXRecordDecl(); + assert(RD && "specialization of class template is not a class?"); + + return CoroHandleType; +} + static bool isValidCoroutineContext(Sema &S, SourceLocation Loc, StringRef Keyword) { // 'co_await' and 'co_yield' are not permitted in unevaluated operands. @@ -260,20 +307,55 @@ return S.ActOnCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr); } +static ExprResult buildCoroutineHandle(Sema &S, QualType PromiseType, + SourceLocation Loc) { + QualType HandleType = lookupCoroutineHandleType(S, PromiseType, Loc); + if (HandleType.isNull()) + return ExprError(); + auto *RD = HandleType->getAsCXXRecordDecl(); + assert(RD && "must be class type"); + DeclarationName DN = S.PP.getIdentifierInfo("from_address"); + LookupResult LR(S, DN, Loc, Sema::LookupMemberName); + if (!S.LookupQualifiedName(LR, RD)) + return ExprError(); + + Expr *FramePtr = + buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {}); + + // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&. + CXXScopeSpec SS; + ExprResult Result = S.BuildMemberReferenceExpr( + /*BaseExpr*/ nullptr, HandleType, Loc, /*IsArrow=*/false, SS, + SourceLocation(), nullptr, LR, /*TemplateArgs=*/nullptr, + /*Scope=*/nullptr); + if (Result.isInvalid()) + return ExprError(); + + return S.ActOnCallExpr(nullptr, Result.get(), Loc, FramePtr, Loc, nullptr); +} + /// Build calls to await_ready, await_suspend, and await_resume for a co_await /// expression. -static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, SourceLocation Loc, - Expr *E) { +static ReadySuspendResumeResult +buildCoawaitCalls(Sema &S, SourceLocation Loc, QualType PromiseType, Expr *E) { // Assume invalid until we see otherwise. ReadySuspendResumeResult Calls = {true, {}}; + ExprResult HandleExprRes = buildCoroutineHandle(S, PromiseType, Loc); + if (HandleExprRes.isInvalid()) + return Calls; + Expr *HandleExpr = HandleExprRes.get(); + const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"}; for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) { Expr *Operand = new (S.Context) OpaqueValueExpr( Loc, E->getType(), VK_LValue, E->getObjectKind(), E); // FIXME: Pass coroutine handle to await_suspend. - ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], None); + MultiExprArg Args = None; + if (Funcs[I] == "await_suspend") + Args = HandleExpr; + ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], Args); if (Result.isInvalid()) return Calls; Calls.Results[I] = Result.get(); @@ -475,7 +557,8 @@ E = CreateMaterializeTemporaryExpr(E->getType(), E, true); // Build the await_ready, await_suspend, await_resume calls. - ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E); + ReadySuspendResumeResult RSS = + buildCoawaitCalls(*this, Loc, Coroutine->CoroutinePromise->getType(), E); if (RSS.IsInvalid) return ExprError(); @@ -528,7 +611,8 @@ E = CreateMaterializeTemporaryExpr(E->getType(), E, true); // Build the await_ready, await_suspend, await_resume calls. - ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E); + ReadySuspendResumeResult RSS = + buildCoawaitCalls(*this, Loc, Coroutine->CoroutinePromise->getType(), E); if (RSS.IsInvalid) return ExprError(); @@ -869,6 +953,8 @@ // FIXME: Perform move-initialization of parameters into frame-local copies. SmallVector ParamMoves; + // If we're instantiating a template then we have already replaced Body + // with a CoroutineBodyStmt. if (Body && !isa(Body)) { StmtResult BodyRes = BuildCoroutineBodyStmt( Body, FSI->CoroutinePromise, FSI->CoroutineSuspends.first, Index: test/CodeGenCoroutines/coro-alloc.cpp =================================================================== --- test/CodeGenCoroutines/coro-alloc.cpp +++ test/CodeGenCoroutines/coro-alloc.cpp @@ -4,12 +4,26 @@ namespace experimental { template struct coroutine_traits; // expected-note {{declared here}} + +template +struct coroutine_handle { + coroutine_handle() = default; + static coroutine_handle from_address(void *) { return {}; } +}; + +template <> +struct coroutine_handle { + static coroutine_handle from_address(void *) { return {}; } + coroutine_handle() = default; + template + coroutine_handle(coroutine_handle) {} +}; } } struct suspend_always { bool await_ready() { return false; } - void await_suspend() {} + void await_suspend(std::experimental::coroutine_handle<>) {} void await_resume() {} }; Index: test/SemaCXX/coroutines.cpp =================================================================== --- test/SemaCXX/coroutines.cpp +++ test/SemaCXX/coroutines.cpp @@ -16,33 +16,26 @@ // expected-error@-1 {{use of undeclared identifier 'a'}} } - -struct awaitable { - bool await_ready(); - void await_suspend(); // FIXME: coroutine_handle - void await_resume(); -} a; - -struct suspend_always { - bool await_ready() { return false; } - void await_suspend() {} - void await_resume() {} -}; - -struct suspend_never { - bool await_ready() { return true; } - void await_suspend() {} - void await_resume() {} -}; - void no_coroutine_traits() { - co_await a; // expected-error {{need to include }} + co_await 4; // expected-error {{need to include }} } namespace std { namespace experimental { template struct coroutine_traits; // expected-note {{declared here}} + +template +struct coroutine_handle { + static coroutine_handle from_address(void *); +}; + +template <> +struct coroutine_handle { + template + coroutine_handle(coroutine_handle); + static coroutine_handle from_address(void *); +}; } } @@ -52,6 +45,24 @@ using promise_type = Promise; }; +struct awaitable { + bool await_ready(); + void await_suspend(std::experimental::coroutine_handle<>); // FIXME: coroutine_handle + void await_resume(); +} a; + +struct suspend_always { + bool await_ready() { return false; } + void await_suspend(std::experimental::coroutine_handle<>) {} + void await_resume() {} +}; + +struct suspend_never { + bool await_ready() { return true; } + void await_suspend(std::experimental::coroutine_handle<>) {} + void await_resume() {} +}; + void no_specialization() { co_await a; // expected-error {{implicit instantiation of undefined template 'std::experimental::coroutine_traits'}} } @@ -86,13 +97,6 @@ struct std::experimental::coroutine_traits { using promise_type = promise_void; }; -namespace std { -namespace experimental { -template -struct coroutine_handle; -} -} - // FIXME: This diagnostic is terrible. void undefined_promise() { // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits::promise_type' (aka 'promise') is an incomplete type}} co_await a;