Index: include/clang/Sema/Sema.h =================================================================== --- include/clang/Sema/Sema.h +++ include/clang/Sema/Sema.h @@ -8024,7 +8024,9 @@ ExprResult BuildCoyieldExpr(SourceLocation KwLoc, Expr *E); StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E); - void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body); + VarDecl *buildCoroutinePromise(SourceLocation KWLoc); + + StmtResult ActOnFinishCoroutineBody(FunctionDecl *FD, Stmt *Body); //===--------------------------------------------------------------------===// // OpenMP directives and clauses. Index: lib/Sema/SemaCoroutine.cpp =================================================================== --- lib/Sema/SemaCoroutine.cpp +++ lib/Sema/SemaCoroutine.cpp @@ -24,18 +24,19 @@ /// Look up the std::coroutine_traits<...>::promise_type for the given /// function type. static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType, - SourceLocation Loc) { + SourceLocation KWLoc, + SourceLocation FuncLoc) { // FIXME: Cache std::coroutine_traits once we've found it. NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace(); if (!StdExp) { - S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found); + S.Diag(KWLoc, diag::err_implied_std_coroutine_traits_not_found); return QualType(); } LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"), - Loc, Sema::LookupOrdinaryName); + FuncLoc, Sema::LookupOrdinaryName); if (!S.LookupQualifiedName(Result, StdExp)) { - S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found); + S.Diag(KWLoc, diag::err_implied_std_coroutine_traits_not_found); return QualType(); } @@ -49,22 +50,22 @@ } // Form template argument list for coroutine_traits. - TemplateArgumentListInfo Args(Loc, Loc); + TemplateArgumentListInfo Args(FuncLoc, FuncLoc); Args.addArgument(TemplateArgumentLoc( TemplateArgument(FnType->getReturnType()), - S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc))); + S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), FuncLoc))); // FIXME: If the function is a non-static member function, add the type // of the implicit object parameter before the formal parameters. for (QualType T : FnType->getParamTypes()) Args.addArgument(TemplateArgumentLoc( - TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc))); + TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, FuncLoc))); // Build the template-id. QualType CoroTrait = - S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args); + S.CheckTemplateIdType(TemplateName(CoroTraits), FuncLoc, Args); if (CoroTrait.isNull()) return QualType(); - if (S.RequireCompleteType(Loc, CoroTrait, + if (S.RequireCompleteType(FuncLoc, CoroTrait, diag::err_coroutine_traits_missing_specialization)) return QualType(); @@ -72,13 +73,14 @@ assert(RD && "specialization of class template is not a class?"); // Look up the ::promise_type member. - LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc, + LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), FuncLoc, Sema::LookupOrdinaryName); S.LookupQualifiedName(R, RD); auto *Promise = R.getAsSingle(); if (!Promise) { - S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found) - << RD; + S.Diag(FuncLoc, + diag::err_implied_std_coroutine_traits_promise_type_not_found) + << RD; return QualType(); } @@ -91,14 +93,38 @@ CoroTrait.getTypePtr()); PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType); - S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class) - << PromiseType; + S.Diag(FuncLoc, + diag::err_implied_std_coroutine_traits_promise_type_not_class) + << PromiseType; return QualType(); } return PromiseType; } +VarDecl *Sema::buildCoroutinePromise(SourceLocation KWLoc) { + auto *FD = dyn_cast(CurContext); + assert(FD && "Not inside a function context"); + SourceLocation FuncLoc = FD->getLocation(); + QualType T = + FD->getType()->isDependentType() + ? Context.DependentTy + : lookupPromiseType(*this, FD->getType()->castAs(), + KWLoc, FuncLoc); + if (T.isNull()) + return nullptr; + + // Create and default-initialize the promise. + VarDecl *VD = + VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(), + &PP.getIdentifierTable().get("__promise"), T, + Context.getTrivialTypeSourceInfo(T, FuncLoc), SC_None); + CheckVariableDeclarationType(VD); + if (!VD->isInvalidDecl()) + ActOnUninitializedDecl(VD, false); + return VD; +} + /// Check that this is a context in which a coroutine suspension can appear. static FunctionScopeInfo * checkCoroutineContext(Sema &S, SourceLocation Loc, StringRef Keyword) { @@ -138,22 +164,8 @@ // If we don't have a promise variable, build one now. if (!ScopeInfo->CoroutinePromise) { - QualType T = - FD->getType()->isDependentType() - ? S.Context.DependentTy - : lookupPromiseType(S, FD->getType()->castAs(), - Loc); - if (T.isNull()) + if (!(ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc))) return nullptr; - - // Create and default-initialize the promise. - ScopeInfo->CoroutinePromise = - VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(), - &S.PP.getIdentifierTable().get("__promise"), T, - S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None); - S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise); - if (!ScopeInfo->CoroutinePromise->isInvalidDecl()) - S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise, false); } return ScopeInfo; @@ -378,7 +390,7 @@ return Res; } -void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { +StmtResult Sema::ActOnFinishCoroutineBody(FunctionDecl *FD, Stmt *Body) { FunctionScopeInfo *Fn = getCurFunction(); assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine"); @@ -410,7 +422,7 @@ StmtResult PromiseStmt = ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc); if (PromiseStmt.isInvalid()) - return FD->setInvalidDecl(); + return StmtError(); // Form and check implicit 'co_await p.initial_suspend();' statement. ExprResult InitialSuspend = @@ -420,7 +432,7 @@ InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get()); InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get()); if (InitialSuspend.isInvalid()) - return FD->setInvalidDecl(); + return StmtError(); // Form and check implicit 'co_await p.final_suspend();' statement. ExprResult FinalSuspend = @@ -430,7 +442,7 @@ FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get()); FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get()); if (FinalSuspend.isInvalid()) - return FD->setInvalidDecl(); + return StmtError(); // FIXME: Perform analysis of set_exception call. @@ -442,7 +454,7 @@ ExprResult ReturnObject = buildPromiseCall(*this, Fn, Loc, "get_return_object", None); if (ReturnObject.isInvalid()) - return FD->setInvalidDecl(); + return StmtError(); QualType RetType = FD->getReturnType(); if (!RetType->isDependentType()) { InitializedEntity Entity = @@ -450,18 +462,18 @@ ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType, ReturnObject.get()); if (ReturnObject.isInvalid()) - return FD->setInvalidDecl(); + return StmtError(); } ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc); if (ReturnObject.isInvalid()) - return FD->setInvalidDecl(); + return StmtError(); // FIXME: Perform move-initialization of parameters into frame-local copies. SmallVector ParamMoves; // Build body for the coroutine wrapper statement. - Body = new (Context) CoroutineBodyStmt( + return new (Context) CoroutineBodyStmt( Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(), - /*SetException*/nullptr, /*Fallthrough*/nullptr, - ReturnObject.get(), ParamMoves); + /*SetException*/ nullptr, /*Fallthrough*/ nullptr, ReturnObject.get(), + ParamMoves); } Index: lib/Sema/SemaDecl.cpp =================================================================== --- lib/Sema/SemaDecl.cpp +++ lib/Sema/SemaDecl.cpp @@ -11626,13 +11626,22 @@ Decl *Sema::ActOnFinishFunctionBody(Decl *dcl, Stmt *Body, bool IsInstantiation) { - FunctionDecl *FD = dcl ? dcl->getAsFunction() : nullptr; + if (!dcl) + return nullptr; + FunctionDecl *FD = dcl->getAsFunction(); sema::AnalysisBasedWarnings::Policy WP = AnalysisWarnings.getDefaultPolicy(); sema::AnalysisBasedWarnings::Policy *ActivePolicy = nullptr; - if (getLangOpts().CoroutinesTS && !getCurFunction()->CoroutineStmts.empty()) - CheckCompletedCoroutineBody(FD, Body); + if (getLangOpts().CoroutinesTS && !getCurFunction()->CoroutineStmts.empty()) { + // FIXME: support ObjC methods here + assert(FD && "Objective C methods are not supported"); + StmtResult NewBody = ActOnFinishCoroutineBody(FD, Body); + if (NewBody.isInvalid()) + FD->setInvalidDecl(); + else + Body = NewBody.get(); + } if (FD) { FD->setBody(Body); Index: lib/Sema/TreeTransform.h =================================================================== --- lib/Sema/TreeTransform.h +++ lib/Sema/TreeTransform.h @@ -1326,6 +1326,16 @@ return getSema().BuildCoyieldExpr(CoyieldLoc, Result); } + /// \brief Build a new coroutine body. + /// + /// By default, performs semantic analysis to build the new body. + /// Subclasses may override this routine to provide different behavior. + StmtResult RebuildCoroutineBodyStmt(Stmt *Body) { + auto *FD = dyn_cast(getSema().CurContext); + assert(FD); // FIXME this assertion should never fire + return getSema().ActOnFinishCoroutineBody(FD, Body); + } + /// \brief Build a new Objective-C \@try statement. /// /// By default, performs semantic analysis to build the new statement. @@ -6651,8 +6661,25 @@ template StmtResult TreeTransform::TransformCoroutineBodyStmt(CoroutineBodyStmt *S) { - // The coroutine body should be re-formed by the caller if necessary. - return getDerived().TransformStmt(S->getBody()); + // FIXME: Don't rebuild the entire coroutine body. + // The coroutine body should only be re-formed by the caller if necessary. + FunctionScopeInfo *FS = getSema().getCurFunction(); + assert(FS); + VarDecl *VD = + getSema().buildCoroutinePromise(S->getPromiseDecl()->getLocation()); + if (!VD || VD->isInvalidDecl()) + return StmtError(); + getDerived().transformedLocalDecl(S->getPromiseDecl(), VD); + // FIXME: Re-setting FS->CoroutinePromise feels like a hack. Is there a better + // way to do this? Currently this is needed so the rebuilt body uses the + // transformed promise type. + FS->CoroutinePromise = VD; + + StmtResult Body = getDerived().TransformStmt(S->getBody()); + if (Body.isInvalid()) + return StmtError(); + + return getDerived().RebuildCoroutineBodyStmt(Body.get()); } template Index: test/SemaCXX/coroutines.cpp =================================================================== --- test/SemaCXX/coroutines.cpp +++ test/SemaCXX/coroutines.cpp @@ -52,21 +52,21 @@ using promise_type = Promise; }; -void no_specialization() { - co_await a; // expected-error {{implicit instantiation of undefined template 'std::experimental::coroutine_traits'}} +void no_specialization() { // expected-error {{implicit instantiation of undefined template 'std::experimental::coroutine_traits'}} + co_await a; } template struct std::experimental::coroutine_traits {}; -int no_promise_type() { - co_await a; // expected-error {{this function cannot be a coroutine: 'std::experimental::coroutine_traits' has no member named 'promise_type'}} +int no_promise_type() { // expected-error {{this function cannot be a coroutine: 'std::experimental::coroutine_traits' has no member named 'promise_type'}} + co_await a; } template <> struct std::experimental::coroutine_traits { typedef int promise_type; }; -double bad_promise_type(double) { - co_await a; // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits::promise_type' (aka 'int') is not a class}} +double bad_promise_type(double) { // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits::promise_type' (aka 'int') is not a class}} + co_await a; } template <> @@ -77,7 +77,7 @@ co_yield 0; // expected-error {{no member named 'yield_value' in 'std::experimental::coroutine_traits::promise_type'}} } -struct promise; // expected-note 2{{forward declaration}} +struct promise; // expected-note 3{{forward declaration}} template struct std::experimental::coroutine_traits { using promise_type = promise; }; @@ -94,6 +94,12 @@ // expected-error@-2 {{incomplete definition of type 'promise'}} co_await a; } +template +void undefined_promise_template(T) { // expected-error {{variable has incomplete type 'promise_type'}} + // FIXME: This diagnostic doesn't make any sense. + co_await a; +} +template void undefined_promise_template(int); // expected-note {{requested here}} struct yielded_thing { const char *p; short a, b; }; @@ -299,6 +305,16 @@ co_await a; } +struct not_class_tag {}; +template <> +struct std::experimental::coroutine_traits { using promise_type = int; }; + +template +void promise_type_not_class(T) { + // expected-error@-1 {{this function cannot be a coroutine: 'experimental::coroutine_traits::promise_type' (aka 'int') is not a class}} + co_await a; +} +template void promise_type_not_class(not_class_tag); // expected-note {{requested here}} template<> struct std::experimental::coroutine_traits { using promise_type = promise; };