Index: include/clang/Sema/ScopeInfo.h =================================================================== --- include/clang/Sema/ScopeInfo.h +++ include/clang/Sema/ScopeInfo.h @@ -22,6 +22,7 @@ #include "clang/Sema/CleanupInfo.h" #include "clang/Sema/Ownership.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSwitch.h" @@ -172,6 +173,10 @@ /// \brief The promise object for this coroutine, if any. VarDecl *CoroutinePromise = nullptr; + /// \brief A mapping between the coroutine function parameters that were moved + /// to the coroutine frame, and their move statements. + llvm::SmallMapVector CoroutineParameterMoves; + /// \brief The initial and final coroutine suspend points. std::pair CoroutineSuspends; Index: include/clang/Sema/Sema.h =================================================================== --- include/clang/Sema/Sema.h +++ include/clang/Sema/Sema.h @@ -8478,6 +8478,7 @@ StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E, bool IsImplicit = false); StmtResult BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs); + bool buildCoroutineParameterMoves(SourceLocation Loc); VarDecl *buildCoroutinePromise(SourceLocation Loc); void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body); Index: lib/Sema/CoroutineStmtBuilder.h =================================================================== --- lib/Sema/CoroutineStmtBuilder.h +++ lib/Sema/CoroutineStmtBuilder.h @@ -51,9 +51,6 @@ /// name lookup. bool buildDependentStatements(); - /// \brief Build just parameter moves. To use for rebuilding in TreeTransform. - bool buildParameterMoves(); - bool isInvalid() const { return !this->IsValid; } private: @@ -65,7 +62,6 @@ bool makeReturnObject(); bool makeGroDeclAndReturnStmt(); bool makeReturnOnAllocFailure(); - bool makeParamMoves(); }; } // end namespace clang Index: lib/Sema/ScopeInfo.cpp =================================================================== --- lib/Sema/ScopeInfo.cpp +++ lib/Sema/ScopeInfo.cpp @@ -43,6 +43,7 @@ // Coroutine state FirstCoroutineStmtLoc = SourceLocation(); CoroutinePromise = nullptr; + CoroutineParameterMoves.clear(); NeedsCoroutineSuspends = true; CoroutineSuspends.first = nullptr; CoroutineSuspends.second = nullptr; Index: lib/Sema/SemaCoroutine.cpp =================================================================== --- lib/Sema/SemaCoroutine.cpp +++ lib/Sema/SemaCoroutine.cpp @@ -494,9 +494,67 @@ CheckVariableDeclarationType(VD); if (VD->isInvalidDecl()) return nullptr; - ActOnUninitializedDecl(VD); + + auto *ScopeInfo = getCurFunction(); + // Build a list of arguments, based on the coroutine functions arguments, + // that will be passed to the promise type's constructor. + llvm::SmallVector CtorArgExprs; + auto &Moves = ScopeInfo->CoroutineParameterMoves; + for (auto *PD : FD->parameters()) { + if (PD->getType()->isDependentType()) + continue; + + auto RefExpr = ExprEmpty(); + auto Move = Moves.find(PD); + if (Move != Moves.end()) { + // If a reference to the function parameter exists in the coroutine + // frame, use that reference. + auto *MoveDecl = + cast(cast(Move->second)->getSingleDecl()); + RefExpr = BuildDeclRefExpr(MoveDecl, MoveDecl->getType(), + ExprValueKind::VK_LValue, FD->getLocation()); + } else { + // If the function parameter doesn't exist in the coroutine frame, it + // must be a scalar value. Use it directly. + assert(!PD->getType()->getAsCXXRecordDecl() && + "Non-scalar types should have been moved and inserted into the " + "parameter moves map"); + RefExpr = + BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(), + ExprValueKind::VK_LValue, FD->getLocation()); + } + + if (RefExpr.isInvalid()) + return nullptr; + CtorArgExprs.push_back(RefExpr.get()); + } + + // Create an initialization sequence for the promise type using the + // constructor arguments, wrapped in a parenthesized list expression. + Expr *PLE = new (Context) ParenListExpr(Context, FD->getLocation(), + CtorArgExprs, FD->getLocation()); + InitializedEntity Entity = InitializedEntity::InitializeVariable(VD); + InitializationKind Kind = InitializationKind::CreateForInit( + VD->getLocation(), /*DirectInit=*/true, PLE); + InitializationSequence InitSeq(*this, Entity, Kind, CtorArgExprs, + /*TopLevelOfInitList=*/false, + /*TreatUnavailableAsInvalid=*/false); + + // Attempt to initialize the promise type with the arguments. + // If that fails, fall back to the promise type's default constructor. + if (InitSeq) { + ExprResult Result = InitSeq.Perform(*this, Entity, Kind, CtorArgExprs); + if (Result.isInvalid()) { + VD->setInvalidDecl(); + } else if (Result.get()) { + VD->setInit(MaybeCreateExprWithCleanups(Result.get())); + VD->setInitStyle(VarDecl::CallInit); + CheckCompleteVariableDeclaration(VD); + } + } else + ActOnUninitializedDecl(VD); + FD->addDecl(VD); - assert(!VD->isInvalidDecl()); return VD; } @@ -518,6 +576,9 @@ if (ScopeInfo->CoroutinePromise) return ScopeInfo; + if (!S.buildCoroutineParameterMoves(Loc)) + return nullptr; + ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc); if (!ScopeInfo->CoroutinePromise) return nullptr; @@ -861,6 +922,11 @@ !Fn.CoroutinePromise || Fn.CoroutinePromise->getType()->isDependentType()) { this->Body = Body; + + for (auto KV : Fn.CoroutineParameterMoves) + this->ParamMovesVector.push_back(KV.second); + this->ParamMoves = this->ParamMovesVector; + if (!IsPromiseDependentType) { PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl(); assert(PromiseRecordDecl && "Type should have already been checked"); @@ -870,7 +936,7 @@ bool CoroutineStmtBuilder::buildStatements() { assert(this->IsValid && "coroutine already invalid"); - this->IsValid = makeReturnObject() && makeParamMoves(); + this->IsValid = makeReturnObject(); if (this->IsValid && !IsPromiseDependentType) buildDependentStatements(); return this->IsValid; @@ -886,12 +952,6 @@ return this->IsValid; } -bool CoroutineStmtBuilder::buildParameterMoves() { - assert(this->IsValid && "coroutine already invalid"); - assert(this->ParamMoves.empty() && "param moves already built"); - return this->IsValid = makeParamMoves(); -} - bool CoroutineStmtBuilder::makePromiseStmt() { // Form a declaration statement for the promise declaration, so that AST // visitors can more easily find it. @@ -1304,47 +1364,50 @@ .get(); } - /// \brief Build a variable declaration for move parameter. static VarDecl *buildVarDecl(Sema &S, SourceLocation Loc, QualType Type, IdentifierInfo *II) { TypeSourceInfo *TInfo = S.Context.getTrivialTypeSourceInfo(Type, Loc); - VarDecl *Decl = - VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type, TInfo, SC_None); + VarDecl *Decl = VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type, + TInfo, SC_None); Decl->setImplicit(); return Decl; } -bool CoroutineStmtBuilder::makeParamMoves() { - for (auto *paramDecl : FD.parameters()) { - auto Ty = paramDecl->getType(); - if (Ty->isDependentType()) +// Build statements that move coroutine function parameters to the coroutine +// frame, and store them on the function scope info. +bool Sema::buildCoroutineParameterMoves(SourceLocation Loc) { + assert(isa(CurContext) && "not in a function scope"); + auto *FD = cast(CurContext); + + auto *ScopeInfo = getCurFunction(); + assert(ScopeInfo->CoroutineParameterMoves.empty() && + "Should not build parameter moves twice"); + + for (auto *PD : FD->parameters()) { + if (PD->getType()->isDependentType()) continue; - // No need to copy scalars, llvm will take care of them. - if (Ty->getAsCXXRecordDecl()) { - ExprResult ParamRef = - S.BuildDeclRefExpr(paramDecl, paramDecl->getType(), - ExprValueKind::VK_LValue, Loc); // FIXME: scope? - if (ParamRef.isInvalid()) + // No need to copy scalars, LLVM will take care of them. + if (PD->getType()->getAsCXXRecordDecl()) { + ExprResult PDRefExpr = BuildDeclRefExpr( + PD, PD->getType(), ExprValueKind::VK_LValue, Loc); // FIXME: scope? + if (PDRefExpr.isInvalid()) return false; - Expr *RCast = castForMoving(S, ParamRef.get()); + Expr *CExpr = castForMoving(*this, PDRefExpr.get()); - auto D = buildVarDecl(S, Loc, Ty, paramDecl->getIdentifier()); - S.AddInitializerToDecl(D, RCast, /*DirectInit=*/true); + auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier()); + AddInitializerToDecl(D, CExpr, /*DirectInit=*/true); // Convert decl to a statement. - StmtResult Stmt = S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(D), Loc, Loc); + StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc); if (Stmt.isInvalid()) return false; - ParamMovesVector.push_back(Stmt.get()); + ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get())); } } - - // Convert to ArrayRef in CtorArgs structure that builder inherits from. - ParamMoves = ParamMovesVector; return true; } Index: lib/Sema/TreeTransform.h =================================================================== --- lib/Sema/TreeTransform.h +++ lib/Sema/TreeTransform.h @@ -6944,6 +6944,8 @@ // The new CoroutinePromise object needs to be built and put into the current // FunctionScopeInfo before any transformations or rebuilding occurs. + if (!SemaRef.buildCoroutineParameterMoves(FD->getLocation())) + return StmtError(); auto *Promise = SemaRef.buildCoroutinePromise(FD->getLocation()); if (!Promise) return StmtError(); @@ -7034,8 +7036,6 @@ Builder.ReturnStmt = Res.get(); } } - if (!Builder.buildParameterMoves()) - return StmtError(); return getDerived().RebuildCoroutineBodyStmt(Builder); } Index: test/CodeGenCoroutines/coro-params.cpp =================================================================== --- test/CodeGenCoroutines/coro-params.cpp +++ test/CodeGenCoroutines/coro-params.cpp @@ -1,6 +1,7 @@ // Verifies that parameters are copied with move constructors // Verifies that parameter copies are destroyed // Vefifies that parameter copies are used in the body of the coroutine +// Verifies that parameter copies are used to construct the promise type, if that type has a matching constructor // RUN: %clang_cc1 -std=c++1z -fcoroutines-ts -triple=x86_64-unknown-linux-gnu -emit-llvm -o - %s -disable-llvm-passes -fexceptions | FileCheck %s namespace std::experimental { @@ -127,3 +128,31 @@ void call_dependent_params() { dependent_params(A{}, B{}, B{}); } + +// Test that, when the promise type has a constructor whose signature matches +// that of the coroutine function, that constructor is used. This is an +// experimental feature that will be proposed for the Coroutines TS. + +struct promise_matching_constructor {}; + +template<> +struct std::experimental::coroutine_traits { + struct promise_type { + promise_type(promise_matching_constructor, int, float, double) {} + promise_type() = delete; + void get_return_object() {} + suspend_always initial_suspend() { return {}; } + suspend_always final_suspend() { return {}; } + void return_void() {} + void unhandled_exception() {} + }; +}; + +// CHECK-LABEL: void @_Z38coroutine_matching_promise_constructor28promise_matching_constructorifd(i32, float, double) +void coroutine_matching_promise_constructor(promise_matching_constructor, int, float, double) { + // CHECK: %[[INT:.+]] = load i32, i32* %.addr, align 4 + // CHECK: %[[FLOAT:.+]] = load float, float* %.addr1, align 4 + // CHECK: %[[DOUBLE:.+]] = load double, double* %.addr2, align 8 + // CHECK: invoke void @_ZNSt12experimental16coroutine_traitsIJv28promise_matching_constructorifdEE12promise_typeC1ES1_ifd(%"struct.std::experimental::coroutine_traits::promise_type"* %__promise, i32 %[[INT]], float %[[FLOAT]], double %[[DOUBLE]]) + co_return; +} Index: test/SemaCXX/coroutines.cpp =================================================================== --- test/SemaCXX/coroutines.cpp +++ test/SemaCXX/coroutines.cpp @@ -1171,4 +1171,73 @@ template CoroMemberTag DepTestType::test_static_template(const char *volatile &, unsigned); +struct bad_promise_deleted_constructor { + // expected-note@+1 {{'bad_promise_deleted_constructor' has been explicitly marked deleted here}} + bad_promise_deleted_constructor() = delete; + coro get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); + void return_void(); + void unhandled_exception(); +}; + +coro +bad_coroutine_calls_deleted_promise_constructor() { + // expected-error@-1 {{call to deleted constructor of 'std::experimental::coroutine_traits>::promise_type' (aka 'CoroHandleMemberFunctionTest::bad_promise_deleted_constructor')}} + co_return; +} + +// Test that, when the promise type has a constructor whose signature matches +// that of the coroutine function, that constructor is used. If no matching +// constructor exists, the default constructor is used as a fallback. If no +// matching constructors exist at all, an error is emitted. This is an +// experimental feature that will be proposed for the Coroutines TS. + +struct good_promise_default_constructor { + good_promise_default_constructor(double, float, int); + good_promise_default_constructor() = default; + coro get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); + void return_void(); + void unhandled_exception(); +}; + +coro +good_coroutine_calls_default_constructor() { + co_return; +} + +struct good_promise_custom_constructor { + good_promise_custom_constructor(double, float, int); + good_promise_custom_constructor() = delete; + coro get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); + void return_void(); + void unhandled_exception(); +}; + +coro +good_coroutine_calls_custom_constructor(double, float, int) { + co_return; +} + +struct bad_promise_no_matching_constructor { + bad_promise_no_matching_constructor(int, int, int); + // expected-note@+1 {{'bad_promise_no_matching_constructor' has been explicitly marked deleted here}} + bad_promise_no_matching_constructor() = delete; + coro get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); + void return_void(); + void unhandled_exception(); +}; + +coro +bad_coroutine_calls_with_no_matching_constructor(int, int) { + // expected-error@-1 {{call to deleted constructor of 'std::experimental::coroutine_traits, int, int>::promise_type' (aka 'CoroHandleMemberFunctionTest::bad_promise_no_matching_constructor')}} + co_return; +} + } // namespace CoroHandleMemberFunctionTest