Index: cfe/trunk/include/clang/AST/StmtCXX.h =================================================================== --- cfe/trunk/include/clang/AST/StmtCXX.h +++ cfe/trunk/include/clang/AST/StmtCXX.h @@ -309,6 +309,7 @@ Allocate, ///< Coroutine frame memory allocation. Deallocate, ///< Coroutine frame memory deallocation. ReturnValue, ///< Return value for thunk function. + ReturnStmtOnAllocFailure, ///< Return statement if allocation failed. FirstParamMove ///< First offset for move construction of parameter copies. }; unsigned NumParams; @@ -332,6 +333,7 @@ Expr *Allocate = nullptr; Expr *Deallocate = nullptr; Stmt *ReturnValue = nullptr; + Stmt *ReturnStmtOnAllocFailure = nullptr; ArrayRef ParamMoves; }; @@ -379,6 +381,9 @@ Expr *getReturnValueInit() const { return cast_or_null(getStoredStmts()[SubStmt::ReturnValue]); } + Stmt *getReturnStmtOnAllocFailure() const { + return getStoredStmts()[SubStmt::ReturnStmtOnAllocFailure]; + } ArrayRef getParamMoves() const { return {getStoredStmts() + SubStmt::FirstParamMove, NumParams}; } Index: cfe/trunk/include/clang/Basic/DiagnosticSemaKinds.td =================================================================== --- cfe/trunk/include/clang/Basic/DiagnosticSemaKinds.td +++ cfe/trunk/include/clang/Basic/DiagnosticSemaKinds.td @@ -8878,6 +8878,8 @@ def warn_coroutine_promise_unhandled_exception_required_with_exceptions : Warning< "%0 is required to declare the member 'unhandled_exception()' when exceptions are enabled">, InGroup; +def err_coroutine_promise_get_return_object_on_allocation_failure : Error< + "%0: 'get_return_object_on_allocation_failure()' must be a static member function">; } let CategoryName = "Documentation Issue" in { Index: cfe/trunk/lib/AST/StmtCXX.cpp =================================================================== --- cfe/trunk/lib/AST/StmtCXX.cpp +++ cfe/trunk/lib/AST/StmtCXX.cpp @@ -108,6 +108,8 @@ SubStmts[CoroutineBodyStmt::Allocate] = Args.Allocate; SubStmts[CoroutineBodyStmt::Deallocate] = Args.Deallocate; SubStmts[CoroutineBodyStmt::ReturnValue] = Args.ReturnValue; + SubStmts[CoroutineBodyStmt::ReturnStmtOnAllocFailure] = + Args.ReturnStmtOnAllocFailure; std::copy(Args.ParamMoves.begin(), Args.ParamMoves.end(), const_cast(getParamMoves().data())); } \ No newline at end of file Index: cfe/trunk/lib/CodeGen/CGCoroutine.cpp =================================================================== --- cfe/trunk/lib/CodeGen/CGCoroutine.cpp +++ cfe/trunk/lib/CodeGen/CGCoroutine.cpp @@ -229,7 +229,24 @@ createCoroData(*this, CurCoro, CoroId); CurCoro.Data->SuspendBB = RetBB; - EmitScalarExpr(S.getAllocate()); + auto *AllocateCall = EmitScalarExpr(S.getAllocate()); + + // Handle allocation failure if 'ReturnStmtOnAllocFailure' was provided. + if (auto *RetOnAllocFailure = S.getReturnStmtOnAllocFailure()) { + auto *RetOnFailureBB = createBasicBlock("coro.ret.on.failure"); + auto *InitBB = createBasicBlock("coro.init"); + + // See if allocation was successful. + auto *NullPtr = llvm::ConstantPointerNull::get(Int8PtrTy); + auto *Cond = Builder.CreateICmpNE(AllocateCall, NullPtr); + Builder.CreateCondBr(Cond, InitBB, RetOnFailureBB); + + // If not, return OnAllocFailure object. + EmitBlock(RetOnFailureBB); + EmitStmt(RetOnAllocFailure); + + EmitBlock(InitBB); + } // FIXME: Setup cleanup scopes. Index: cfe/trunk/lib/Sema/SemaCoroutine.cpp =================================================================== --- cfe/trunk/lib/Sema/SemaCoroutine.cpp +++ cfe/trunk/lib/Sema/SemaCoroutine.cpp @@ -708,8 +708,8 @@ } this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend() && makeOnException() && makeOnFallthrough() && - makeNewAndDeleteExpr() && makeReturnObject() && - makeParamMoves(); + makeReturnOnAllocFailure() && makeNewAndDeleteExpr() && + makeReturnObject() && makeParamMoves(); } bool isInvalid() const { return !this->IsValid; } @@ -720,6 +720,7 @@ bool makeOnFallthrough(); bool makeOnException(); bool makeReturnObject(); + bool makeReturnOnAllocFailure(); bool makeParamMoves(); }; } @@ -777,6 +778,66 @@ return true; } +static bool diagReturnOnAllocFailure(Sema &S, Expr *E, + CXXRecordDecl *PromiseRecordDecl, + FunctionScopeInfo &Fn) { + auto Loc = E->getExprLoc(); + if (auto *DeclRef = dyn_cast_or_null(E)) { + auto *Decl = DeclRef->getDecl(); + if (CXXMethodDecl *Method = dyn_cast_or_null(Decl)) { + if (Method->isStatic()) + return true; + else + Loc = Decl->getLocation(); + } + } + + S.Diag( + Loc, + diag::err_coroutine_promise_get_return_object_on_allocation_failure) + << PromiseRecordDecl; + S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) + << Fn.getFirstCoroutineStmtKeyword(); + return false; +} + +bool SubStmtBuilder::makeReturnOnAllocFailure() { + if (!PromiseRecordDecl) return true; + + // [dcl.fct.def.coroutine]/8 + // The unqualified-id get_return_object_on_allocation_failure is looked up in + // the scope of class P by class member access lookup (3.4.5). ... + // If an allocation function returns nullptr, ... the coroutine return value + // is obtained by a call to ... get_return_object_on_allocation_failure(). + + DeclarationName DN = + S.PP.getIdentifierInfo("get_return_object_on_allocation_failure"); + LookupResult Found(S, DN, Loc, Sema::LookupMemberName); + // Suppress diagnostics when a private member is selected. The same warnings + // will be produced again when building the call. + Found.suppressDiagnostics(); + if (!S.LookupQualifiedName(Found, PromiseRecordDecl)) return true; + + CXXScopeSpec SS; + ExprResult DeclNameExpr = + S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false); + if (DeclNameExpr.isInvalid()) return false; + + if (!diagReturnOnAllocFailure(S, DeclNameExpr.get(), PromiseRecordDecl, Fn)) + return false; + + ExprResult ReturnObjectOnAllocationFailure = + S.ActOnCallExpr(nullptr, DeclNameExpr.get(), Loc, {}, Loc); + if (ReturnObjectOnAllocationFailure.isInvalid()) return false; + + StmtResult ReturnStmt = S.ActOnReturnStmt( + Loc, ReturnObjectOnAllocationFailure.get(), S.getCurScope()); + if (ReturnStmt.isInvalid()) return false; + + this->ReturnStmtOnAllocFailure = ReturnStmt.get(); + return true; +} + bool SubStmtBuilder::makeNewAndDeleteExpr() { // Form and check allocation and deallocation calls. QualType PromiseType = Fn.CoroutinePromise->getType(); @@ -786,7 +847,8 @@ if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type)) return false; - // FIXME: Add support for get_return_object_on_allocation failure. + // FIXME: Add nothrow_t placement arg for global alloc + // if ReturnStmtOnAllocFailure != nullptr. // FIXME: Add support for stateful allocators. FunctionDecl *OperatorNew = nullptr; Index: cfe/trunk/test/CodeGenCoroutines/coro-alloc.cpp =================================================================== --- cfe/trunk/test/CodeGenCoroutines/coro-alloc.cpp +++ cfe/trunk/test/CodeGenCoroutines/coro-alloc.cpp @@ -40,7 +40,7 @@ }; }; -// CHECK-LABEL: f0( +// CHECK-LABEL: f0( extern "C" void f0(global_new_delete_tag) { // CHECK: %[[ID:.+]] = call token @llvm.coro.id(i32 16 // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64() @@ -65,7 +65,7 @@ }; }; -// CHECK-LABEL: f1( +// CHECK-LABEL: f1( extern "C" void f1(promise_new_tag ) { // CHECK: %[[ID:.+]] = call token @llvm.coro.id(i32 16 // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64() @@ -90,7 +90,7 @@ }; }; -// CHECK-LABEL: f2( +// CHECK-LABEL: f2( extern "C" void f2(promise_delete_tag) { // CHECK: %[[ID:.+]] = call token @llvm.coro.id(i32 16 // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64() @@ -127,3 +127,30 @@ // CHECK: call void @_ZNSt12experimental16coroutine_traitsIJv24promise_sized_delete_tagEE12promise_typedlEPvm(i8* %[[MEM]], i64 %[[SIZE2]]) co_return; } + +struct promise_on_alloc_failure_tag {}; + +template<> +struct std::experimental::coroutine_traits { + struct promise_type { + int get_return_object() {} + suspend_always initial_suspend() { return {}; } + suspend_always final_suspend() { return {}; } + void return_void() {} + static int get_return_object_on_allocation_failure() { return -1; } + }; +}; + +// CHECK-LABEL: f4( +extern "C" int f4(promise_on_alloc_failure_tag) { + // CHECK: %[[ID:.+]] = call token @llvm.coro.id(i32 16 + // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64() + // CHECK: %[[MEM:.+]] = call i8* @_Znwm(i64 %[[SIZE]]) + // CHECK: %[[OK:.+]] = icmp ne i8* %[[MEM]], null + // CHECK: br i1 %[[OK]], label %[[OKBB:.+]], label %[[ERRBB:.+]] + + // CHECK: [[ERRBB]]: + // CHECK: %[[RETVAL:.+]] = call i32 @_ZNSt12experimental16coroutine_traitsIJi28promise_on_alloc_failure_tagEE12promise_type39get_return_object_on_allocation_failureEv( + // CHECK: ret i32 %[[RETVAL]] + co_return; +} Index: cfe/trunk/test/SemaCXX/coroutines.cpp =================================================================== --- cfe/trunk/test/SemaCXX/coroutines.cpp +++ cfe/trunk/test/SemaCXX/coroutines.cpp @@ -634,3 +634,21 @@ //expected-note@-1 {{call to 'initial_suspend' implicitly required by the initial suspend point}} co_return; //expected-note {{function is a coroutine due to use of 'co_return' here}} } + +struct promise_on_alloc_failure_tag {}; + +template<> +struct std::experimental::coroutine_traits { + struct promise_type { + int get_return_object() {} + suspend_always initial_suspend() { return {}; } + suspend_always final_suspend() { return {}; } + void return_void() {} + int get_return_object_on_allocation_failure(); // expected-error{{'promise_type': 'get_return_object_on_allocation_failure()' must be a static member function}} + void unhandled_exception(); + }; +}; + +extern "C" int f(promise_on_alloc_failure_tag) { + co_return; //expected-note {{function is a coroutine due to use of 'co_return' here}} +}