Index: lib/Sema/SemaCoroutine.cpp =================================================================== --- lib/Sema/SemaCoroutine.cpp +++ lib/Sema/SemaCoroutine.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "CoroutineStmtBuilder.h" +#include "clang/AST/ASTLambda.h" #include "clang/AST/Decl.h" #include "clang/AST/ExprCXX.h" #include "clang/AST/StmtCXX.h" @@ -506,24 +507,15 @@ auto RefExpr = ExprEmpty(); auto Move = Moves.find(PD); - if (Move != Moves.end()) { - // If a reference to the function parameter exists in the coroutine - // frame, use that reference. - auto *MoveDecl = - cast(cast(Move->second)->getSingleDecl()); - RefExpr = BuildDeclRefExpr(MoveDecl, MoveDecl->getType(), - ExprValueKind::VK_LValue, FD->getLocation()); - } else { - // If the function parameter doesn't exist in the coroutine frame, it - // must be a scalar value. Use it directly. - assert(!PD->getType()->getAsCXXRecordDecl() && - "Non-scalar types should have been moved and inserted into the " - "parameter moves map"); - RefExpr = - BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(), - ExprValueKind::VK_LValue, FD->getLocation()); - } - + assert(Move != Moves.end() && + "Coroutine function parameter not inserted into move map"); + // If a reference to the function parameter exists in the coroutine + // frame, use that reference. + auto *MoveDecl = + cast(cast(Move->second)->getSingleDecl()); + RefExpr = + BuildDeclRefExpr(MoveDecl, MoveDecl->getType().getNonReferenceType(), + ExprValueKind::VK_LValue, FD->getLocation()); if (RefExpr.isInvalid()) return nullptr; CtorArgExprs.push_back(RefExpr.get()); @@ -1050,7 +1042,12 @@ const bool RequiresNoThrowAlloc = ReturnStmtOnAllocFailure != nullptr; - // FIXME: Add support for stateful allocators. + // [dcl.fct.def.coroutine]/7 + // Lookup allocation functions using a parameter list composed of the + // requested size of the coroutine state being allocated, followed by + // the coroutine function's arguments. If a matching allocation function + // exists, use it. Otherwise, use an allocation function that just takes + // the requested size. FunctionDecl *OperatorNew = nullptr; FunctionDecl *OperatorDelete = nullptr; @@ -1058,10 +1055,62 @@ bool PassAlignment = false; SmallVector PlacementArgs; + // [dcl.fct.def.coroutine]/7 + // "The allocation function’s name is looked up in the scope of P. + // [...] If the lookup finds an allocation function in the scope of P, + // overload resolution is performed on a function call created by assembling + // an argument list. The first argument is the amount of space requested, + // and has type std::size_t. The lvalues p1 ... pn are the succeeding + // arguments." + // + // ...where "p1 ... pn" are defined earlier as: + // + // [dcl.fct.def.coroutine]/3 + // "For a coroutine f that is a non-static member function, let P1 denote the + // type of the implicit object parameter (13.3.1) and P2 ... Pn be the types + // of the function parameters; otherwise let P1 ... Pn be the types of the + // function parameters. Let p1 ... pn be lvalues denoting those objects." + if (auto *MD = dyn_cast(&FD)) { + if (MD->isInstance() && !isLambdaCallOperator(MD)) { + ExprResult ThisExpr = S.ActOnCXXThis(Loc); + if (ThisExpr.isInvalid()) + return false; + ThisExpr = S.CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get()); + if (ThisExpr.isInvalid()) + return false; + PlacementArgs.push_back(ThisExpr.get()); + } + } + for (auto *PD : FD.parameters()) { + if (PD->getType()->isDependentType()) + continue; + + // Build a reference to the parameter. + auto PDLoc = PD->getLocation(); + ExprResult PDRefExpr = + S.BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(), + ExprValueKind::VK_LValue, PDLoc); + if (PDRefExpr.isInvalid()) + return false; + + PlacementArgs.push_back(PDRefExpr.get()); + } S.FindAllocationFunctions(Loc, SourceRange(), /*UseGlobal*/ false, PromiseType, /*isArray*/ false, PassAlignment, PlacementArgs, - OperatorNew, UnusedResult); + OperatorNew, UnusedResult, /*Diagnose*/ false); + + // [dcl.fct.def.coroutine]/7 + // "If no matching function is found, overload resolution is performed again + // on a function call created by passing just the amount of space required as + // an argument of type std::size_t." + if (!OperatorNew && !PlacementArgs.empty()) { + PlacementArgs.clear(); + S.FindAllocationFunctions(Loc, SourceRange(), + /*UseGlobal*/ false, PromiseType, + /*isArray*/ false, PassAlignment, + PlacementArgs, OperatorNew, UnusedResult); + } bool IsGlobalOverload = OperatorNew && !isa(OperatorNew->getDeclContext()); @@ -1080,7 +1129,8 @@ OperatorNew, UnusedResult); } - assert(OperatorNew && "expected definition of operator new to be found"); + if (!OperatorNew) + return false; if (RequiresNoThrowAlloc) { const auto *FT = OperatorNew->getType()->getAs(); @@ -1386,25 +1436,28 @@ if (PD->getType()->isDependentType()) continue; - // No need to copy scalars, LLVM will take care of them. - if (PD->getType()->getAsCXXRecordDecl()) { - ExprResult PDRefExpr = BuildDeclRefExpr( - PD, PD->getType(), ExprValueKind::VK_LValue, Loc); // FIXME: scope? - if (PDRefExpr.isInvalid()) - return false; + ExprResult PDRefExpr = + BuildDeclRefExpr(PD, PD->getType().getNonReferenceType(), + ExprValueKind::VK_LValue, Loc); // FIXME: scope? + if (PDRefExpr.isInvalid()) + return false; - Expr *CExpr = castForMoving(*this, PDRefExpr.get()); + Expr *CExpr = nullptr; + if (PD->getType()->getAsCXXRecordDecl() || + PD->getType()->isRValueReferenceType()) + CExpr = castForMoving(*this, PDRefExpr.get()); + else + CExpr = PDRefExpr.get(); - auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier()); - AddInitializerToDecl(D, CExpr, /*DirectInit=*/true); + auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier()); + AddInitializerToDecl(D, CExpr, /*DirectInit=*/true); - // Convert decl to a statement. - StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc); - if (Stmt.isInvalid()) - return false; + // Convert decl to a statement. + StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc); + if (Stmt.isInvalid()) + return false; - ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get())); - } + ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get())); } return true; } Index: test/CodeGenCoroutines/coro-alloc.cpp =================================================================== --- test/CodeGenCoroutines/coro-alloc.cpp +++ test/CodeGenCoroutines/coro-alloc.cpp @@ -106,6 +106,34 @@ co_return; } +struct promise_matching_placement_new_tag {}; + +template<> +struct std::experimental::coroutine_traits { + struct promise_type { + void *operator new(unsigned long, promise_matching_placement_new_tag, + int, float, double); + void get_return_object() {} + suspend_always initial_suspend() { return {}; } + suspend_always final_suspend() { return {}; } + void return_void() {} + }; +}; + +// CHECK-LABEL: f1a( +extern "C" void f1a(promise_matching_placement_new_tag, int x, float y , double z) { + // CHECK: store i32 %x, i32* %x.addr, align 4 + // CHECK: store float %y, float* %y.addr, align 4 + // CHECK: store double %z, double* %z.addr, align 8 + // CHECK: %[[ID:.+]] = call token @llvm.coro.id(i32 16 + // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64() + // CHECK: %[[INT:.+]] = load i32, i32* %x.addr, align 4 + // CHECK: %[[FLOAT:.+]] = load float, float* %y.addr, align 4 + // CHECK: %[[DOUBLE:.+]] = load double, double* %z.addr, align 8 + // CHECK: call i8* @_ZNSt12experimental16coroutine_traitsIJv34promise_matching_placement_new_tagifdEE12promise_typenwEmS1_ifd(i64 %[[SIZE]], i32 %[[INT]], float %[[FLOAT]], double %[[DOUBLE]]) + co_return; +} + struct promise_delete_tag {}; template<> Index: test/CodeGenCoroutines/coro-gro-nrvo.cpp =================================================================== --- test/CodeGenCoroutines/coro-gro-nrvo.cpp +++ test/CodeGenCoroutines/coro-gro-nrvo.cpp @@ -41,7 +41,7 @@ // CHECK: {{.*}}[[CoroInit]]: // CHECK: store i1 false, i1* %gro.active -// CHECK-NEXT: call void @{{.*get_return_objectEv}}(%struct.coro* sret %agg.result +// CHECK: call void @{{.*get_return_objectEv}}(%struct.coro* sret %agg.result // CHECK-NEXT: store i1 true, i1* %gro.active co_return; } @@ -78,7 +78,7 @@ // CHECK: {{.*}}[[InitOnSuccess]]: // CHECK: store i1 false, i1* %gro.active -// CHECK-NEXT: call void @{{.*get_return_objectEv}}(%struct.coro_two* sret %agg.result +// CHECK: call void @{{.*get_return_objectEv}}(%struct.coro_two* sret %agg.result // CHECK-NEXT: store i1 true, i1* %gro.active // CHECK: [[RetLabel]]: Index: test/CodeGenCoroutines/coro-params.cpp =================================================================== --- test/CodeGenCoroutines/coro-params.cpp +++ test/CodeGenCoroutines/coro-params.cpp @@ -69,12 +69,12 @@ // CHECK: store i32 %val, i32* %[[ValAddr:.+]] // CHECK: call i8* @llvm.coro.begin( - // CHECK-NEXT: call void @_ZN8MoveOnlyC1EOS_(%struct.MoveOnly* %[[MoCopy]], %struct.MoveOnly* dereferenceable(4) %[[MoParam]]) + // CHECK: call void @_ZN8MoveOnlyC1EOS_(%struct.MoveOnly* %[[MoCopy]], %struct.MoveOnly* dereferenceable(4) %[[MoParam]]) // CHECK-NEXT: call void @_ZN11MoveAndCopyC1EOS_(%struct.MoveAndCopy* %[[McCopy]], %struct.MoveAndCopy* dereferenceable(4) %[[McParam]]) # // CHECK-NEXT: invoke void @_ZNSt12experimental16coroutine_traitsIJvi8MoveOnly11MoveAndCopyEE12promise_typeC1Ev( // CHECK: call void @_ZN14suspend_always12await_resumeEv( - // CHECK: %[[IntParam:.+]] = load i32, i32* %val.addr + // CHECK: %[[IntParam:.+]] = load i32, i32* %val1 // CHECK: %[[MoGep:.+]] = getelementptr inbounds %struct.MoveOnly, %struct.MoveOnly* %[[MoCopy]], i32 0, i32 0 // CHECK: %[[MoVal:.+]] = load i32, i32* %[[MoGep]] // CHECK: %[[McGep:.+]] = getelementptr inbounds %struct.MoveAndCopy, %struct.MoveAndCopy* %[[McCopy]], i32 0, i32 0 @@ -150,9 +150,9 @@ // CHECK-LABEL: void @_Z38coroutine_matching_promise_constructor28promise_matching_constructorifd(i32, float, double) void coroutine_matching_promise_constructor(promise_matching_constructor, int, float, double) { - // CHECK: %[[INT:.+]] = load i32, i32* %.addr, align 4 - // CHECK: %[[FLOAT:.+]] = load float, float* %.addr1, align 4 - // CHECK: %[[DOUBLE:.+]] = load double, double* %.addr2, align 8 + // CHECK: %[[INT:.+]] = load i32, i32* %5, align 4 + // CHECK: %[[FLOAT:.+]] = load float, float* %6, align 4 + // CHECK: %[[DOUBLE:.+]] = load double, double* %7, align 8 // CHECK: invoke void @_ZNSt12experimental16coroutine_traitsIJv28promise_matching_constructorifdEE12promise_typeC1ES1_ifd(%"struct.std::experimental::coroutine_traits::promise_type"* %__promise, i32 %[[INT]], float %[[FLOAT]], double %[[DOUBLE]]) co_return; } Index: test/SemaCXX/coroutines.cpp =================================================================== --- test/SemaCXX/coroutines.cpp +++ test/SemaCXX/coroutines.cpp @@ -781,6 +781,102 @@ } template coro dependent_uses_nothrow_new(good_promise_13); +struct good_promise_custom_new_operator { + coro get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); + void return_void(); + void unhandled_exception(); + void *operator new(unsigned long, double, float, int); +}; + +coro +good_coroutine_calls_custom_new_operator(double, float, int) { + co_return; +} + +struct coroutine_nonstatic_member_struct; + +struct good_promise_nonstatic_member_custom_new_operator { + coro get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); + void return_void(); + void unhandled_exception(); + void *operator new(unsigned long, coroutine_nonstatic_member_struct &, double); +}; + +struct bad_promise_nonstatic_member_mismatched_custom_new_operator { + coro get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); + void return_void(); + void unhandled_exception(); + // expected-note@+1 {{candidate function not viable: requires 2 arguments, but 1 was provided}} + void *operator new(unsigned long, double); +}; + +struct coroutine_nonstatic_member_struct { + coro + good_coroutine_calls_nonstatic_member_custom_new_operator(double) { + co_return; + } + + coro + bad_coroutine_calls_nonstatic_member_mistmatched_custom_new_operator(double) { + // expected-error@-1 {{no matching function for call to 'operator new'}} + co_return; + } +}; + +struct bad_promise_mismatched_custom_new_operator { + coro get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); + void return_void(); + void unhandled_exception(); + // expected-note@+1 {{candidate function not viable: requires 4 arguments, but 1 was provided}} + void *operator new(unsigned long, double, float, int); +}; + +coro +bad_coroutine_calls_mismatched_custom_new_operator(double) { + // expected-error@-1 {{no matching function for call to 'operator new'}} + co_return; +} + +struct bad_promise_throwing_custom_new_operator { + static coro get_return_object_on_allocation_failure(); + coro get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); + void return_void(); + void unhandled_exception(); + // expected-error@+1 {{'operator new' is required to have a non-throwing noexcept specification when the promise type declares 'get_return_object_on_allocation_failure()'}} + void *operator new(unsigned long, double, float, int); +}; + +coro +bad_coroutine_calls_throwing_custom_new_operator(double, float, int) { + // expected-note@-1 {{call to 'operator new' implicitly required by coroutine function here}} + co_return; +} + +struct good_promise_noexcept_custom_new_operator { + static coro get_return_object_on_allocation_failure(); + coro get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); + void return_void(); + void unhandled_exception(); + void *operator new(unsigned long, double, float, int) noexcept; +}; + +coro +good_coroutine_calls_noexcept_custom_new_operator(double, float, int) { + co_return; +} + struct mismatch_gro_type_tag1 {}; template<> struct std::experimental::coroutine_traits {