Index: include/clang/AST/ExprCXX.h =================================================================== --- include/clang/AST/ExprCXX.h +++ include/clang/AST/ExprCXX.h @@ -4194,11 +4194,16 @@ friend class ASTStmtReader; public: CoawaitExpr(SourceLocation CoawaitLoc, Expr *Operand, Expr *Ready, - Expr *Suspend, Expr *Resume) + Expr *Suspend, Expr *Resume, bool IsImplicit = false) : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Operand, Ready, - Suspend, Resume) {} - CoawaitExpr(SourceLocation CoawaitLoc, QualType Ty, Expr *Operand) - : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Ty, Operand) {} + Suspend, Resume) { + CoawaitBits.IsImplicit = IsImplicit; + } + CoawaitExpr(SourceLocation CoawaitLoc, QualType Ty, Expr *Operand, + bool IsImplicit = false) + : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Ty, Operand) { + CoawaitBits.IsImplicit = IsImplicit; + } CoawaitExpr(EmptyShell Empty) : CoroutineSuspendExpr(CoawaitExprClass, Empty) {} @@ -4207,11 +4212,57 @@ return getCommonExpr(); } + bool isImplicit() const { return CoawaitBits.IsImplicit; } + void setIsImplicit(bool value = true) { CoawaitBits.IsImplicit = value; } + static bool classof(const Stmt *T) { return T->getStmtClass() == CoawaitExprClass; } }; +/// \brief Represents a 'co_await' expression while the type of the promise +/// is dependent. +class DependentCoawaitExpr : public Expr { + SourceLocation KeywordLoc; + Stmt *SubExprs[2]; + + friend class ASTStmtReader; + +public: + DependentCoawaitExpr(SourceLocation KeywordLoc, QualType Ty, Expr *Op, + UnresolvedLookupExpr *OpCoawait) + : Expr(DependentCoawaitExprClass, Ty, VK_RValue, OK_Ordinary, + /*TypeDependent*/ true, /*ValueDependent*/ true, + /*InstantiationDependent*/ true, + Op->containsUnexpandedParameterPack()), + KeywordLoc(KeywordLoc) { + assert(Op->isTypeDependent() && Ty->isDependentType() && + "wrong constructor for non-dependent co_await/co_yield expression"); + SubExprs[0] = Op; + SubExprs[1] = OpCoawait; + } + + DependentCoawaitExpr(EmptyShell Empty) + : Expr(DependentCoawaitExprClass, Empty) {} + + Expr *getOperand() const { return cast(SubExprs[0]); } + UnresolvedLookupExpr *getOperatorCoawaitLookup() const { + return cast(SubExprs[1]); + } + SourceLocation getKeywordLoc() const { return KeywordLoc; } + + SourceLocation getLocStart() const LLVM_READONLY { return KeywordLoc; } + SourceLocation getLocEnd() const LLVM_READONLY { + return getOperand()->getLocEnd(); + } + + child_range children() { return child_range(SubExprs, SubExprs + 2); } + + static bool classof(const Stmt *T) { + return T->getStmtClass() == DependentCoawaitExprClass; + } +}; + /// \brief Represents a 'co_yield' expression. class CoyieldExpr : public CoroutineSuspendExpr { friend class ASTStmtReader; Index: include/clang/AST/RecursiveASTVisitor.h =================================================================== --- include/clang/AST/RecursiveASTVisitor.h +++ include/clang/AST/RecursiveASTVisitor.h @@ -2509,6 +2509,12 @@ ShouldVisitChildren = false; } }) +DEF_TRAVERSE_STMT(DependentCoawaitExpr, { + if (!getDerived().shouldVisitImplicitCode()) { + TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getOperand()); + ShouldVisitChildren = false; + } +}) DEF_TRAVERSE_STMT(CoyieldExpr, { if (!getDerived().shouldVisitImplicitCode()) { TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getOperand()); Index: include/clang/AST/Stmt.h =================================================================== --- include/clang/AST/Stmt.h +++ include/clang/AST/Stmt.h @@ -252,6 +252,17 @@ unsigned NumArgs : 32 - 8 - 1 - NumExprBits; }; + class CoawaitExprBitfields { + friend class CoawaitExpr; + + unsigned : NumExprBits; + + unsigned IsImplicit : 1; + + /// \brief The number of arguments to this type trait. + unsigned NumArgs : 32 - 1 - NumExprBits; + }; + union { StmtBitfields StmtBits; CompoundStmtBitfields CompoundStmtBits; @@ -268,6 +279,7 @@ ObjCIndirectCopyRestoreExprBitfields ObjCIndirectCopyRestoreExprBits; InitListExprBitfields InitListExprBits; TypeTraitExprBitfields TypeTraitExprBits; + CoawaitExprBitfields CoawaitBits; }; friend class ASTStmtReader; Index: include/clang/AST/StmtCXX.h =================================================================== --- include/clang/AST/StmtCXX.h +++ include/clang/AST/StmtCXX.h @@ -370,24 +370,25 @@ } Expr *getAllocate() const { - return cast(getStoredStmts()[SubStmt::Allocate]); + return cast_or_null(getStoredStmts()[SubStmt::Allocate]); } Expr *getDeallocate() const { - return cast(getStoredStmts()[SubStmt::Deallocate]); + return cast_or_null(getStoredStmts()[SubStmt::Deallocate]); } Expr *getReturnValueInit() const { - return cast(getStoredStmts()[SubStmt::ReturnValue]); + return cast_or_null(getStoredStmts()[SubStmt::ReturnValue]); } ArrayRef getParamMoves() const { return {getStoredStmts() + SubStmt::FirstParamMove, NumParams}; } SourceLocation getLocStart() const LLVM_READONLY { - return getBody()->getLocStart(); + return getBody() ? getBody()->getLocStart() + : getPromiseDecl()->getLocStart(); } SourceLocation getLocEnd() const LLVM_READONLY { - return getBody()->getLocEnd(); + return getBody() ? getBody()->getLocEnd() : getPromiseDecl()->getLocEnd(); } child_range children() { @@ -417,10 +418,14 @@ enum SubStmt { Operand, PromiseCall, Count }; Stmt *SubStmts[SubStmt::Count]; + bool IsImplicit : 1; + friend class ASTStmtReader; public: - CoreturnStmt(SourceLocation CoreturnLoc, Stmt *Operand, Stmt *PromiseCall) - : Stmt(CoreturnStmtClass), CoreturnLoc(CoreturnLoc) { + CoreturnStmt(SourceLocation CoreturnLoc, Stmt *Operand, Stmt *PromiseCall, + bool IsImplicit = false) + : Stmt(CoreturnStmtClass), CoreturnLoc(CoreturnLoc), + IsImplicit(IsImplicit) { SubStmts[SubStmt::Operand] = Operand; SubStmts[SubStmt::PromiseCall] = PromiseCall; } @@ -438,6 +443,9 @@ return static_cast(SubStmts[PromiseCall]); } + bool isImplicit() const { return IsImplicit; } + void setIsImplicit(bool value = true) { IsImplicit = value; } + SourceLocation getLocStart() const LLVM_READONLY { return CoreturnLoc; } SourceLocation getLocEnd() const LLVM_READONLY { return getOperand() ? getOperand()->getLocEnd() : getLocStart(); Index: include/clang/Basic/DiagnosticSemaKinds.td =================================================================== --- include/clang/Basic/DiagnosticSemaKinds.td +++ include/clang/Basic/DiagnosticSemaKinds.td @@ -8793,8 +8793,7 @@ def err_return_in_coroutine : Error< "return statement not allowed in coroutine; did you mean 'co_return'?">; def note_declared_coroutine_here : Note< - "function is a coroutine due to use of " - "'%select{co_await|co_yield|co_return}0' here">; + "function is a coroutine due to use of '%0' here">; def err_coroutine_objc_method : Error< "Objective-C methods as coroutines are not yet supported">; def err_coroutine_unevaluated_context : Error< @@ -8814,6 +8813,8 @@ "this function cannot be a coroutine: %q0 has no member named 'promise_type'">; def err_implied_std_coroutine_traits_promise_type_not_class : Error< "this function cannot be a coroutine: %0 is not a class">; +def err_coroutine_promise_type_incomplete : Error< + "this function cannot be a coroutine: %0 is an incomplete type">; def err_coroutine_traits_missing_specialization : Error< "this function cannot be a coroutine: missing definition of " "specialization %q0">; @@ -8824,6 +8825,11 @@ "'std::current_exception' must be a function">; def err_coroutine_promise_return_ill_formed : Error< "%0 declares both 'return_value' and 'return_void'">; +def note_coroutine_promise_implicit_await_transform_required_here : Note< + "call to 'await_transform' implicitly required by 'co_await' here">; +def note_coroutine_promise_call_implicitly_required : Note< + "call to '%select{initial_suspend|final_suspend}0' implicitly " + "required by the %select{initial suspend point|final suspend point}0">; } let CategoryName = "Documentation Issue" in { Index: include/clang/Basic/StmtNodes.td =================================================================== --- include/clang/Basic/StmtNodes.td +++ include/clang/Basic/StmtNodes.td @@ -150,6 +150,7 @@ // C++ Coroutines TS expressions def CoroutineSuspendExpr : DStmt; def CoawaitExpr : DStmt; +def DependentCoawaitExpr : DStmt; def CoyieldExpr : DStmt; // Obj-C Expressions. Index: include/clang/Sema/ScopeInfo.h =================================================================== --- include/clang/Sema/ScopeInfo.h +++ include/clang/Sema/ScopeInfo.h @@ -135,6 +135,10 @@ /// false if there is an invocation of an initializer on 'self'. bool ObjCWarnForNoInitDelegation : 1; + /// \brief Whether this function has already built, or tried to build, the + /// the initial and final coroutine suspend points. + bool NeedsCoroutineSuspends : 1; + /// First 'return' statement in the current function. SourceLocation FirstReturnLoc; @@ -159,6 +163,9 @@ /// \brief The promise object for this coroutine, if any. VarDecl *CoroutinePromise = nullptr; + /// \brief The initial and final coroutine suspend points. + std::pair CoroutineSuspends; + /// \brief The list of coroutine control flow constructs (co_await, co_yield, /// co_return) that occur within the function or block. Empty if and only if /// this function or block is not (yet known to be) a coroutine. @@ -376,7 +383,25 @@ (HasIndirectGoto || (HasBranchProtectedScope && HasBranchIntoScope)); } - + + void setNeedsCoroutineSuspends(bool value = true) { + assert(NeedsCoroutineSuspends && CoroutineSuspends.first == nullptr && + "we already have valid suspend points"); + NeedsCoroutineSuspends = value; + } + + bool hasInvalidCoroutineSuspends() const { + return !NeedsCoroutineSuspends && CoroutineSuspends.first == nullptr; + } + + void setCoroutineSuspends(Stmt *Initial, Stmt *Final) { + assert(Initial && Final && "suspend points cannot be null"); + assert(CoroutineSuspends.first == nullptr && "suspend points already set"); + NeedsCoroutineSuspends = false; + CoroutineSuspends.first = Initial; + CoroutineSuspends.second = Final; + } + FunctionScopeInfo(DiagnosticsEngine &Diag) : Kind(SK_Function), HasBranchProtectedScope(false), @@ -386,6 +411,7 @@ HasOMPDeclareReductionCombiner(false), HasFallthroughStmt(false), HasPotentialAvailabilityViolations(false), + NeedsCoroutineSuspends(true), ObjCShouldCallSuper(false), ObjCIsDesignatedInit(false), ObjCWarnForNoDesignatedInitChain(false), Index: include/clang/Sema/Sema.h =================================================================== --- include/clang/Sema/Sema.h +++ include/clang/Sema/Sema.h @@ -26,6 +26,7 @@ #include "clang/AST/MangleNumberingContext.h" #include "clang/AST/NSAPI.h" #include "clang/AST/PrettyPrinter.h" +#include "clang/AST/StmtCXX.h" #include "clang/AST/TypeLoc.h" #include "clang/AST/TypeOrdering.h" #include "clang/Basic/ExpressionTraits.h" @@ -101,6 +102,7 @@ class CodeCompletionAllocator; class CodeCompletionTUInfo; class CodeCompletionResult; + class CoroutineBodyStmt; class Decl; class DeclAccessPair; class DeclContext; @@ -8169,12 +8171,17 @@ // ExprResult ActOnCoawaitExpr(Scope *S, SourceLocation KwLoc, Expr *E); ExprResult ActOnCoyieldExpr(Scope *S, SourceLocation KwLoc, Expr *E); - StmtResult ActOnCoreturnStmt(SourceLocation KwLoc, Expr *E); + StmtResult ActOnCoreturnStmt(Scope *S, SourceLocation KwLoc, Expr *E); - ExprResult BuildCoawaitExpr(SourceLocation KwLoc, Expr *E); + ExprResult BuildCoawaitExpr(SourceLocation KwLoc, Expr *E, + bool IsImplicit = false); + ExprResult BuildDependentCoawaitExpr(SourceLocation KwLoc, Expr *E, + UnresolvedLookupExpr* Lookup); ExprResult BuildCoyieldExpr(SourceLocation KwLoc, Expr *E); - StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E); - + StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E, + bool IsImplicit = false); + StmtResult BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs); + VarDecl *buildCoroutinePromise(SourceLocation Loc); void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body); //===--------------------------------------------------------------------===// Index: lib/AST/Expr.cpp =================================================================== --- lib/AST/Expr.cpp +++ lib/AST/Expr.cpp @@ -2953,6 +2953,7 @@ case CXXNewExprClass: case CXXDeleteExprClass: case CoawaitExprClass: + case DependentCoawaitExprClass: case CoyieldExprClass: // These always have a side-effect. return true; Index: lib/AST/ExprClassification.cpp =================================================================== --- lib/AST/ExprClassification.cpp +++ lib/AST/ExprClassification.cpp @@ -129,6 +129,7 @@ case Expr::UnresolvedLookupExprClass: case Expr::UnresolvedMemberExprClass: case Expr::TypoExprClass: + case Expr::DependentCoawaitExprClass: case Expr::CXXDependentScopeMemberExprClass: case Expr::DependentScopeDeclRefExprClass: // ObjC instance variables are lvalues Index: lib/AST/ExprConstant.cpp =================================================================== --- lib/AST/ExprConstant.cpp +++ lib/AST/ExprConstant.cpp @@ -10098,6 +10098,7 @@ case Expr::LambdaExprClass: case Expr::CXXFoldExprClass: case Expr::CoawaitExprClass: + case Expr::DependentCoawaitExprClass: case Expr::CoyieldExprClass: return ICEDiag(IK_NotICE, E->getLocStart()); Index: lib/AST/ItaniumMangle.cpp =================================================================== --- lib/AST/ItaniumMangle.cpp +++ lib/AST/ItaniumMangle.cpp @@ -4037,6 +4037,12 @@ mangleExpression(cast(E)->getOperand()); break; + case Expr::DependentCoawaitExprClass: + // FIXME: Propose a non-vendor mangling. + Out << "v18co_await"; + mangleExpression(cast(E)->getOperand()); + break; + case Expr::CoyieldExprClass: // FIXME: Propose a non-vendor mangling. Out << "v18co_yield"; Index: lib/AST/StmtPrinter.cpp =================================================================== --- lib/AST/StmtPrinter.cpp +++ lib/AST/StmtPrinter.cpp @@ -2475,6 +2475,13 @@ PrintExpr(S->getOperand()); } + +void StmtPrinter::VisitDependentCoawaitExpr(DependentCoawaitExpr *S) { + OS << "co_await "; + PrintExpr(S->getOperand()); +} + + void StmtPrinter::VisitCoyieldExpr(CoyieldExpr *S) { OS << "co_yield "; PrintExpr(S->getOperand()); Index: lib/AST/StmtProfile.cpp =================================================================== --- lib/AST/StmtProfile.cpp +++ lib/AST/StmtProfile.cpp @@ -1599,6 +1599,10 @@ VisitExpr(S); } +void StmtProfiler::VisitDependentCoawaitExpr(const DependentCoawaitExpr *S) { + VisitExpr(S); +} + void StmtProfiler::VisitCoyieldExpr(const CoyieldExpr *S) { VisitExpr(S); } Index: lib/Parse/ParseStmt.cpp =================================================================== --- lib/Parse/ParseStmt.cpp +++ lib/Parse/ParseStmt.cpp @@ -1898,7 +1898,7 @@ } } if (IsCoreturn) - return Actions.ActOnCoreturnStmt(ReturnLoc, R.get()); + return Actions.ActOnCoreturnStmt(getCurScope(), ReturnLoc, R.get()); return Actions.ActOnReturnStmt(ReturnLoc, R.get(), getCurScope()); } Index: lib/Sema/ScopeInfo.cpp =================================================================== --- lib/Sema/ScopeInfo.cpp +++ lib/Sema/ScopeInfo.cpp @@ -43,6 +43,9 @@ SwitchStack.clear(); Returns.clear(); CoroutinePromise = nullptr; + NeedsCoroutineSuspends = true; + CoroutineSuspends.first = nullptr; + CoroutineSuspends.second = nullptr; CoroutineStmts.clear(); ErrorTrap.reset(); PossiblyUnreachableDiags.clear(); Index: lib/Sema/SemaCoroutine.cpp =================================================================== --- lib/Sema/SemaCoroutine.cpp +++ lib/Sema/SemaCoroutine.cpp @@ -21,21 +21,32 @@ using namespace clang; using namespace sema; +static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD, + SourceLocation Loc) { + DeclarationName DN = S.PP.getIdentifierInfo(Name); + LookupResult LR(S, DN, Loc, Sema::LookupMemberName); + // Suppress diagnostics when a private member is selected. The same warnings + // will be produced again when building the call. + LR.suppressDiagnostics(); + return S.LookupQualifiedName(LR, RD); +} + /// Look up the std::coroutine_traits<...>::promise_type for the given /// function type. static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType, - SourceLocation Loc) { + SourceLocation KwLoc, + SourceLocation FuncLoc) { // FIXME: Cache std::coroutine_traits once we've found it. NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace(); if (!StdExp) { - S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found); + S.Diag(KwLoc, diag::err_implied_std_coroutine_traits_not_found); return QualType(); } LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"), - Loc, Sema::LookupOrdinaryName); + FuncLoc, Sema::LookupOrdinaryName); if (!S.LookupQualifiedName(Result, StdExp)) { - S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found); + S.Diag(KwLoc, diag::err_implied_std_coroutine_traits_not_found); return QualType(); } @@ -49,52 +60,58 @@ } // Form template argument list for coroutine_traits. - TemplateArgumentListInfo Args(Loc, Loc); + TemplateArgumentListInfo Args(KwLoc, KwLoc); Args.addArgument(TemplateArgumentLoc( TemplateArgument(FnType->getReturnType()), - S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc))); + S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), KwLoc))); // FIXME: If the function is a non-static member function, add the type // of the implicit object parameter before the formal parameters. for (QualType T : FnType->getParamTypes()) Args.addArgument(TemplateArgumentLoc( - TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc))); + TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc))); // Build the template-id. QualType CoroTrait = - S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args); + S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args); if (CoroTrait.isNull()) return QualType(); - if (S.RequireCompleteType(Loc, CoroTrait, + if (S.RequireCompleteType(KwLoc, CoroTrait, diag::err_coroutine_traits_missing_specialization)) return QualType(); - CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl(); + auto *RD = CoroTrait->getAsCXXRecordDecl(); assert(RD && "specialization of class template is not a class?"); // Look up the ::promise_type member. - LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc, + LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc, Sema::LookupOrdinaryName); S.LookupQualifiedName(R, RD); auto *Promise = R.getAsSingle(); if (!Promise) { - S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found) + S.Diag(FuncLoc, + diag::err_implied_std_coroutine_traits_promise_type_not_found) << RD; return QualType(); } - // The promise type is required to be a class type. QualType PromiseType = S.Context.getTypeDeclType(Promise); - if (!PromiseType->getAsCXXRecordDecl()) { - // Use the fully-qualified name of the type. + + auto buildElaboratedType = [&]() { auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp); NNS = NestedNameSpecifier::Create(S.Context, NNS, false, CoroTrait.getTypePtr()); - PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType); + return S.Context.getElaboratedType(ETK_None, NNS, PromiseType); + }; - S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class) - << PromiseType; + if (!PromiseType->getAsCXXRecordDecl()) { + S.Diag(FuncLoc, + diag::err_implied_std_coroutine_traits_promise_type_not_class) + << buildElaboratedType(); return QualType(); } + if (S.RequireCompleteType(FuncLoc, buildElaboratedType(), + diag::err_coroutine_promise_type_incomplete)) + return QualType(); return PromiseType; } @@ -160,41 +177,49 @@ return !Diagnosed; } -/// Check that this is a context in which a coroutine suspension can appear. -static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc, - StringRef Keyword) { - if (!isValidCoroutineContext(S, Loc, Keyword)) - return nullptr; +static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S, + SourceLocation Loc) { + DeclarationName OpName = + SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait); + LookupResult Operators(SemaRef, OpName, SourceLocation(), + Sema::LookupOperatorName); + SemaRef.LookupName(Operators, S); + + assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous"); + const auto &Functions = Operators.asUnresolvedSet(); + bool IsOverloaded = + Functions.size() > 1 || + (Functions.size() == 1 && isa(*Functions.begin())); + Expr *CoawaitOp = UnresolvedLookupExpr::Create( + SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(), + DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded, + Functions.begin(), Functions.end()); + assert(CoawaitOp); + return CoawaitOp; +} - assert(isa(S.CurContext) && "not in a function scope"); - auto *FD = cast(S.CurContext); - auto *ScopeInfo = S.getCurFunction(); - assert(ScopeInfo && "missing function scope for function"); +/// Build a call to 'operator co_await' if there is a suitable operator for +/// the given expression. +static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc, + Expr *E, + UnresolvedLookupExpr *Lookup) { - // If we don't have a promise variable, build one now. - if (!ScopeInfo->CoroutinePromise) { - QualType T = FD->getType()->isDependentType() - ? S.Context.DependentTy - : lookupPromiseType( - S, FD->getType()->castAs(), Loc); - if (T.isNull()) - return nullptr; - - // Create and default-initialize the promise. - ScopeInfo->CoroutinePromise = - VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(), - &S.PP.getIdentifierTable().get("__promise"), T, - S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None); - S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise); - if (!ScopeInfo->CoroutinePromise->isInvalidDecl()) - S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise); - } + UnresolvedSet<16> Functions; + Functions.append(Lookup->decls_begin(), Lookup->decls_end()); + return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); +} - return ScopeInfo; +static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, + SourceLocation Loc, Expr *E) { + ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc); + if (R.isInvalid()) + return ExprError(); + return buildOperatorCoawaitCall(SemaRef, Loc, E, + cast(R.get())); } static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id, - MutableArrayRef CallArgs) { + MultiExprArg CallArgs) { StringRef Name = S.Context.BuiltinInfo.getName(Id); LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName); S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true); @@ -213,15 +238,6 @@ return Call.get(); } -/// Build a call to 'operator co_await' if there is a suitable operator for -/// the given expression. -static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, - SourceLocation Loc, Expr *E) { - UnresolvedSet<16> Functions; - SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(), - Functions); - return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); -} struct ReadySuspendResumeResult { bool IsInvalid; @@ -229,8 +245,7 @@ }; static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc, - StringRef Name, - MutableArrayRef Args) { + StringRef Name, MultiExprArg Args) { DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc); // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&. @@ -268,25 +283,174 @@ return Calls; } +static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise, + SourceLocation Loc, StringRef Name, + MultiExprArg Args) { + + // Form a reference to the promise. + ExprResult PromiseRef = S.BuildDeclRefExpr( + Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc); + if (PromiseRef.isInvalid()) + return ExprError(); + + // Call 'yield_value', passing in E. + return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); +} + +VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) { + assert(isa(CurContext) && "not in a function scope"); + auto *FD = cast(CurContext); + + QualType T = + FD->getType()->isDependentType() + ? Context.DependentTy + : lookupPromiseType(*this, FD->getType()->castAs(), + Loc, FD->getLocation()); + if (T.isNull()) + return nullptr; + + auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(), + &PP.getIdentifierTable().get("__promise"), T, + Context.getTrivialTypeSourceInfo(T, Loc), SC_None); + CheckVariableDeclarationType(VD); + if (VD->isInvalidDecl()) + return nullptr; + ActOnUninitializedDecl(VD); + assert(!VD->isInvalidDecl()); + return VD; +} + +/// Check that this is a context in which a coroutine suspension can appear. +static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc, + StringRef Keyword) { + if (!isValidCoroutineContext(S, Loc, Keyword)) + return nullptr; + + assert(isa(S.CurContext) && "not in a function scope"); + auto *FD = cast(S.CurContext); + + auto *ScopeInfo = S.getCurFunction(); + assert(ScopeInfo && "missing function scope for function"); + + if (ScopeInfo->CoroutinePromise) + return ScopeInfo; + + ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc); + if (!ScopeInfo->CoroutinePromise) + return nullptr; + + return ScopeInfo; +} + +static bool actOnCoroutineBodyStart(Sema &S, Scope *SC, SourceLocation KWLoc, + StringRef Keyword) { + if (!checkCoroutineContext(S, KWLoc, Keyword)) + return false; + auto *ScopeInfo = S.getCurFunction(); + assert(ScopeInfo->CoroutinePromise); + + // If we have existing coroutine statements then we have already built + // the initial and final suspend points. + if (!ScopeInfo->NeedsCoroutineSuspends) + return true; + + ScopeInfo->setNeedsCoroutineSuspends(false); + + auto *Fn = cast(S.CurContext); + SourceLocation Loc = Fn->getLocation(); + // Build the initial suspend point + auto buildSuspends = [&](StringRef Name) mutable -> StmtResult { + ExprResult Suspend = + buildPromiseCall(S, ScopeInfo->CoroutinePromise, Loc, Name, None); + if (Suspend.isInvalid()) + return StmtError(); + Suspend = buildOperatorCoawaitCall(S, SC, Loc, Suspend.get()); + if (Suspend.isInvalid()) + return StmtError(); + Suspend = S.BuildCoawaitExpr(Loc, Suspend.get(), + /*IsImplicitlyCreated*/ true); + Suspend = S.ActOnFinishFullExpr(Suspend.get()); + if (Suspend.isInvalid()) { + S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) + << ((Name == "initial_suspend") ? 0 : 1); + S.Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword; + return StmtError(); + } + return cast(Suspend.get()); + }; + + StmtResult InitSuspend = buildSuspends("initial_suspend"); + if (InitSuspend.isInvalid()) + return true; + + StmtResult FinalSuspend = buildSuspends("final_suspend"); + if (FinalSuspend.isInvalid()) + return true; + + ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get()); + + return true; +} + ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await"); - if (!Coroutine) { + if (!actOnCoroutineBodyStart(*this, S, Loc, "co_await")) { CorrectDelayedTyposInExpr(E); return ExprError(); } + if (E->getType()->isPlaceholderType()) { ExprResult R = CheckPlaceholderExpr(E); if (R.isInvalid()) return ExprError(); E = R.get(); } + ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc); + if (Lookup.isInvalid()) + return ExprError(); + return BuildDependentCoawaitExpr(Loc, E, + cast(Lookup.get())); +} + +ExprResult Sema::BuildDependentCoawaitExpr(SourceLocation Loc, Expr *E, + UnresolvedLookupExpr *Lookup) { + auto *FSI = checkCoroutineContext(*this, Loc, "co_await"); + if (!FSI) + return ExprError(); + + if (E->getType()->isPlaceholderType()) { + ExprResult R = CheckPlaceholderExpr(E); + if (R.isInvalid()) + return ExprError(); + E = R.get(); + } + + auto *Promise = FSI->CoroutinePromise; + if (Promise->getType()->isDependentType()) { + Expr *Res = + new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup); + FSI->CoroutineStmts.push_back(Res); + return Res; + } - ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E); + auto *RD = Promise->getType()->getAsCXXRecordDecl(); + if (lookupMember(*this, "await_transform", RD, Loc)) { + ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E); + if (R.isInvalid()) { + Diag(Loc, + diag::note_coroutine_promise_implicit_await_transform_required_here) + << E->getSourceRange(); + return ExprError(); + } + E = R.get(); + } + ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup); if (Awaitable.isInvalid()) return ExprError(); return BuildCoawaitExpr(Loc, Awaitable.get()); } -ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) { + +ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E, + bool IsImplicitlyCreated) { auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await"); if (!Coroutine) return ExprError(); @@ -298,8 +462,10 @@ } if (E->getType()->isDependentType()) { - Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E); - Coroutine->CoroutineStmts.push_back(Res); + Expr *Res = new (Context) + CoawaitExpr(Loc, Context.DependentTy, E, IsImplicitlyCreated); + if (!IsImplicitlyCreated) + Coroutine->CoroutineStmts.push_back(Res); return Res; } @@ -314,37 +480,21 @@ return ExprError(); Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1], - RSS.Results[2]); - Coroutine->CoroutineStmts.push_back(Res); + RSS.Results[2], IsImplicitlyCreated); + if (!IsImplicitlyCreated) + Coroutine->CoroutineStmts.push_back(Res); return Res; } -static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine, - SourceLocation Loc, StringRef Name, - MutableArrayRef Args) { - assert(Coroutine->CoroutinePromise && "no promise for coroutine"); - - // Form a reference to the promise. - auto *Promise = Coroutine->CoroutinePromise; - ExprResult PromiseRef = S.BuildDeclRefExpr( - Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc); - if (PromiseRef.isInvalid()) - return ExprError(); - - // Call 'yield_value', passing in E. - return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); -} - ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield"); - if (!Coroutine) { + if (!actOnCoroutineBodyStart(*this, S, Loc, "co_yield")) { CorrectDelayedTyposInExpr(E); return ExprError(); } // Build yield_value call. - ExprResult Awaitable = - buildPromiseCall(*this, Coroutine, Loc, "yield_value", E); + ExprResult Awaitable = buildPromiseCall( + *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E); if (Awaitable.isInvalid()) return ExprError(); @@ -388,18 +538,18 @@ return Res; } -StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return"); - if (!Coroutine) { +StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) { + if (!actOnCoroutineBodyStart(*this, S, Loc, "co_return")) { CorrectDelayedTyposInExpr(E); return StmtError(); } return BuildCoreturnStmt(Loc, E); } -StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return"); - if (!Coroutine) +StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E, + bool IsImplicitlyCreated) { + auto *FSI = checkCoroutineContext(*this, Loc, "co_return"); + if (!FSI) return StmtError(); if (E && E->getType()->isPlaceholderType() && @@ -412,20 +562,22 @@ // FIXME: If the operand is a reference to a variable that's about to go out // of scope, we should treat the operand as an xvalue for this overload // resolution. + VarDecl *Promise = FSI->CoroutinePromise; ExprResult PC; if (E && (isa(E) || !E->getType()->isVoidType())) { - PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E); + PC = buildPromiseCall(*this, Promise, Loc, "return_value", E); } else { E = MakeFullDiscardedValueExpr(E).get(); - PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None); + PC = buildPromiseCall(*this, Promise, Loc, "return_void", None); } if (PC.isInvalid()) return StmtError(); Expr *PCE = ActOnFinishFullExpr(PC.get()).get(); - Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE); - Coroutine->CoroutineStmts.push_back(Res); + Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicitlyCreated); + if (!IsImplicitlyCreated) + FSI->CoroutineStmts.push_back(Res); return Res; } @@ -482,14 +634,82 @@ return OperatorDelete; } +static bool buildFallthrough(Sema &S, SourceLocation Loc, + FunctionDecl *FD, + FunctionScopeInfo *FTI, + Stmt *&OnFallthrough) +{ + assert(!OnFallthrough && "rebuilding existing OnFallthrough"); + auto *Promise = FTI->CoroutinePromise; + if (Promise->getType()->isDependentType()) + return true; + + CXXRecordDecl *RD = Promise->getType()->getAsCXXRecordDecl(); + + // [dcl.fct.def.coroutine]/4 + // The unqualified-ids 'return_void' and 'return_value' are looked up in + // the scope of class P. If both are found, the program is ill-formed. + const bool HasRVoid = lookupMember(S, "return_void", RD, Loc); + const bool HasRValue = lookupMember(S, "return_value", RD, Loc); + if (HasRVoid && HasRValue) { + // FIXME Improve this diagnostic + S.Diag(FD->getLocation(), diag::err_coroutine_promise_return_ill_formed) + << RD; + return false; + } else if (HasRVoid) { + // If the unqualified-id return_void is found, flowing off the end of a + // coroutine is equivalent to a co_return with no operand. Otherwise, + // flowing off the end of a coroutine results in undefined behavior. + StmtResult Fallthrough = S.BuildCoreturnStmt(FD->getLocation(), nullptr, + /*IsImplicitlyCreated*/ true); + if (!Fallthrough.isInvalid()) + Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get()); + if (Fallthrough.isInvalid()) + return false; + OnFallthrough = Fallthrough.get(); + } + return true; +} + +static bool buildSetException(Sema &S, SourceLocation Loc, + FunctionDecl *FD, + FunctionScopeInfo *FTI, + Stmt *&OnException) +{ + assert(!OnException && "rebuilding existing set_exception"); + auto *Promise = FTI->CoroutinePromise; + if (Promise->getType()->isDependentType()) + return true; + + CXXRecordDecl *RD = Promise->getType()->getAsCXXRecordDecl(); + + // [dcl.fct.def.coroutine]/3 + // The unqualified-id set_exception is found in the scope of P by class + // member access lookup (3.4.5). + if (lookupMember(S, "set_exception", RD, Loc)) { + // Form the call 'p.set_exception(std::current_exception())' + ExprResult SetException = buildStdCurrentExceptionCall(S, Loc); + if (SetException.isInvalid()) + return false; + Expr *E = SetException.get(); + SetException = buildPromiseCall(S, Promise, Loc, "set_exception", E); + SetException = S.ActOnFinishFullExpr(SetException.get(), Loc); + if (SetException.isInvalid()) + return false; + OnException = SetException.get(); + } + return true; +} + + // Builds allocation and deallocation for the coroutine. Returns false on // failure. static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc, FunctionScopeInfo *Fn, Expr *&Allocation, Expr *&Deallocation) { - TypeSourceInfo *TInfo = Fn->CoroutinePromise->getTypeSourceInfo(); - QualType PromiseType = TInfo->getType(); + assert(!Allocation && !Deallocation && "alloc/dealloc statements have already been built"); + QualType PromiseType = Fn->CoroutinePromise->getType(); if (PromiseType->isDependentType()) return true; @@ -532,8 +752,6 @@ if (NewExpr.isInvalid()) return false; - Allocation = NewExpr.get(); - // Make delete call. QualType OpDeleteQualType = OperatorDelete->getType(); @@ -559,6 +777,7 @@ if (DeleteExpr.isInvalid()) return false; + Allocation = NewExpr.get(); Deallocation = DeleteExpr.get(); return true; @@ -608,7 +827,7 @@ void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { FunctionScopeInfo *Fn = getCurFunction(); - assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine"); + assert(Fn && Fn->CoroutinePromise && "not a coroutine"); // Coroutines [stmt.return]p1: // A return statement shall not appear in a coroutine. @@ -616,8 +835,8 @@ Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine); auto *First = Fn->CoroutineStmts[0]; Diag(First->getLocStart(), diag::note_declared_coroutine_here) - << (isa(First) ? 0 : - isa(First) ? 1 : 2); + << (isa(First) ? "co_await" : + isa(First) ? "co_yield" : "co_return"); } SubStmtBuilder Builder(*this, *FD, *Fn, Body); if (Builder.isInvalid()) @@ -640,32 +859,16 @@ } bool SubStmtBuilder::makeInitialSuspend() { - // Form and check implicit 'co_await p.initial_suspend();' statement. - ExprResult InitialSuspend = - buildPromiseCall(S, &Fn, Loc, "initial_suspend", None); - // FIXME: Support operator co_await here. - if (!InitialSuspend.isInvalid()) - InitialSuspend = S.BuildCoawaitExpr(Loc, InitialSuspend.get()); - InitialSuspend = S.ActOnFinishFullExpr(InitialSuspend.get()); - if (InitialSuspend.isInvalid()) + if (Fn.hasInvalidCoroutineSuspends()) return false; - - this->InitialSuspend = InitialSuspend.get(); + this->InitialSuspend = cast(Fn.CoroutineSuspends.first); return true; } bool SubStmtBuilder::makeFinalSuspend() { - // Form and check implicit 'co_await p.final_suspend();' statement. - ExprResult FinalSuspend = - buildPromiseCall(S, &Fn, Loc, "final_suspend", None); - // FIXME: Support operator co_await here. - if (!FinalSuspend.isInvalid()) - FinalSuspend = S.BuildCoawaitExpr(Loc, FinalSuspend.get()); - FinalSuspend = S.ActOnFinishFullExpr(FinalSuspend.get()); - if (FinalSuspend.isInvalid()) + if (Fn.hasInvalidCoroutineSuspends()) return false; - - this->FinalSuspend = FinalSuspend.get(); + this->FinalSuspend = cast(Fn.CoroutineSuspends.second); return true; } @@ -736,7 +939,7 @@ if (SetException.isInvalid()) return false; Expr *E = SetException.get(); - SetException = buildPromiseCall(S, &Fn, Loc, "set_exception", E); + SetException = buildPromiseCall(S, Fn.CoroutinePromise, Loc, "set_exception", E); SetException = S.ActOnFinishFullExpr(SetException.get(), Loc); if (SetException.isInvalid()) return false; @@ -751,7 +954,7 @@ // Build implicit 'p.get_return_object()' expression and form initialization // of return type from it. ExprResult ReturnObject = - buildPromiseCall(S, &Fn, Loc, "get_return_object", None); + buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None); if (ReturnObject.isInvalid()) return false; QualType RetType = FD.getReturnType(); @@ -775,3 +978,10 @@ // FIXME: Perform move-initialization of parameters into frame-local copies. return true; } + +StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) { + CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args); + if (!Res) + return StmtError(); + return Res; +} Index: lib/Sema/SemaDecl.cpp =================================================================== --- lib/Sema/SemaDecl.cpp +++ lib/Sema/SemaDecl.cpp @@ -11984,7 +11984,7 @@ sema::AnalysisBasedWarnings::Policy WP = AnalysisWarnings.getDefaultPolicy(); sema::AnalysisBasedWarnings::Policy *ActivePolicy = nullptr; - if (getLangOpts().CoroutinesTS && !getCurFunction()->CoroutineStmts.empty()) + if (getLangOpts().CoroutinesTS && getCurFunction()->CoroutinePromise) CheckCompletedCoroutineBody(FD, Body); if (FD) { Index: lib/Sema/SemaExceptionSpec.cpp =================================================================== --- lib/Sema/SemaExceptionSpec.cpp +++ lib/Sema/SemaExceptionSpec.cpp @@ -1182,6 +1182,7 @@ case Expr::ArraySubscriptExprClass: case Expr::OMPArraySectionExprClass: case Expr::BinaryOperatorClass: + case Expr::DependentCoawaitExprClass: case Expr::CompoundAssignOperatorClass: case Expr::CStyleCastExprClass: case Expr::CXXStaticCastExprClass: Index: lib/Sema/TreeTransform.h =================================================================== --- lib/Sema/TreeTransform.h +++ lib/Sema/TreeTransform.h @@ -1362,16 +1362,28 @@ /// /// By default, performs semantic analysis to build the new statement. /// Subclasses may override this routine to provide different behavior. - StmtResult RebuildCoreturnStmt(SourceLocation CoreturnLoc, Expr *Result) { - return getSema().BuildCoreturnStmt(CoreturnLoc, Result); + StmtResult RebuildCoreturnStmt(SourceLocation CoreturnLoc, Expr *Result, + bool IsImplicit) { + return getSema().BuildCoreturnStmt(CoreturnLoc, Result, IsImplicit); } /// \brief Build a new co_await expression. /// /// By default, performs semantic analysis to build the new expression. /// Subclasses may override this routine to provide different behavior. - ExprResult RebuildCoawaitExpr(SourceLocation CoawaitLoc, Expr *Result) { - return getSema().BuildCoawaitExpr(CoawaitLoc, Result); + ExprResult RebuildCoawaitExpr(SourceLocation CoawaitLoc, Expr *Result, + bool IsImplicit) { + return getSema().BuildCoawaitExpr(CoawaitLoc, Result, IsImplicit); + } + + /// \brief Build a new co_await expression. + /// + /// By default, performs semantic analysis to build the new expression. + /// Subclasses may override this routine to provide different behavior. + ExprResult RebuildDependentCoawaitExpr(SourceLocation CoawaitLoc, + Expr *Result, + UnresolvedLookupExpr *Lookup) { + return getSema().BuildDependentCoawaitExpr(CoawaitLoc, Result, Lookup); } /// \brief Build a new co_yield expression. @@ -1382,6 +1394,10 @@ return getSema().BuildCoyieldExpr(CoyieldLoc, Result); } + StmtResult RebuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) { + return getSema().BuildCoroutineBodyStmt(Args); + } + /// \brief Build a new Objective-C \@try statement. /// /// By default, performs semantic analysis to build the new statement. @@ -6826,7 +6842,91 @@ TreeTransform::TransformCoroutineBodyStmt(CoroutineBodyStmt *S) { // The coroutine body should be re-formed by the caller if necessary. // FIXME: The coroutine body is always rebuilt by ActOnFinishFunctionBody - return getDerived().TransformStmt(S->getBody()); + CoroutineBodyStmt::CtorArgs BodyArgs; + + auto *ScopeInfo = SemaRef.getCurFunction(); + auto *FD = cast(SemaRef.CurContext); + assert(ScopeInfo && !ScopeInfo->CoroutinePromise && + ScopeInfo->NeedsCoroutineSuspends && + ScopeInfo->CoroutineSuspends.first == nullptr && + ScopeInfo->CoroutineSuspends.second == nullptr && + ScopeInfo->CoroutineStmts.empty() && "expected clean scope info"); + + // Set that we have (possibly-invalid) suspend points before we do anything + // that may fail. + ScopeInfo->setNeedsCoroutineSuspends(false); + + // The new CoroutinePromise object needs to be built and put into the current + // FunctionScopeInfo before any transformations or rebuilding occurs. + auto *Promise = S->getPromiseDecl(); + auto *NewPromise = SemaRef.buildCoroutinePromise(FD->getLocation()); + if (!NewPromise) + return StmtError(); + getDerived().transformedLocalDecl(Promise, NewPromise); + ScopeInfo->CoroutinePromise = NewPromise; + StmtResult PromiseStmt = SemaRef.ActOnDeclStmt( + SemaRef.ConvertDeclToDeclGroup(NewPromise), + FD->getLocation(), FD->getLocation()); + assert(!PromiseStmt.isInvalid()); + BodyArgs.Promise = PromiseStmt.get(); + + // Transform the implicit coroutine statements we built during the initial + // parse. + StmtResult InitSuspend = getDerived().TransformStmt(S->getInitSuspendStmt()); + if (InitSuspend.isInvalid()) + return StmtError(); + StmtResult FinalSuspend = + getDerived().TransformStmt(S->getFinalSuspendStmt()); + if (FinalSuspend.isInvalid()) + return StmtError(); + ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get()); + assert(isa(InitSuspend.get()) && isa(FinalSuspend.get())); + BodyArgs.InitialSuspend = cast(InitSuspend.get()); + BodyArgs.FinalSuspend = cast(FinalSuspend.get()); + + StmtResult BodyRes = getDerived().TransformStmt(S->getBody()); + if (BodyRes.isInvalid()) + return StmtError(); + BodyArgs.Body = BodyRes.get(); + + if (S->getFallthroughHandler()) { + StmtResult Res = getDerived().TransformStmt(S->getFallthroughHandler()); + if (Res.isInvalid()) + return StmtError(); + BodyArgs.OnFallthrough = Res.get(); + } + + if (S->getExceptionHandler()) { + StmtResult Res = getDerived().TransformStmt(S->getExceptionHandler()); + if (Res.isInvalid()) + return StmtError(); + BodyArgs.OnException = Res.get(); + } + + // Transform any additional statements we may have already built + if (S->getAllocate() && S->getDeallocate()) { + ExprResult AllocRes = getDerived().TransformExpr(S->getAllocate()); + if (AllocRes.isInvalid()) + return StmtError(); + BodyArgs.Allocate = AllocRes.get(); + + ExprResult DeallocRes = getDerived().TransformExpr(S->getDeallocate()); + if (DeallocRes.isInvalid()) + return StmtError(); + BodyArgs.Deallocate = DeallocRes.get(); + } + + Expr *ReturnObject = S->getReturnValueInit(); + if (ReturnObject) { + ExprResult Res = getDerived().TransformInitializer(ReturnObject, + /*NoCopyInit*/false); + if (Res.isInvalid()) + return StmtError(); + BodyArgs.ReturnValue = Res.get(); + } + + // Do a partial rebuild of the coroutine body and stash it in the ScopeInfo + return getDerived().RebuildCoroutineBodyStmt(BodyArgs); } template @@ -6839,7 +6939,8 @@ // Always rebuild; we don't know if this needs to be injected into a new // context or if the promise type has changed. - return getDerived().RebuildCoreturnStmt(S->getKeywordLoc(), Result.get()); + return getDerived().RebuildCoreturnStmt(S->getKeywordLoc(), Result.get(), + S->isImplicit()); } template @@ -6852,7 +6953,32 @@ // Always rebuild; we don't know if this needs to be injected into a new // context or if the promise type has changed. - return getDerived().RebuildCoawaitExpr(E->getKeywordLoc(), Result.get()); + return getDerived().RebuildCoawaitExpr(E->getKeywordLoc(), Result.get(), + E->isImplicit()); +} + +template +ExprResult +TreeTransform::TransformDependentCoawaitExpr(DependentCoawaitExpr *E) { + ExprResult OperandResult = getDerived().TransformInitializer(E->getOperand(), + /*NotCopyInit*/ false); + if (OperandResult.isInvalid()) + return ExprError(); + + ExprResult LookupResult = getDerived().TransformUnresolvedLookupExpr( + E->getOperatorCoawaitLookup()); + + if (LookupResult.isInvalid()) + return ExprError(); + + // FIXME(EricWF): Remove this + assert(isa(LookupResult.get()) && "Expected lookup expr"); + + // Always rebuild; we don't know if this needs to be injected into a new + // context or if the promise type has changed. + return getDerived().RebuildDependentCoawaitExpr( + E->getKeywordLoc(), OperandResult.get(), + cast(LookupResult.get())); } template Index: lib/Serialization/ASTReaderStmt.cpp =================================================================== --- lib/Serialization/ASTReaderStmt.cpp +++ lib/Serialization/ASTReaderStmt.cpp @@ -381,6 +381,11 @@ llvm_unreachable("unimplemented"); } +void ASTStmtReader::VisitDependentCoawaitExpr(DependentCoawaitExpr *S) { + // FIXME: Implement coroutine serialization. + llvm_unreachable("unimplemented"); +} + void ASTStmtReader::VisitCoyieldExpr(CoyieldExpr *S) { // FIXME: Implement coroutine serialization. llvm_unreachable("unimplemented"); Index: lib/Serialization/ASTWriterStmt.cpp =================================================================== --- lib/Serialization/ASTWriterStmt.cpp +++ lib/Serialization/ASTWriterStmt.cpp @@ -315,6 +315,11 @@ llvm_unreachable("unimplemented"); } +void ASTStmtWriter::VisitDependentCoawaitExpr(DependentCoawaitExpr *S) { + // FIXME: Implement coroutine serialization. + llvm_unreachable("unimplemented"); +} + void ASTStmtWriter::VisitCoyieldExpr(CoyieldExpr *S) { // FIXME: Implement coroutine serialization. llvm_unreachable("unimplemented"); Index: lib/StaticAnalyzer/Core/ExprEngine.cpp =================================================================== --- lib/StaticAnalyzer/Core/ExprEngine.cpp +++ lib/StaticAnalyzer/Core/ExprEngine.cpp @@ -792,6 +792,7 @@ case Stmt::FunctionParmPackExprClass: case Stmt::CoroutineBodyStmtClass: case Stmt::CoawaitExprClass: + case Stmt::DependentCoawaitExprClass: case Stmt::CoreturnStmtClass: case Stmt::CoyieldExprClass: case Stmt::SEHTryStmtClass: Index: test/SemaCXX/coroutines.cpp =================================================================== --- test/SemaCXX/coroutines.cpp +++ test/SemaCXX/coroutines.cpp @@ -59,25 +59,25 @@ template struct std::experimental::coroutine_traits {}; -int no_promise_type() { - co_await a; // expected-error {{this function cannot be a coroutine: 'std::experimental::coroutine_traits' has no member named 'promise_type'}} +int no_promise_type() { // expected-error {{this function cannot be a coroutine: 'std::experimental::coroutine_traits' has no member named 'promise_type'}} + co_await a; } template <> struct std::experimental::coroutine_traits { typedef int promise_type; }; -double bad_promise_type(double) { - co_await a; // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits::promise_type' (aka 'int') is not a class}} +double bad_promise_type(double) { // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits::promise_type' (aka 'int') is not a class}} + co_await a; } template <> struct std::experimental::coroutine_traits { struct promise_type {}; }; -double bad_promise_type_2(int) { +double bad_promise_type_2(int) { // expected-error {{no member named 'initial_suspend'}} co_yield 0; // expected-error {{no member named 'yield_value' in 'std::experimental::coroutine_traits::promise_type'}} } -struct promise; // expected-note 2{{forward declaration}} +struct promise; // expected-note {{forward declaration}} struct promise_void; struct void_tag {}; template @@ -94,9 +94,7 @@ } // FIXME: This diagnostic is terrible. -void undefined_promise() { // expected-error {{variable has incomplete type 'promise_type'}} - // FIXME: This diagnostic doesn't make any sense. - // expected-error@-2 {{incomplete definition of type 'promise'}} +void undefined_promise() { // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits::promise_type' (aka 'promise') is an incomplete type}} co_await a; } @@ -216,6 +214,13 @@ } struct outer {}; +struct await_arg_1 {}; +struct await_arg_2 {}; + +namespace adl_ns { +struct coawait_arg_type {}; +awaitable operator co_await(coawait_arg_type); +} namespace dependent_operator_co_await_lookup { template void await_template(T t) { @@ -238,6 +243,94 @@ }; template void await_template(outer); // expected-note {{instantiation}} template void await_template_2(outer); + + struct transform_awaitable {}; + struct transformed {}; + + struct transform_promise { + typedef transform_awaitable await_arg; + coro get_return_object(); + transformed initial_suspend(); + ::adl_ns::coawait_arg_type final_suspend(); + transformed await_transform(transform_awaitable); + }; + template + struct basic_promise { + typedef AwaitArg await_arg; + coro get_return_object(); + awaitable initial_suspend(); + awaitable final_suspend(); + }; + + awaitable operator co_await(await_arg_1); + + template + coro await_template_3(U t) { + co_await t; + } + + template coro> await_template_3>(await_arg_1); + + template + struct dependent_member { + coro mem_fn() const { + co_await typename T::await_arg{}; // expected-error {{call to function 'operator co_await'}}} + } + template + coro dep_mem_fn(U t) { + co_await t; + } + }; + + template <> + struct dependent_member { + // FIXME this diagnostic is terrible + coro mem_fn() const { // expected-error {{no member named 'await_ready' in 'dependent_operator_co_await_lookup::transformed'}} + // expected-note@-1 {{call to 'initial_suspend' implicitly required by the initial suspend point}} + // expected-note@+1 {{function is a coroutine due to use of 'co_await' here}} + co_await transform_awaitable{}; + // expected-error@-1 {{no member named 'await_ready'}} + } + template + coro dep_mem_fn(U u) { co_await u; } + }; + + awaitable operator co_await(await_arg_2); // expected-note {{'operator co_await' should be declared prior to the call site}} + + template struct dependent_member, 0>; + template struct dependent_member, 0>; // expected-note {{in instantiation}} + + template <> + coro + // FIXME this diagnostic is terrible + dependent_member::dep_mem_fn(int) { // expected-error {{no member named 'await_ready' in 'dependent_operator_co_await_lookup::transformed'}} + //expected-note@-1 {{call to 'initial_suspend' implicitly required by the initial suspend point}} + //expected-note@+1 {{function is a coroutine due to use of 'co_await' here}} + co_await transform_awaitable{}; + // expected-error@-1 {{no member named 'await_ready'}} + } + + void operator co_await(transform_awaitable) = delete; + awaitable operator co_await(transformed); + + template coro + dependent_member::dep_mem_fn(transform_awaitable); + + template <> + coro dependent_member::dep_mem_fn(long) { + co_await transform_awaitable{}; + } + + template <> + struct dependent_member { + coro mem_fn() const { + co_await transform_awaitable{}; + } + }; + + template coro await_template_3(transform_awaitable); + template struct dependent_member; + template coro dependent_member::dep_mem_fn(transform_awaitable); } struct yield_fn_tag {}; @@ -293,6 +386,7 @@ // FIXME: We shouldn't offer a typo-correction here! suspend_always final_suspend(); // expected-note {{here}} }; +// FIXME: This shouldn't happen twice coro missing_initial_suspend() { // expected-error {{no member named 'initial_suspend' in 'bad_promise_2'}} co_await a; } @@ -313,7 +407,8 @@ }; // FIXME: This diagnostic is terrible. coro bad_initial_suspend() { // expected-error {{no member named 'await_ready' in 'not_awaitable'}} - co_await a; + // expected-note@-1 {{call to 'initial_suspend' implicitly required by the initial suspend point}} + co_await a; // expected-note {{function is a coroutine due to use of 'co_await' here}} } struct bad_promise_5 { @@ -323,7 +418,8 @@ }; // FIXME: This diagnostic is terrible. coro bad_final_suspend() { // expected-error {{no member named 'await_ready' in 'not_awaitable'}} - co_await a; + // expected-note@-1 {{call to 'final_suspend' implicitly required by the final suspend point}} + co_await a; // expected-note {{function is a coroutine due to use of 'co_await' here}} } struct bad_promise_6 { Index: tools/libclang/CXCursor.cpp =================================================================== --- tools/libclang/CXCursor.cpp +++ tools/libclang/CXCursor.cpp @@ -231,6 +231,7 @@ case Stmt::TypeTraitExprClass: case Stmt::CoroutineBodyStmtClass: case Stmt::CoawaitExprClass: + case Stmt::DependentCoawaitExprClass: case Stmt::CoreturnStmtClass: case Stmt::CoyieldExprClass: case Stmt::CXXBindTemporaryExprClass: