Index: include/clang/AST/StmtCXX.h =================================================================== --- include/clang/AST/StmtCXX.h +++ include/clang/AST/StmtCXX.h @@ -304,6 +304,8 @@ FinalSuspend, ///< The final suspend statement, run after the body. OnException, ///< Handler for exceptions thrown in the body. OnFallthrough, ///< Handler for control flow falling off the body. + Allocate, ///< Coroutine frame memory allocation. + Deallocate, ///< Coroutine frame memory deallocation. ReturnValue, ///< Return value for thunk function. FirstParamMove ///< First offset for move construction of parameter copies. }; @@ -312,7 +314,8 @@ friend class ASTStmtReader; public: CoroutineBodyStmt(Stmt *Body, Stmt *Promise, Stmt *InitSuspend, - Stmt *FinalSuspend, Stmt *OnException, Stmt *OnFallthrough, + LabelStmt *FinalSuspend, Stmt *OnException, + Stmt *OnFallthrough, Expr *Allocate, LabelStmt *Deallocate, Expr *ReturnValue, ArrayRef ParamMoves) : Stmt(CoroutineBodyStmtClass) { SubStmts[CoroutineBodyStmt::Body] = Body; @@ -321,6 +324,8 @@ SubStmts[CoroutineBodyStmt::FinalSuspend] = FinalSuspend; SubStmts[CoroutineBodyStmt::OnException] = OnException; SubStmts[CoroutineBodyStmt::OnFallthrough] = OnFallthrough; + SubStmts[CoroutineBodyStmt::Allocate] = Allocate; + SubStmts[CoroutineBodyStmt::Deallocate] = Deallocate; SubStmts[CoroutineBodyStmt::ReturnValue] = ReturnValue; // FIXME: Tail-allocate space for parameter move expressions and store them. assert(ParamMoves.empty() && "not implemented yet"); @@ -338,13 +343,20 @@ } Stmt *getInitSuspendStmt() const { return SubStmts[SubStmt::InitSuspend]; } - Stmt *getFinalSuspendStmt() const { return SubStmts[SubStmt::FinalSuspend]; } + LabelStmt *getFinalSuspendStmt() const { + return cast(SubStmts[SubStmt::FinalSuspend]); + } Stmt *getExceptionHandler() const { return SubStmts[SubStmt::OnException]; } Stmt *getFallthroughHandler() const { return SubStmts[SubStmt::OnFallthrough]; } + Expr *getAllocate() const { return cast(SubStmts[SubStmt::Allocate]); } + LabelStmt *getDeallocate() const { + return cast(SubStmts[SubStmt::Deallocate]); + } + Expr *getReturnValueInit() const { return cast(SubStmts[SubStmt::ReturnValue]); } Index: include/clang/Basic/DiagnosticSemaKinds.td =================================================================== --- include/clang/Basic/DiagnosticSemaKinds.td +++ include/clang/Basic/DiagnosticSemaKinds.td @@ -8567,10 +8567,6 @@ "'main' cannot be a coroutine">; def err_coroutine_varargs : Error< "'%0' cannot be used in a varargs function">; -def ext_coroutine_without_co_await_co_yield : ExtWarn< - "'co_return' used in a function " - "that uses neither 'co_await' nor 'co_yield'">, - InGroup>; def err_implied_std_coroutine_traits_not_found : Error< "you need to include before defining a coroutine">; def err_malformed_std_coroutine_traits : Error< Index: lib/CodeGen/CGCoroutine.cpp =================================================================== --- lib/CodeGen/CGCoroutine.cpp +++ lib/CodeGen/CGCoroutine.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "CodeGenFunction.h" +#include "clang/AST/StmtCXX.h" using namespace clang; using namespace CodeGen; @@ -58,6 +59,27 @@ return true; } +void CodeGenFunction::EmitCoreturnStmt(const CoreturnStmt &S) { + EmitStmt(S.getPromiseCall()); + // FIXME: Jump to final suspend label. +} + +void CodeGenFunction::EmitCoroutineBody(const CoroutineBodyStmt &S) { + auto *NullPtr = llvm::ConstantPointerNull::get(Builder.getInt8PtrTy()); + // FIXME: Instead of 0, pass an equivalent of alignas(maxalign_t). + auto *CoroId = + Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::coro_id), + {Builder.getInt32(0), NullPtr, NullPtr, NullPtr}); + if (!createCoroData(*this, CurCoro, CoroId, nullptr)) { + // User inserted __builtin_coro_id by hand. Should not try to emit anything. + return; + } + + EmitScalarExpr(S.getAllocate()); + // FIXME: Emit the rest of the coroutine. + EmitStmt(S.getDeallocate()); +} + // Emit coroutine intrinsic and patch up arguments of the token type. RValue CodeGenFunction::EmitCoroutineIntrinsic(const CallExpr *E, unsigned int IID) { Index: lib/CodeGen/CGStmt.cpp =================================================================== --- lib/CodeGen/CGStmt.cpp +++ lib/CodeGen/CGStmt.cpp @@ -142,9 +142,11 @@ case Stmt::GCCAsmStmtClass: // Intentional fall-through. case Stmt::MSAsmStmtClass: EmitAsmStmt(cast(*S)); break; case Stmt::CoroutineBodyStmtClass: - case Stmt::CoreturnStmtClass: - CGM.ErrorUnsupported(S, "coroutine"); + EmitCoroutineBody(cast(*S)); break; + case Stmt::CoreturnStmtClass: + EmitCoreturnStmt(cast(*S)); + break; case Stmt::CapturedStmtClass: { const CapturedStmt *CS = cast(S); EmitCapturedStmt(*CS, CS->getCapturedRegionKind()); Index: lib/CodeGen/CodeGenFunction.h =================================================================== --- lib/CodeGen/CodeGenFunction.h +++ lib/CodeGen/CodeGenFunction.h @@ -2301,6 +2301,8 @@ void EmitObjCAtSynchronizedStmt(const ObjCAtSynchronizedStmt &S); void EmitObjCAutoreleasePoolStmt(const ObjCAutoreleasePoolStmt &S); + void EmitCoroutineBody(const CoroutineBodyStmt &S); + void EmitCoreturnStmt(const CoreturnStmt &S); RValue EmitCoroutineIntrinsic(const CallExpr *E, unsigned int IID); void EnterCXXTryStmt(const CXXTryStmt &S, bool IsFnTryBlock = false); Index: lib/Sema/SemaCoroutine.cpp =================================================================== --- lib/Sema/SemaCoroutine.cpp +++ lib/Sema/SemaCoroutine.cpp @@ -378,6 +378,143 @@ return Res; } +static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID id, + MutableArrayRef CallArgs) { + StringRef Name = S.Context.BuiltinInfo.getName(id); + LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName); + S.LookupName(R, S.TUScope, true); + + FunctionDecl *BuiltInDecl = R.getAsSingle(); + assert(BuiltInDecl && "failed to find builtin declaration"); + + ExprResult DeclRef = S.BuildDeclRefExpr(BuiltInDecl, BuiltInDecl->getType(), + VK_RValue, Loc, nullptr); + assert(DeclRef.isUsable() && "Builtin reference cannot fail"); + + ExprResult Call = + S.ActOnCallExpr(/*Scope=*/nullptr, DeclRef.get(), Loc, CallArgs, Loc); + + assert(!Call.isInvalid() && "Call to builtin cannot fail!"); + return Call.get(); +} + +// Find an appropriate delete for the promise. +static FunctionDecl *findDeleteForPromise(Sema &S, SourceLocation Loc, + QualType PromiseType) { + FunctionDecl *OperatorDelete = nullptr; + + DeclarationName DeleteName = + S.Context.DeclarationNames.getCXXOperatorName(OO_Delete); + + CXXRecordDecl *PointeeRD = PromiseType->getAsCXXRecordDecl(); + assert(PointeeRD && "PromiseType must be a CxxRecordDecl type"); + + if (S.FindDeallocationFunction(Loc, PointeeRD, DeleteName, OperatorDelete)) + return nullptr; + + if (!OperatorDelete) { + // Look for a global declaration. + OperatorDelete = S.FindUsualDeallocationFunction( + Loc, S.isCompleteType(Loc, PromiseType), DeleteName); + + S.MarkFunctionReferenced(Loc, OperatorDelete); + } + return OperatorDelete; +} + +// Builds allocation and deallocation for the coroutine. Returns false on +// failure. +static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc, + FunctionScopeInfo *Fn, + Expr *&Allocation, + LabelStmt *&Deallocation) { + TypeSourceInfo *TInfo = Fn->CoroutinePromise->getTypeSourceInfo(); + QualType PromiseType = TInfo->getType(); + if (PromiseType->isDependentType()) + return true; + + if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type)) + return false; + + // FIXME: Add support for get_return_object_on_allocation failure. + // FIXME: Add support for stateful allocators. + + FunctionDecl *OperatorNew = nullptr; + FunctionDecl *OperatorDelete = nullptr; + FunctionDecl *UnusedResult = nullptr; + + S.FindAllocationFunctions(Loc, SourceRange(), + /*UseGlobal*/ false, PromiseType, + /*isArray*/ false, /*PlacementArgs*/ None, + OperatorNew, UnusedResult); + + OperatorDelete = findDeleteForPromise(S, Loc, PromiseType); + + if (!OperatorDelete || !OperatorNew) + return false; + + Expr *FramePtr = + buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {}); + + Expr *FrameSize = + buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_size, {}); + + // Make new call. + + ExprResult NewRef = + S.BuildDeclRefExpr(OperatorNew, OperatorNew->getType(), VK_LValue, Loc); + if (NewRef.isInvalid()) + return false; + + ExprResult NewExpr = S.ActOnCallExpr(S.getCurScope(), NewRef.get(), Loc, + FrameSize, Loc, nullptr); + if (NewExpr.isInvalid()) + return false; + + Allocation = NewExpr.get(); + + // Make delete call. + + QualType opDeleteQualType = OperatorDelete->getType(); + + ExprResult DeleteRef = + S.BuildDeclRefExpr(OperatorDelete, opDeleteQualType, VK_LValue, Loc); + if (DeleteRef.isInvalid()) + return false; + + Expr *CoroFree = + buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_free, {FramePtr}); + + SmallVector DeleteArgs{CoroFree}; + + // Check if we need to pass the size. + const FunctionProtoType *opDeleteType = + opDeleteQualType.getTypePtr()->getAs(); + if (opDeleteType->getNumParams() > 1) { + DeleteArgs.push_back(FrameSize); + } + + ExprResult DeleteExpr = S.ActOnCallExpr(S.getCurScope(), DeleteRef.get(), Loc, + DeleteArgs, Loc, nullptr); + if (DeleteExpr.isInvalid()) + return false; + + // Make it a labeled statement. Suspend point emission uses this label as a + // jump target for the cleanup branch. + LabelDecl *DestroyLabel = + LabelDecl::Create(S.Context, S.CurContext, SourceLocation(), + S.PP.getIdentifierInfo("coro.destroy.label")); + + StmtResult Stmt = S.ActOnLabelStmt(Loc, DestroyLabel, Loc, DeleteExpr.get()); + + if (Stmt.isInvalid()) + return false; + + Deallocation = cast(Stmt.get()); + + return true; +} + void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { FunctionScopeInfo *Fn = getCurFunction(); assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine"); @@ -388,21 +525,9 @@ Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine); auto *First = Fn->CoroutineStmts[0]; Diag(First->getLocStart(), diag::note_declared_coroutine_here) - << (isa(First) ? 0 : - isa(First) ? 1 : 2); - } - - bool AnyCoawaits = false; - bool AnyCoyields = false; - for (auto *CoroutineStmt : Fn->CoroutineStmts) { - AnyCoawaits |= isa(CoroutineStmt); - AnyCoyields |= isa(CoroutineStmt); + << (isa(First) ? 0 : isa(First) ? 1 : 2); } - if (!AnyCoawaits && !AnyCoyields) - Diag(Fn->CoroutineStmts.front()->getLocStart(), - diag::ext_coroutine_without_co_await_co_yield); - SourceLocation Loc = FD->getLocation(); // Form a declaration statement for the promise declaration, so that AST @@ -432,6 +557,22 @@ if (FinalSuspend.isInvalid()) return FD->setInvalidDecl(); + // Add a label to a final suspend. It will be the jump target for co_return + // statements. + LabelDecl *FinalLabel = + LabelDecl::Create(Context, CurContext, SourceLocation(), + PP.getIdentifierInfo("coro.final.label")); + StmtResult FinalSuspendWithLabel = + ActOnLabelStmt(Loc, FinalLabel, Loc, FinalSuspend.get()); + if (FinalSuspendWithLabel.isInvalid()) + return FD->setInvalidDecl(); + + // Build allocation function and deallocation expressions. + Expr *Allocation = nullptr; + LabelStmt *Deallocation = nullptr; + if (!buildAllocationAndDeallocation(*this, Loc, Fn, Allocation, Deallocation)) + return FD->setInvalidDecl(); + // FIXME: Perform analysis of set_exception call. // FIXME: Try to form 'p.return_void();' expression statement to handle @@ -440,7 +581,7 @@ // Build implicit 'p.get_return_object()' expression and form initialization // of return type from it. ExprResult ReturnObject = - buildPromiseCall(*this, Fn, Loc, "get_return_object", None); + buildPromiseCall(*this, Fn, Loc, "get_return_object", None); if (ReturnObject.isInvalid()) return FD->setInvalidDecl(); QualType RetType = FD->getReturnType(); @@ -457,11 +598,12 @@ return FD->setInvalidDecl(); // FIXME: Perform move-initialization of parameters into frame-local copies. - SmallVector ParamMoves; + SmallVector ParamMoves; // Build body for the coroutine wrapper statement. Body = new (Context) CoroutineBodyStmt( - Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(), - /*SetException*/nullptr, /*Fallthrough*/nullptr, - ReturnObject.get(), ParamMoves); + Body, PromiseStmt.get(), InitialSuspend.get(), + cast_or_null(FinalSuspendWithLabel.get()), + /*SetException*/ nullptr, /*Fallthrough*/ nullptr, Allocation, + Deallocation, ReturnObject.get(), ParamMoves); } Index: test/CodeGenCoroutines/coro-alloc.cpp =================================================================== --- /dev/null +++ test/CodeGenCoroutines/coro-alloc.cpp @@ -0,0 +1,118 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fcoroutines-ts -std=c++14 -emit-llvm %s -o - -disable-llvm-passes | FileCheck %s + +namespace std { +namespace experimental { +template +struct coroutine_traits; // expected-note {{declared here}} +} +} + +struct suspend_always { + bool await_ready() { return false; } + void await_suspend() {} + void await_resume() {} +}; + +struct global_new_delete_tag {}; + +template<> +struct std::experimental::coroutine_traits { + struct promise_type { + void get_return_object() {} + suspend_always initial_suspend() { return {}; } + suspend_always final_suspend() { return {}; } + void return_void() {} + }; +}; + +// CHECK-LABEL: f0( +extern "C" void f0(global_new_delete_tag) { + // CHECK: %[[ID:.+]] = call token @llvm.coro.id( + // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64() + // CHECK: call i8* @_Znwm(i64 %[[SIZE]]) + + // CHECK: coro.destroy.label: + // CHECK: %[[FRAME:.+]] = call i8* @llvm.coro.frame() + // CHECK: %[[MEM:.+]] = call i8* @llvm.coro.free(token %[[ID]], i8* %[[FRAME]]) + // CHECK: call void @_ZdlPv(i8* %[[MEM]]) + co_return; +} + +struct promise_new_tag {}; + +template<> +struct std::experimental::coroutine_traits { + struct promise_type { + void *operator new(unsigned long); + void get_return_object() {} + suspend_always initial_suspend() { return {}; } + suspend_always final_suspend() { return {}; } + void return_void() {} + }; +}; + +// CHECK-LABEL: f1( +extern "C" void f1(promise_new_tag ) { + // CHECK: %[[ID:.+]] = call token @llvm.coro.id( + // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64() + // CHECK: call i8* @_ZNSt12experimental16coroutine_traitsIJv15promise_new_tagEE12promise_typenwEm(i64 %[[SIZE]]) + + // CHECK: coro.destroy.label: + // CHECK: %[[FRAME:.+]] = call i8* @llvm.coro.frame() + // CHECK: %[[MEM:.+]] = call i8* @llvm.coro.free(token %[[ID]], i8* %[[FRAME]]) + // CHECK: call void @_ZdlPv(i8* %[[MEM]]) + co_return; +} + +struct promise_delete_tag {}; + +template<> +struct std::experimental::coroutine_traits { + struct promise_type { + void operator delete(void*); + void get_return_object() {} + suspend_always initial_suspend() { return {}; } + suspend_always final_suspend() { return {}; } + void return_void() {} + }; +}; + +// CHECK-LABEL: f2( +extern "C" void f2(promise_delete_tag) { + // CHECK: %[[ID:.+]] = call token @llvm.coro.id( + // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64() + // CHECK: call i8* @_Znwm(i64 %[[SIZE]]) + + // CHECK: coro.destroy.label: + // CHECK: %[[FRAME:.+]] = call i8* @llvm.coro.frame() + // CHECK: %[[MEM:.+]] = call i8* @llvm.coro.free(token %[[ID]], i8* %[[FRAME]]) + // CHECK: call void @_ZNSt12experimental16coroutine_traitsIJv18promise_delete_tagEE12promise_typedlEPv(i8* %[[MEM]]) + co_return; +} + +struct promise_sized_delete_tag {}; + +template<> +struct std::experimental::coroutine_traits { + struct promise_type { + void operator delete(void*, unsigned long); + void get_return_object() {} + suspend_always initial_suspend() { return {}; } + suspend_always final_suspend() { return {}; } + void return_void() {} + }; +}; + +// CHECK-LABEL: f3( +extern "C" void f3(promise_sized_delete_tag) { + // CHECK: %[[ID:.+]] = call token @llvm.coro.id( + // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64() + // CHECK: call i8* @_Znwm(i64 %[[SIZE]]) + + // CHECK: coro.destroy.label: + // CHECK: %[[FRAME:.+]] = call i8* @llvm.coro.frame() + // CHECK: %[[MEM:.+]] = call i8* @llvm.coro.free(token %[[ID]], i8* %[[FRAME]]) + // CHECK: %[[SIZE2:.+]] = call i64 @llvm.coro.size.i64() + // CHECK: call void @_ZNSt12experimental16coroutine_traitsIJv24promise_sized_delete_tagEE12promise_typedlEPvm(i8* %[[MEM]], i64 %[[SIZE2]]) + co_return; +} Index: test/SemaCXX/coroutines.cpp =================================================================== --- test/SemaCXX/coroutines.cpp +++ test/SemaCXX/coroutines.cpp @@ -143,13 +143,12 @@ } void only_coreturn() { - co_return; // expected-warning {{'co_return' used in a function that uses neither 'co_await' nor 'co_yield'}} + co_return; // OK } void mixed_coreturn(bool b) { if (b) - // expected-warning@+1 {{'co_return' used in a function that uses neither}} - co_return; // expected-note {{use of 'co_return'}} + co_return; // expected-note {{use of 'co_return' here}} else return; // expected-error {{not allowed in coroutine}} }