diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -219,6 +219,12 @@ - Fix crash when using ``[[clang::always_inline]]`` or ``[[clang::noinline]]`` statement attributes on a call to a template function in the body of a template function. +- Fix coroutines issue where ``get_return_object()`` result was always eargerly + converted to the return type. Eager initialization (allowing RVO) is now only + perfomed when these types match, otherwise deferred initialization is used, + enabling short-circuiting coroutines use cases. This fixes + (`#56532 `_) in + antecipation of `CWG2563 _`. Bug Fixes to Compiler Builtins ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/clang/include/clang/AST/StmtCXX.h b/clang/include/clang/AST/StmtCXX.h --- a/clang/include/clang/AST/StmtCXX.h +++ b/clang/include/clang/AST/StmtCXX.h @@ -411,9 +411,8 @@ return cast(getStoredStmts()[SubStmt::ReturnValue]); } Expr *getReturnValue() const { - assert(getReturnStmt()); - auto *RS = cast(getReturnStmt()); - return RS->getRetValue(); + auto *RS = dyn_cast_or_null(getReturnStmt()); + return RS ? RS->getRetValue() : nullptr; } Stmt *getReturnStmt() const { return getStoredStmts()[SubStmt::ReturnStmt]; } Stmt *getReturnStmtOnAllocFailure() const { diff --git a/clang/lib/CodeGen/CGCoroutine.cpp b/clang/lib/CodeGen/CGCoroutine.cpp --- a/clang/lib/CodeGen/CGCoroutine.cpp +++ b/clang/lib/CodeGen/CGCoroutine.cpp @@ -472,13 +472,41 @@ CodeGenFunction &CGF; CGBuilderTy &Builder; const CoroutineBodyStmt &S; + // When true, performs RVO for the return object. + bool DirectEmit = false; Address GroActiveFlag; CodeGenFunction::AutoVarEmission GroEmission; GetReturnObjectManager(CodeGenFunction &CGF, const CoroutineBodyStmt &S) : CGF(CGF), Builder(CGF.Builder), S(S), GroActiveFlag(Address::invalid()), - GroEmission(CodeGenFunction::AutoVarEmission::invalid()) {} + GroEmission(CodeGenFunction::AutoVarEmission::invalid()) { + // The call to get_­return_­object is sequenced before the call to + // initial_­suspend and is invoked at most once, but there are caveats + // regarding on whether the prvalue result object may be initialized + // directly/eager or delayed, depending on the types involved. + // + // More info at https://github.com/cplusplus/papers/issues/1414 + // + // The general cases: + // 1. Same type of get_return_object and coroutine return type (direct + // emission): + // - Constructed in the return slot. + // 2. Different types (delayed emission): + // - Constructed temporary object prior to initial suspend initialized with + // a call to get_return_object() + // - When coroutine needs to to return to the caller and needs to construct + // return value for the coroutine it is initialized with expiring value of + // the temporary obtained above. + // + // Direct emission for void returning coroutines or GROs. + DirectEmit = [&]() { + auto *RVI = S.getReturnValueInit(); + assert(RVI && "expected RVI"); + auto GroType = RVI->getType(); + return CGF.getContext().hasSameType(GroType, CGF.FnRetTy); + }(); + } // The gro variable has to outlive coroutine frame and coroutine promise, but, // it can only be initialized after coroutine promise was created, thus, we @@ -486,7 +514,10 @@ // cleanups. Later when coroutine promise is available we initialize the gro // and sets the flag that the cleanup is now active. void EmitGroAlloca() { - auto *GroDeclStmt = dyn_cast(S.getResultDecl()); + if (DirectEmit) + return; + + auto *GroDeclStmt = dyn_cast_or_null(S.getResultDecl()); if (!GroDeclStmt) { // If get_return_object returns void, no need to do an alloca. return; @@ -519,6 +550,27 @@ } void EmitGroInit() { + if (DirectEmit) { + // ReturnValue should be valid as long as the coroutine's return type + // is not void. The assertion could help us to reduce the check later. + assert(CGF.ReturnValue.isValid() == (bool)S.getReturnStmt()); + // Now we have the promise, initialize the GRO. + // We need to emit `get_return_object` first. According to: + // [dcl.fct.def.coroutine]p7 + // The call to get_return_­object is sequenced before the call to + // initial_suspend and is invoked at most once. + // + // So we couldn't emit return value when we emit return statment, + // otherwise the call to get_return_object wouldn't be in front + // of initial_suspend. + if (CGF.ReturnValue.isValid()) { + CGF.EmitAnyExprToMem(S.getReturnValue(), CGF.ReturnValue, + S.getReturnValue()->getType().getQualifiers(), + /*IsInit*/ true); + } + return; + } + if (!GroActiveFlag.isValid()) { // No Gro variable was allocated. Simply emit the call to // get_return_object. @@ -598,10 +650,6 @@ CGM.getIntrinsic(llvm::Intrinsic::coro_begin), {CoroId, Phi}); CurCoro.Data->CoroBegin = CoroBegin; - // We need to emit `get_­return_­object` first. According to: - // [dcl.fct.def.coroutine]p7 - // The call to get_­return_­object is sequenced before the call to - // initial_­suspend and is invoked at most once. GetReturnObjectManager GroManager(*this, S); GroManager.EmitGroAlloca(); @@ -706,8 +754,13 @@ llvm::Function *CoroEnd = CGM.getIntrinsic(llvm::Intrinsic::coro_end); Builder.CreateCall(CoroEnd, {NullPtr, Builder.getFalse()}); - if (Stmt *Ret = S.getReturnStmt()) + if (Stmt *Ret = S.getReturnStmt()) { + // Since we already emitted the return value above, so we shouldn't + // emit it again here. + if (GroManager.DirectEmit) + cast(Ret)->setRetValue(nullptr); EmitStmt(Ret); + } // LLVM require the frontend to mark the coroutine. CurFn->setPresplitCoroutine(); diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp --- a/clang/lib/Sema/SemaCoroutine.cpp +++ b/clang/lib/Sema/SemaCoroutine.cpp @@ -1730,13 +1730,22 @@ assert(!FnRetType->isDependentType() && "get_return_object type must no longer be dependent"); + // The call to get_­return_­object is sequenced before the call to + // initial_­suspend and is invoked at most once, but there are caveats + // regarding on whether the prvalue result object may be initialized + // directly/eager or delayed, depending on the types involved. + // + // More info at https://github.com/cplusplus/papers/issues/1414 + bool GroMatchesRetType = S.getASTContext().hasSameType(GroType, FnRetType); + if (FnRetType->isVoidType()) { ExprResult Res = S.ActOnFinishFullExpr(this->ReturnValue, Loc, /*DiscardedValue*/ false); if (Res.isInvalid()) return false; - this->ResultDecl = Res.get(); + if (!GroMatchesRetType) + this->ResultDecl = Res.get(); return true; } @@ -1749,52 +1758,59 @@ return false; } - auto *GroDecl = VarDecl::Create( - S.Context, &FD, FD.getLocation(), FD.getLocation(), - &S.PP.getIdentifierTable().get("__coro_gro"), GroType, - S.Context.getTrivialTypeSourceInfo(GroType, Loc), SC_None); - GroDecl->setImplicit(); - - S.CheckVariableDeclarationType(GroDecl); - if (GroDecl->isInvalidDecl()) - return false; + StmtResult ReturnStmt; + clang::VarDecl *GroDecl = nullptr; + if (GroMatchesRetType) { + ReturnStmt = S.BuildReturnStmt(Loc, ReturnValue); + } else { + GroDecl = VarDecl::Create( + S.Context, &FD, FD.getLocation(), FD.getLocation(), + &S.PP.getIdentifierTable().get("__coro_gro"), GroType, + S.Context.getTrivialTypeSourceInfo(GroType, Loc), SC_None); + GroDecl->setImplicit(); + + S.CheckVariableDeclarationType(GroDecl); + if (GroDecl->isInvalidDecl()) + return false; - InitializedEntity Entity = InitializedEntity::InitializeVariable(GroDecl); - ExprResult Res = - S.PerformCopyInitialization(Entity, SourceLocation(), ReturnValue); - if (Res.isInvalid()) - return false; + InitializedEntity Entity = InitializedEntity::InitializeVariable(GroDecl); + ExprResult Res = + S.PerformCopyInitialization(Entity, SourceLocation(), ReturnValue); + if (Res.isInvalid()) + return false; - Res = S.ActOnFinishFullExpr(Res.get(), /*DiscardedValue*/ false); - if (Res.isInvalid()) - return false; + Res = S.ActOnFinishFullExpr(Res.get(), /*DiscardedValue*/ false); + if (Res.isInvalid()) + return false; - S.AddInitializerToDecl(GroDecl, Res.get(), - /*DirectInit=*/false); + S.AddInitializerToDecl(GroDecl, Res.get(), + /*DirectInit=*/false); - S.FinalizeDeclaration(GroDecl); + S.FinalizeDeclaration(GroDecl); - // Form a declaration statement for the return declaration, so that AST - // visitors can more easily find it. - StmtResult GroDeclStmt = - S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(GroDecl), Loc, Loc); - if (GroDeclStmt.isInvalid()) - return false; + // Form a declaration statement for the return declaration, so that AST + // visitors can more easily find it. + StmtResult GroDeclStmt = + S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(GroDecl), Loc, Loc); + if (GroDeclStmt.isInvalid()) + return false; - this->ResultDecl = GroDeclStmt.get(); + this->ResultDecl = GroDeclStmt.get(); - ExprResult declRef = S.BuildDeclRefExpr(GroDecl, GroType, VK_LValue, Loc); - if (declRef.isInvalid()) - return false; + ExprResult declRef = S.BuildDeclRefExpr(GroDecl, GroType, VK_LValue, Loc); + if (declRef.isInvalid()) + return false; - StmtResult ReturnStmt = S.BuildReturnStmt(Loc, declRef.get()); + ReturnStmt = S.BuildReturnStmt(Loc, declRef.get()); + } if (ReturnStmt.isInvalid()) { noteMemberDeclaredHere(S, ReturnValue, Fn); return false; } - if (cast(ReturnStmt.get())->getNRVOCandidate() == GroDecl) + if (!GroMatchesRetType && + cast(ReturnStmt.get())->getNRVOCandidate() == GroDecl) GroDecl->setNRVOVariable(true); this->ReturnStmt = ReturnStmt.get(); diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -8103,11 +8103,12 @@ return StmtError(); Builder.Deallocate = DeallocRes.get(); - assert(S->getResultDecl() && "ResultDecl must already be built"); - StmtResult ResultDecl = getDerived().TransformStmt(S->getResultDecl()); - if (ResultDecl.isInvalid()) - return StmtError(); - Builder.ResultDecl = ResultDecl.get(); + if (auto *ResultDecl = S->getResultDecl()) { + StmtResult Res = getDerived().TransformStmt(ResultDecl); + if (Res.isInvalid()) + return StmtError(); + Builder.ResultDecl = Res.get(); + } if (auto *ReturnStmt = S->getReturnStmt()) { StmtResult Res = getDerived().TransformStmt(ReturnStmt); diff --git a/clang/test/CodeGenCoroutines/coro-gro.cpp b/clang/test/CodeGenCoroutines/coro-gro.cpp --- a/clang/test/CodeGenCoroutines/coro-gro.cpp +++ b/clang/test/CodeGenCoroutines/coro-gro.cpp @@ -2,26 +2,9 @@ // Verify that coroutine promise and allocated memory are freed up on exception. // RUN: %clang_cc1 -std=c++20 -triple=x86_64-unknown-linux-gnu -emit-llvm -o - %s -disable-llvm-passes | FileCheck %s -namespace std { -template struct coroutine_traits; +#include "Inputs/coroutine.h" -template struct coroutine_handle { - coroutine_handle() = default; - static coroutine_handle from_address(void *) noexcept; -}; -template <> struct coroutine_handle { - static coroutine_handle from_address(void *) noexcept; - coroutine_handle() = default; - template - coroutine_handle(coroutine_handle) noexcept; -}; -} // namespace std - -struct suspend_always { - bool await_ready() noexcept; - void await_suspend(std::coroutine_handle<>) noexcept; - void await_resume() noexcept; -}; +using namespace std; struct GroType { ~GroType(); @@ -51,8 +34,8 @@ // CHECK: %[[Size:.+]] = call i64 @llvm.coro.size.i64() // CHECK: call noalias noundef nonnull ptr @_Znwm(i64 noundef %[[Size]]) // CHECK: store i1 false, ptr %[[GroActive]] - // CHECK: call void @_ZNSt16coroutine_traitsIJiEE12promise_typeC1Ev( - // CHECK: call void @_ZNSt16coroutine_traitsIJiEE12promise_type17get_return_objectEv( + // CHECK: call void @_ZNSt16coroutine_traitsIiJEE12promise_typeC1Ev( + // CHECK: call void @_ZNSt16coroutine_traitsIiJEE12promise_type17get_return_objectEv( // CHECK: store i1 true, ptr %[[GroActive]] Cleanup cleanup; @@ -60,16 +43,18 @@ co_return; // CHECK: call void @_Z11doSomethingv( - // CHECK: call void @_ZNSt16coroutine_traitsIJiEE12promise_type11return_voidEv( + // CHECK: call void @_ZNSt16coroutine_traitsIiJEE12promise_type11return_voidEv( // CHECK: call void @_ZN7CleanupD1Ev( // Destroy promise and free the memory. - // CHECK: call void @_ZNSt16coroutine_traitsIJiEE12promise_typeD1Ev( + // CHECK: call void @_ZNSt16coroutine_traitsIiJEE12promise_typeD1Ev( // CHECK: %[[Mem:.+]] = call ptr @llvm.coro.free( // CHECK: call void @_ZdlPv(ptr noundef %[[Mem]]) // Initialize retval from Gro and destroy Gro + // Note this also tests delaying initialization when Gro and function return + // types mismatch (see cwg2563). // CHECK: %[[Conv:.+]] = call noundef i32 @_ZN7GroTypecviEv( // CHECK: store i32 %[[Conv]], ptr %[[RetVal]] @@ -84,3 +69,38 @@ // CHECK: %[[LoadRet:.+]] = load i32, ptr %[[RetVal]] // CHECK: ret i32 %[[LoadRet]] } + +class invoker { +public: + class invoker_promise { + public: + invoker get_return_object() { return invoker{}; } + auto initial_suspend() { return suspend_always{}; } + auto final_suspend() noexcept { return suspend_always{}; } + void return_void() {} + void unhandled_exception() {} + }; + using promise_type = invoker_promise; + invoker() {} + invoker(const invoker &) = delete; + invoker &operator=(const invoker &) = delete; + invoker(invoker &&) = delete; + invoker &operator=(invoker &&) = delete; +}; + +// According to cwg2563, matching GRO and function return type must allow +// for eager initialization and RVO. +// CHECK: define{{.*}} void @_Z1gv({{.*}} %[[AggRes:.+]]) +invoker g() { + // CHECK: %[[ResultPtr:.+]] = alloca ptr + // CHECK-NEXT: %[[Promise:.+]] = alloca %"class.invoker::invoker_promise" + + // CHECK: store ptr %[[AggRes]], ptr %[[ResultPtr]] + // CHECK: coro.init: + // CHECK: = call ptr @llvm.coro.begin + + // delayed GRO pattern stores a GRO active flag, make sure to not emit it. + // CHECK-NOT: store i1 false, ptr + // CHECK: call void @_ZN7invoker15invoker_promise17get_return_objectEv({{.*}} %[[AggRes]] + co_return; +} diff --git a/clang/test/SemaCXX/coroutine-no-move-ctor.cpp b/clang/test/SemaCXX/coroutine-no-move-ctor.cpp --- a/clang/test/SemaCXX/coroutine-no-move-ctor.cpp +++ b/clang/test/SemaCXX/coroutine-no-move-ctor.cpp @@ -15,13 +15,10 @@ }; using promise_type = invoker_promise; invoker() {} - // TODO: implement RVO for get_return_object type matching - // function return type. - // - // invoker(const invoker &) = delete; - // invoker &operator=(const invoker &) = delete; - // invoker(invoker &&) = delete; - // invoker &operator=(invoker &&) = delete; + invoker(const invoker &) = delete; + invoker &operator=(const invoker &) = delete; + invoker(invoker &&) = delete; + invoker &operator=(invoker &&) = delete; }; invoker f() {