Index: include/clang/Sema/ScopeInfo.h =================================================================== --- include/clang/Sema/ScopeInfo.h +++ include/clang/Sema/ScopeInfo.h @@ -24,6 +24,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSwitch.h" #include namespace clang { @@ -139,6 +140,14 @@ /// to build, the initial and final coroutine suspend points bool NeedsCoroutineSuspends : 1; + /// \brief An enumeration represeting the kind of the first coroutine statement + /// in the function. One of co_return, co_await, or co_yield. + unsigned char FirstCoroutineStmtKind : 2; + + /// First coroutine statement in the current function. + /// (ex co_return, co_await, co_yield) + SourceLocation FirstCoroutineStmtLoc; + /// First 'return' statement in the current function. SourceLocation FirstReturnLoc; @@ -166,11 +175,6 @@ /// \brief The initial and final coroutine suspend points. std::pair CoroutineSuspends; - /// \brief The list of coroutine control flow constructs (co_await, co_yield, - /// co_return) that occur within the function or block. Empty if and only if - /// this function or block is not (yet known to be) a coroutine. - SmallVector CoroutineStmts; - /// \brief The stack of currently active compound stamement scopes in the /// function. SmallVector CompoundScopes; @@ -384,6 +388,28 @@ (HasBranchProtectedScope && HasBranchIntoScope)); } + void setFirstCoroutineStmt(SourceLocation Loc, StringRef Keyword) { + assert(FirstCoroutineStmtLoc.isInvalid() && + "first coroutine statement location already set"); + FirstCoroutineStmtLoc = Loc; + FirstCoroutineStmtKind = llvm::StringSwitch(Keyword) + .Case("co_return", 0) + .Case("co_await", 1) + .Case("co_yield", 2); + } + + StringRef getFirstCoroutineStmtKeyword() const { + assert(FirstCoroutineStmtLoc.isValid() + && "no coroutine statement available"); + switch (FirstCoroutineStmtKind) { + case 0: return "co_return"; + case 1: return "co_await"; + case 2: return "co_yield"; + default: + llvm_unreachable("FirstCoroutineStmtKind has an invalid value"); + }; + } + void setNeedsCoroutineSuspends(bool value = true) { assert((!value || CoroutineSuspends.first == nullptr) && "we already have valid suspend points"); Index: lib/Sema/ScopeInfo.cpp =================================================================== --- lib/Sema/ScopeInfo.cpp +++ lib/Sema/ScopeInfo.cpp @@ -40,13 +40,15 @@ FirstCXXTryLoc = SourceLocation(); FirstSEHTryLoc = SourceLocation(); - SwitchStack.clear(); - Returns.clear(); + // Coroutine state + FirstCoroutineStmtLoc = SourceLocation(); CoroutinePromise = nullptr; NeedsCoroutineSuspends = true; CoroutineSuspends.first = nullptr; CoroutineSuspends.second = nullptr; - CoroutineStmts.clear(); + + SwitchStack.clear(); + Returns.clear(); ErrorTrap.reset(); PossiblyUnreachableDiags.clear(); WeakObjectUses.clear(); Index: lib/Sema/SemaCoroutine.cpp =================================================================== --- lib/Sema/SemaCoroutine.cpp +++ lib/Sema/SemaCoroutine.cpp @@ -400,7 +400,8 @@ /// Check that this is a context in which a coroutine suspension can appear. static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc, - StringRef Keyword) { + StringRef Keyword, + bool IsImplicit = false) { if (!isValidCoroutineContext(S, Loc, Keyword)) return nullptr; @@ -409,6 +410,9 @@ auto *ScopeInfo = S.getCurFunction(); assert(ScopeInfo && "missing function scope for function"); + if (ScopeInfo->FirstCoroutineStmtLoc.isInvalid() && !IsImplicit) + ScopeInfo->setFirstCoroutineStmt(Loc, Keyword); + if (ScopeInfo->CoroutinePromise) return ScopeInfo; @@ -488,7 +492,7 @@ } ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E, - UnresolvedLookupExpr *Lookup) { + UnresolvedLookupExpr *Lookup) { auto *FSI = checkCoroutineContext(*this, Loc, "co_await"); if (!FSI) return ExprError(); @@ -504,7 +508,6 @@ if (Promise->getType()->isDependentType()) { Expr *Res = new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup); - FSI->CoroutineStmts.push_back(Res); return Res; } @@ -528,7 +531,7 @@ ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E, bool IsImplicit) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await"); + auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await", IsImplicit); if (!Coroutine) return ExprError(); @@ -541,8 +544,6 @@ if (E->getType()->isDependentType()) { Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit); - if (!IsImplicit) - Coroutine->CoroutineStmts.push_back(Res); return Res; } @@ -560,8 +561,7 @@ Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1], RSS.Results[2], RSS.OpaqueValue, IsImplicit); - if (!IsImplicit) - Coroutine->CoroutineStmts.push_back(Res); + return Res; } @@ -597,7 +597,6 @@ if (E->getType()->isDependentType()) { Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E); - Coroutine->CoroutineStmts.push_back(Res); return Res; } @@ -614,7 +613,7 @@ Expr *Res = new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1], RSS.Results[2], RSS.OpaqueValue); - Coroutine->CoroutineStmts.push_back(Res); + return Res; } @@ -628,7 +627,7 @@ StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E, bool IsImplicit) { - auto *FSI = checkCoroutineContext(*this, Loc, "co_return"); + auto *FSI = checkCoroutineContext(*this, Loc, "co_return", IsImplicit); if (!FSI) return StmtError(); @@ -656,8 +655,6 @@ Expr *PCE = ActOnFinishFullExpr(PC.get()).get(); Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit); - if (!IsImplicit) - FSI->CoroutineStmts.push_back(Res); return Res; } @@ -774,15 +771,11 @@ // Coroutines [stmt.return]p1: // A return statement shall not appear in a coroutine. if (Fn->FirstReturnLoc.isValid()) { + assert(Fn->FirstCoroutineStmtLoc.isValid() && + "first coroutine location not set"); Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine); - // FIXME: Every Coroutine statement may be invalid and therefore not added - // to CoroutineStmts. Find another way to provide location information. - if (!Fn->CoroutineStmts.empty()) { - auto *First = Fn->CoroutineStmts[0]; - Diag(First->getLocStart(), diag::note_declared_coroutine_here) - << (isa(First) ? "co_await" : - isa(First) ? "co_yield" : "co_return"); - } + Diag(Fn->FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) + << Fn->getFirstCoroutineStmtKeyword(); } SubStmtBuilder Builder(*this, *FD, *Fn, Body); if (Builder.isInvalid()) Index: lib/Sema/TreeTransform.h =================================================================== --- lib/Sema/TreeTransform.h +++ lib/Sema/TreeTransform.h @@ -6857,7 +6857,7 @@ ScopeInfo->NeedsCoroutineSuspends && ScopeInfo->CoroutineSuspends.first == nullptr && ScopeInfo->CoroutineSuspends.second == nullptr && - ScopeInfo->CoroutineStmts.empty() && "expected clean scope info"); + "expected clean scope info"); // Set that we have (possibly-invalid) suspend points before we do anything // that may fail. Index: test/SemaCXX/coroutines.cpp =================================================================== --- test/SemaCXX/coroutines.cpp +++ test/SemaCXX/coroutines.cpp @@ -162,11 +162,59 @@ return; // expected-error {{not allowed in coroutine}} } +void mixed_yield_invalid() { + co_yield blah; // expected-error {{use of undeclared identifier}} + // expected-note@-1 {{function is a coroutine due to use of 'co_yield'}} + return; // expected-error {{return statement not allowed in coroutine}} +} + +template +void mixed_yield_template(T) { + co_yield blah; // expected-error {{use of undeclared identifier}} + // expected-note@-1 {{function is a coroutine due to use of 'co_yield'}} + return; // expected-error {{return statement not allowed in coroutine}} +} + +template +void mixed_yield_template2(T) { + co_yield 42; + // expected-note@-1 {{function is a coroutine due to use of 'co_yield'}} + return; // expected-error {{return statement not allowed in coroutine}} +} + +template +void mixed_yield_template3(T v) { + co_yield blah(v); + // expected-note@-1 {{function is a coroutine due to use of 'co_yield'}} + return; // expected-error {{return statement not allowed in coroutine}} +} + void mixed_await() { co_await a; // expected-note {{use of 'co_await'}} return; // expected-error {{not allowed in coroutine}} } +void mixed_await_invalid() { + co_await 42; // expected-error {{'int' is not a structure or union}} + // expected-note@-1 {{function is a coroutine due to use of 'co_await'}} + return; // expected-error {{not allowed in coroutine}} +} + +template +void mixed_await_template(T) { + co_await 42; + // expected-note@-1 {{function is a coroutine due to use of 'co_await'}} + return; // expected-error {{not allowed in coroutine}} +} + +template +void mixed_await_template2(T v) { + co_await v; // expected-error {{'long' is not a structure or union}} + // expected-note@-1 {{function is a coroutine due to use of 'co_await'}} + return; // expected-error {{not allowed in coroutine}} +} +template void mixed_await_template2(long); // expected-note {{requested here}} + void only_coreturn(void_tag) { co_return; // OK } @@ -178,6 +226,33 @@ return; // expected-error {{not allowed in coroutine}} } +void mixed_coreturn_invalid(bool b) { + if (b) + co_return; // expected-note {{use of 'co_return'}} + // expected-error@-1 {{no member named 'return_void' in 'promise'}} + else + return; // expected-error {{not allowed in coroutine}} +} + +template +void mixed_coreturn_template(void_tag, bool b, T v) { + if (b) + co_return v; // expected-note {{use of 'co_return'}} + // expected-error@-1 {{no member named 'return_value' in 'promise_void'}} + else + return; // expected-error {{not allowed in coroutine}} +} +template void mixed_coreturn_template(void_tag, bool, int); // expected-note {{requested here}} + +template +void mixed_coreturn_template2(bool b, T) { + if (b) + co_return v; // expected-note {{use of 'co_return'}} + // expected-error@-1 {{use of undeclared identifier 'v'}} + else + return; // expected-error {{not allowed in coroutine}} +} + struct CtorDtor { CtorDtor() { co_yield 0; // expected-error {{'co_yield' cannot be used in a constructor}}