Index: include/clang/AST/StmtCXX.h =================================================================== --- include/clang/AST/StmtCXX.h +++ include/clang/AST/StmtCXX.h @@ -296,7 +296,9 @@ /// \brief Represents the body of a coroutine. This wraps the normal function /// body and holds the additional semantic context required to set up and tear /// down the coroutine frame. -class CoroutineBodyStmt : public Stmt { +class CoroutineBodyStmt final + : public Stmt, + private llvm::TrailingObjects { enum SubStmt { Body, ///< The body of the coroutine. Promise, ///< The promise statement. @@ -309,52 +311,76 @@ ReturnValue, ///< Return value for thunk function. FirstParamMove ///< First offset for move construction of parameter copies. }; - Stmt *SubStmts[SubStmt::FirstParamMove]; + unsigned NumParams; friend class ASTStmtReader; + friend TrailingObjects; + + Stmt **getStoredStmts() { return getTrailingObjects(); } + + Stmt *const *getStoredStmts() const { return getTrailingObjects(); } + public: - CoroutineBodyStmt(Stmt *Body, Stmt *Promise, Stmt *InitSuspend, - Stmt *FinalSuspend, Stmt *OnException, Stmt *OnFallthrough, - Expr *Allocate, Stmt *Deallocate, - Expr *ReturnValue, ArrayRef ParamMoves) - : Stmt(CoroutineBodyStmtClass) { - SubStmts[CoroutineBodyStmt::Body] = Body; - SubStmts[CoroutineBodyStmt::Promise] = Promise; - SubStmts[CoroutineBodyStmt::InitSuspend] = InitSuspend; - SubStmts[CoroutineBodyStmt::FinalSuspend] = FinalSuspend; - SubStmts[CoroutineBodyStmt::OnException] = OnException; - SubStmts[CoroutineBodyStmt::OnFallthrough] = OnFallthrough; - SubStmts[CoroutineBodyStmt::Allocate] = Allocate; - SubStmts[CoroutineBodyStmt::Deallocate] = Deallocate; - SubStmts[CoroutineBodyStmt::ReturnValue] = ReturnValue; - // FIXME: Tail-allocate space for parameter move expressions and store them. - assert(ParamMoves.empty() && "not implemented yet"); - } + + struct CtorArgs { + Stmt *Body = nullptr; + Stmt *Promise = nullptr; + Expr *InitialSuspend = nullptr; + Expr *FinalSuspend = nullptr; + Stmt *OnException = nullptr; + Stmt *OnFallthrough = nullptr; + Expr *Allocate = nullptr; + Expr *Deallocate = nullptr; + Stmt *ReturnValue = nullptr; + ArrayRef ParamMoves; + }; + +private: + + CoroutineBodyStmt(CtorArgs const& Args); + +public: + static CoroutineBodyStmt *Create(const ASTContext &C, CtorArgs const &Args); /// \brief Retrieve the body of the coroutine as written. This will be either /// a CompoundStmt or a TryStmt. Stmt *getBody() const { - return SubStmts[SubStmt::Body]; + return getStoredStmts()[SubStmt::Body]; } - Stmt *getPromiseDeclStmt() const { return SubStmts[SubStmt::Promise]; } + Stmt *getPromiseDeclStmt() const { + return getStoredStmts()[SubStmt::Promise]; + } VarDecl *getPromiseDecl() const { return cast(cast(getPromiseDeclStmt())->getSingleDecl()); } - Stmt *getInitSuspendStmt() const { return SubStmts[SubStmt::InitSuspend]; } - Stmt *getFinalSuspendStmt() const { return SubStmts[SubStmt::FinalSuspend]; } + Stmt *getInitSuspendStmt() const { + return getStoredStmts()[SubStmt::InitSuspend]; + } + Stmt *getFinalSuspendStmt() const { + return getStoredStmts()[SubStmt::FinalSuspend]; + } - Stmt *getExceptionHandler() const { return SubStmts[SubStmt::OnException]; } + Stmt *getExceptionHandler() const { + return getStoredStmts()[SubStmt::OnException]; + } Stmt *getFallthroughHandler() const { - return SubStmts[SubStmt::OnFallthrough]; + return getStoredStmts()[SubStmt::OnFallthrough]; } - Expr *getAllocate() const { return cast(SubStmts[SubStmt::Allocate]); } - Stmt *getDeallocate() const { return SubStmts[SubStmt::Deallocate]; } + Expr *getAllocate() const { + return cast(getStoredStmts()[SubStmt::Allocate]); + } + Expr *getDeallocate() const { + return cast(getStoredStmts()[SubStmt::Deallocate]); + } Expr *getReturnValueInit() const { - return cast(SubStmts[SubStmt::ReturnValue]); + return cast(getStoredStmts()[SubStmt::ReturnValue]); + } + ArrayRef getParamMoves() const { + return {getStoredStmts() + SubStmt::FirstParamMove, NumParams}; } SourceLocation getLocStart() const LLVM_READONLY { @@ -365,7 +391,8 @@ } child_range children() { - return child_range(SubStmts, SubStmts + SubStmt::FirstParamMove); + return child_range(getStoredStmts(), + getStoredStmts() + SubStmt::FirstParamMove + NumParams); } static bool classof(const Stmt *T) { Index: include/clang/Sema/ScopeInfo.h =================================================================== --- include/clang/Sema/ScopeInfo.h +++ include/clang/Sema/ScopeInfo.h @@ -157,7 +157,7 @@ SmallVector Returns; /// \brief The promise object for this coroutine, if any. - VarDecl *CoroutinePromise; + VarDecl *CoroutinePromise = nullptr; /// \brief The list of coroutine control flow constructs (co_await, co_yield, /// co_return) that occur within the function or block. Empty if and only if Index: lib/AST/StmtCXX.cpp =================================================================== --- lib/AST/StmtCXX.cpp +++ lib/AST/StmtCXX.cpp @@ -86,3 +86,28 @@ const VarDecl *CXXForRangeStmt::getLoopVariable() const { return const_cast(this)->getLoopVariable(); } + +CoroutineBodyStmt *CoroutineBodyStmt::Create( + const ASTContext &C, CoroutineBodyStmt::CtorArgs const& Args) { + std::size_t Size = totalSizeToAlloc( + CoroutineBodyStmt::FirstParamMove + Args.ParamMoves.size()); + + void *Mem = C.Allocate(Size, alignof(CoroutineBodyStmt)); + return new (Mem) CoroutineBodyStmt(Args); +} + +CoroutineBodyStmt::CoroutineBodyStmt(CoroutineBodyStmt::CtorArgs const &Args) + : Stmt(CoroutineBodyStmtClass), NumParams(Args.ParamMoves.size()) { + Stmt **SubStmts = getStoredStmts(); + SubStmts[CoroutineBodyStmt::Body] = Args.Body; + SubStmts[CoroutineBodyStmt::Promise] = Args.Promise; + SubStmts[CoroutineBodyStmt::InitSuspend] = Args.InitialSuspend; + SubStmts[CoroutineBodyStmt::FinalSuspend] = Args.FinalSuspend; + SubStmts[CoroutineBodyStmt::OnException] = Args.OnException; + SubStmts[CoroutineBodyStmt::OnFallthrough] = Args.OnFallthrough; + SubStmts[CoroutineBodyStmt::Allocate] = Args.Allocate; + SubStmts[CoroutineBodyStmt::Deallocate] = Args.Deallocate; + SubStmts[CoroutineBodyStmt::ReturnValue] = Args.ReturnValue; + std::copy(Args.ParamMoves.begin(), Args.ParamMoves.end(), + const_cast(getParamMoves().data())); +} \ No newline at end of file Index: lib/Sema/SemaCoroutine.cpp =================================================================== --- lib/Sema/SemaCoroutine.cpp +++ lib/Sema/SemaCoroutine.cpp @@ -487,7 +487,7 @@ static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc, FunctionScopeInfo *Fn, Expr *&Allocation, - Stmt *&Deallocation) { + Expr *&Deallocation) { TypeSourceInfo *TInfo = Fn->CoroutinePromise->getTypeSourceInfo(); QualType PromiseType = TInfo->getType(); if (PromiseType->isDependentType()) @@ -564,6 +564,48 @@ return true; } +namespace { +class SubStmtBuilder : public CoroutineBodyStmt::CtorArgs { + Sema &S; + FunctionDecl &FD; + FunctionScopeInfo &Fn; + bool IsValid; + SourceLocation Loc; + QualType RetType; + SmallVector ParamMovesVector; + const bool IsPromiseDependentType; + CXXRecordDecl *PromiseRecordDecl = nullptr; + +public: + SubStmtBuilder(Sema &S, FunctionDecl &FD, FunctionScopeInfo &Fn, Stmt *Body) + : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()), + IsPromiseDependentType( + !Fn.CoroutinePromise || + Fn.CoroutinePromise->getType()->isDependentType()) { + this->Body = Body; + if (!IsPromiseDependentType) { + PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl(); + assert(PromiseRecordDecl && "Type should have already been checked"); + } + this->IsValid = makePromiseStmt() && makeInitialSuspend() && + makeFinalSuspend() && makeOnException() && + makeOnFallthrough() && makeNewAndDeleteExpr() && + makeReturnObject() && makeParamMoves(); + } + + bool isInvalid() const { return !this->IsValid; } + + bool makePromiseStmt(); + bool makeInitialSuspend(); + bool makeFinalSuspend(); + bool makeNewAndDeleteExpr(); + bool makeOnFallthrough(); + bool makeOnException(); + bool makeReturnObject(); + bool makeParamMoves(); +}; +} + void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { FunctionScopeInfo *Fn = getCurFunction(); assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine"); @@ -577,120 +619,159 @@ << (isa(First) ? 0 : isa(First) ? 1 : 2); } + SubStmtBuilder Builder(*this, *FD, *Fn, Body); + if (Builder.isInvalid()) + return FD->setInvalidDecl(); - SourceLocation Loc = FD->getLocation(); + // Build body for the coroutine wrapper statement. + Body = CoroutineBodyStmt::Create(Context, Builder); +} +bool SubStmtBuilder::makePromiseStmt() { // Form a declaration statement for the promise declaration, so that AST // visitors can more easily find it. StmtResult PromiseStmt = - ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc); + S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc); if (PromiseStmt.isInvalid()) - return FD->setInvalidDecl(); + return false; + + this->Promise = PromiseStmt.get(); + return true; +} +bool SubStmtBuilder::makeInitialSuspend() { // Form and check implicit 'co_await p.initial_suspend();' statement. ExprResult InitialSuspend = - buildPromiseCall(*this, Fn, Loc, "initial_suspend", None); + buildPromiseCall(S, &Fn, Loc, "initial_suspend", None); // FIXME: Support operator co_await here. if (!InitialSuspend.isInvalid()) - InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get()); - InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get()); + InitialSuspend = S.BuildCoawaitExpr(Loc, InitialSuspend.get()); + InitialSuspend = S.ActOnFinishFullExpr(InitialSuspend.get()); if (InitialSuspend.isInvalid()) - return FD->setInvalidDecl(); + return false; + + this->InitialSuspend = InitialSuspend.get(); + return true; +} +bool SubStmtBuilder::makeFinalSuspend() { // Form and check implicit 'co_await p.final_suspend();' statement. ExprResult FinalSuspend = - buildPromiseCall(*this, Fn, Loc, "final_suspend", None); + buildPromiseCall(S, &Fn, Loc, "final_suspend", None); // FIXME: Support operator co_await here. if (!FinalSuspend.isInvalid()) - FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get()); - FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get()); + FinalSuspend = S.BuildCoawaitExpr(Loc, FinalSuspend.get()); + FinalSuspend = S.ActOnFinishFullExpr(FinalSuspend.get()); if (FinalSuspend.isInvalid()) - return FD->setInvalidDecl(); + return false; + this->FinalSuspend = FinalSuspend.get(); + return true; +} + +bool SubStmtBuilder::makeNewAndDeleteExpr() { // Form and check allocation and deallocation calls. - Expr *Allocation = nullptr; - Stmt *Deallocation = nullptr; - if (!buildAllocationAndDeallocation(*this, Loc, Fn, Allocation, Deallocation)) - return FD->setInvalidDecl(); + return buildAllocationAndDeallocation(S, Loc, &Fn, this->Allocate, + this->Deallocate); +} + +bool SubStmtBuilder::makeOnFallthrough() { + if (!PromiseRecordDecl) + return true; + + // [dcl.fct.def.coroutine]/4 + // The unqualified-ids 'return_void' and 'return_value' are looked up in + // the scope of class P. If both are found, the program is ill-formed. + DeclarationName RVoidDN = S.PP.getIdentifierInfo("return_void"); + LookupResult RVoidResult(S, RVoidDN, Loc, Sema::LookupMemberName); + const bool HasRVoid = S.LookupQualifiedName(RVoidResult, PromiseRecordDecl); - // control flowing off the end of the coroutine. - // Also try to form 'p.set_exception(std::current_exception());' to handle + DeclarationName RValueDN = S.PP.getIdentifierInfo("return_value"); + LookupResult RValueResult(S, RValueDN, Loc, Sema::LookupMemberName); + const bool HasRValue = S.LookupQualifiedName(RValueResult, PromiseRecordDecl); + + StmtResult Fallthrough; + if (HasRVoid && HasRValue) { + // FIXME Improve this diagnostic + S.Diag(FD.getLocation(), diag::err_coroutine_promise_return_ill_formed) + << PromiseRecordDecl; + return false; + } else if (HasRVoid) { + // If the unqualified-id return_void is found, flowing off the end of a + // coroutine is equivalent to a co_return with no operand. Otherwise, + // flowing off the end of a coroutine results in undefined behavior. + Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr); + Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get()); + if (Fallthrough.isInvalid()) + return false; + } + + this->OnFallthrough = Fallthrough.get(); + return true; +} + +bool SubStmtBuilder::makeOnException() { + // Try to form 'p.set_exception(std::current_exception());' to handle // uncaught exceptions. + // TODO: Post WG21 Issaquah 2016 renamed set_exception to unhandled_exception + // TODO: and dropped exception_ptr parameter. Make it so. + + if (!PromiseRecordDecl) + return true; + + // If exceptions are disabled, don't try to build OnException. + if (!S.getLangOpts().CXXExceptions) + return true; + ExprResult SetException; - StmtResult Fallthrough; - if (Fn->CoroutinePromise && - !Fn->CoroutinePromise->getType()->isDependentType()) { - CXXRecordDecl *RD = Fn->CoroutinePromise->getType()->getAsCXXRecordDecl(); - assert(RD && "Type should have already been checked"); - // [dcl.fct.def.coroutine]/4 - // The unqualified-ids 'return_void' and 'return_value' are looked up in - // the scope of class P. If both are found, the program is ill-formed. - DeclarationName RVoidDN = PP.getIdentifierInfo("return_void"); - LookupResult RVoidResult(*this, RVoidDN, Loc, Sema::LookupMemberName); - const bool HasRVoid = LookupQualifiedName(RVoidResult, RD); - - DeclarationName RValueDN = PP.getIdentifierInfo("return_value"); - LookupResult RValueResult(*this, RValueDN, Loc, Sema::LookupMemberName); - const bool HasRValue = LookupQualifiedName(RValueResult, RD); - - if (HasRVoid && HasRValue) { - // FIXME Improve this diagnostic - Diag(FD->getLocation(), diag::err_coroutine_promise_return_ill_formed) - << RD; - return FD->setInvalidDecl(); - } else if (HasRVoid) { - // If the unqualified-id return_void is found, flowing off the end of a - // coroutine is equivalent to a co_return with no operand. Otherwise, - // flowing off the end of a coroutine results in undefined behavior. - Fallthrough = BuildCoreturnStmt(FD->getLocation(), nullptr); - Fallthrough = ActOnFinishFullStmt(Fallthrough.get()); - if (Fallthrough.isInvalid()) - return FD->setInvalidDecl(); - } - // [dcl.fct.def.coroutine]/3 - // The unqualified-id set_exception is found in the scope of P by class - // member access lookup (3.4.5). - DeclarationName SetExDN = PP.getIdentifierInfo("set_exception"); - LookupResult SetExResult(*this, SetExDN, Loc, Sema::LookupMemberName); - if (LookupQualifiedName(SetExResult, RD)) { - // Form the call 'p.set_exception(std::current_exception())' - SetException = buildStdCurrentExceptionCall(*this, Loc); - if (SetException.isInvalid()) - return FD->setInvalidDecl(); - Expr *E = SetException.get(); - SetException = buildPromiseCall(*this, Fn, Loc, "set_exception", E); - SetException = ActOnFinishFullExpr(SetException.get(), Loc); - if (SetException.isInvalid()) - return FD->setInvalidDecl(); - } + // [dcl.fct.def.coroutine]/3 + // The unqualified-id set_exception is found in the scope of P by class + // member access lookup (3.4.5). + DeclarationName SetExDN = S.PP.getIdentifierInfo("set_exception"); + LookupResult SetExResult(S, SetExDN, Loc, Sema::LookupMemberName); + if (S.LookupQualifiedName(SetExResult, PromiseRecordDecl)) { + // Form the call 'p.set_exception(std::current_exception())' + SetException = buildStdCurrentExceptionCall(S, Loc); + if (SetException.isInvalid()) + return false; + Expr *E = SetException.get(); + SetException = buildPromiseCall(S, &Fn, Loc, "set_exception", E); + SetException = S.ActOnFinishFullExpr(SetException.get(), Loc); + if (SetException.isInvalid()) + return false; } + this->OnException = SetException.get(); + return true; +} + +bool SubStmtBuilder::makeReturnObject() { + // Build implicit 'p.get_return_object()' expression and form initialization // of return type from it. ExprResult ReturnObject = - buildPromiseCall(*this, Fn, Loc, "get_return_object", None); + buildPromiseCall(S, &Fn, Loc, "get_return_object", None); if (ReturnObject.isInvalid()) - return FD->setInvalidDecl(); - QualType RetType = FD->getReturnType(); + return false; + QualType RetType = FD.getReturnType(); if (!RetType->isDependentType()) { InitializedEntity Entity = InitializedEntity::InitializeResult(Loc, RetType, false); - ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType, + ReturnObject = S.PerformMoveOrCopyInitialization(Entity, nullptr, RetType, ReturnObject.get()); if (ReturnObject.isInvalid()) - return FD->setInvalidDecl(); + return false; } - ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc); + ReturnObject = S.ActOnFinishFullExpr(ReturnObject.get(), Loc); if (ReturnObject.isInvalid()) - return FD->setInvalidDecl(); + return false; - // FIXME: Perform move-initialization of parameters into frame-local copies. - SmallVector ParamMoves; + this->ReturnValue = ReturnObject.get(); + return true; +} - // Build body for the coroutine wrapper statement. - Body = new (Context) CoroutineBodyStmt( - Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(), - SetException.get(), Fallthrough.get(), Allocation, Deallocation, - ReturnObject.get(), ParamMoves); +bool SubStmtBuilder::makeParamMoves() { + // FIXME: Perform move-initialization of parameters into frame-local copies. + return true; } Index: test/SemaCXX/coroutines.cpp =================================================================== --- test/SemaCXX/coroutines.cpp +++ test/SemaCXX/coroutines.cpp @@ -1,4 +1,4 @@ -// RUN: %clang_cc1 -std=c++14 -fcoroutines-ts -verify %s +// RUN: %clang_cc1 -std=c++14 -fcoroutines-ts -verify %s -fcxx-exceptions void no_coroutine_traits_bad_arg_await() { co_await a; // expected-error {{include }}