diff --git a/clang/include/clang/AST/StmtCXX.h b/clang/include/clang/AST/StmtCXX.h --- a/clang/include/clang/AST/StmtCXX.h +++ b/clang/include/clang/AST/StmtCXX.h @@ -411,9 +411,8 @@ return cast(getStoredStmts()[SubStmt::ReturnValue]); } Expr *getReturnValue() const { - assert(getReturnStmt()); - auto *RS = cast(getReturnStmt()); - return RS->getRetValue(); + auto *RS = dyn_cast_or_null(getReturnStmt()); + return RS ? RS->getRetValue() : nullptr; } Stmt *getReturnStmt() const { return getStoredStmts()[SubStmt::ReturnStmt]; } Stmt *getReturnStmtOnAllocFailure() const { diff --git a/clang/lib/CodeGen/CGCoroutine.cpp b/clang/lib/CodeGen/CGCoroutine.cpp --- a/clang/lib/CodeGen/CGCoroutine.cpp +++ b/clang/lib/CodeGen/CGCoroutine.cpp @@ -472,13 +472,41 @@ CodeGenFunction &CGF; CGBuilderTy &Builder; const CoroutineBodyStmt &S; + bool DirectEmit = false; Address GroActiveFlag; CodeGenFunction::AutoVarEmission GroEmission; GetReturnObjectManager(CodeGenFunction &CGF, const CoroutineBodyStmt &S) : CGF(CGF), Builder(CGF.Builder), S(S), GroActiveFlag(Address::invalid()), - GroEmission(CodeGenFunction::AutoVarEmission::invalid()) {} + GroEmission(CodeGenFunction::AutoVarEmission::invalid()) { + // The call to get_­return_­object is sequenced before the call to + // initial_­suspend and is invoked at most once, but there are caveats + // regarding on whether the prvalue result object may be initialized + // directly/eager or delayed, depending on the types involved. + // + // The general cases: + // 1. Same type of get_return_object and coroutine return type (direct + // emission): + // - Constructed in the return slot. + // 2. Different types (delayed emission): + // - Constructed temporary object prior to initial suspend initialized with + // a call to get_return_object() + // - When coroutine needs to to return to the caller and needs to construct + // return value for the coroutine it is initialized with expiring value of + // the temporary obtained above. + // + // Direct emission for void returning coroutines or GROs. + DirectEmit = [&]() { + auto *RVI = S.getReturnValueInit(); + if (!RVI || CGF.FnRetTy->isVoidType()) + return true; + auto GroType = RVI->getType(); + if (GroType->isVoidType()) + return true; + return GroType == CGF.FnRetTy; + }(); + } // The gro variable has to outlive coroutine frame and coroutine promise, but, // it can only be initialized after coroutine promise was created, thus, we @@ -486,7 +514,10 @@ // cleanups. Later when coroutine promise is available we initialize the gro // and sets the flag that the cleanup is now active. void EmitGroAlloca() { - auto *GroDeclStmt = dyn_cast(S.getResultDecl()); + if (DirectEmit) + return; + + auto *GroDeclStmt = dyn_cast_or_null(S.getResultDecl()); if (!GroDeclStmt) { // If get_return_object returns void, no need to do an alloca. return; @@ -519,6 +550,27 @@ } void EmitGroInit() { + if (DirectEmit) { + // ReturnValue should be valid as long as the coroutine's return type + // is not void. The assertion could help us to reduce the check later. + assert(CGF.ReturnValue.isValid() == (bool)S.getReturnStmt()); + // Now we have the promise, initialize the GRO. + // We need to emit `get_return_object` first. According to: + // [dcl.fct.def.coroutine]p7 + // The call to get_return_­object is sequenced before the call to + // initial_suspend and is invoked at most once. + // + // So we couldn't emit return value when we emit return statment, + // otherwise the call to get_return_object wouldn't be in front + // of initial_suspend. + if (CGF.ReturnValue.isValid()) { + CGF.EmitAnyExprToMem(S.getReturnValue(), CGF.ReturnValue, + S.getReturnValue()->getType().getQualifiers(), + /*IsInit*/ true); + } + return; + } + if (!GroActiveFlag.isValid()) { // No Gro variable was allocated. Simply emit the call to // get_return_object. @@ -598,10 +650,6 @@ CGM.getIntrinsic(llvm::Intrinsic::coro_begin), {CoroId, Phi}); CurCoro.Data->CoroBegin = CoroBegin; - // We need to emit `get_­return_­object` first. According to: - // [dcl.fct.def.coroutine]p7 - // The call to get_­return_­object is sequenced before the call to - // initial_­suspend and is invoked at most once. GetReturnObjectManager GroManager(*this, S); GroManager.EmitGroAlloca(); @@ -706,8 +754,13 @@ llvm::Function *CoroEnd = CGM.getIntrinsic(llvm::Intrinsic::coro_end); Builder.CreateCall(CoroEnd, {NullPtr, Builder.getFalse()}); - if (Stmt *Ret = S.getReturnStmt()) + if (Stmt *Ret = S.getReturnStmt()) { + // Since we already emitted the return value above, so we shouldn't + // emit it again here. + if (GroManager.DirectEmit) + cast(Ret)->setRetValue(nullptr); EmitStmt(Ret); + } // LLVM require the frontend to mark the coroutine. CurFn->setPresplitCoroutine(); diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp --- a/clang/lib/Sema/SemaCoroutine.cpp +++ b/clang/lib/Sema/SemaCoroutine.cpp @@ -1730,13 +1730,16 @@ assert(!FnRetType->isDependentType() && "get_return_object type must no longer be dependent"); + bool GroMatchesRetType = GroType == FnRetType; + if (FnRetType->isVoidType()) { ExprResult Res = S.ActOnFinishFullExpr(this->ReturnValue, Loc, /*DiscardedValue*/ false); if (Res.isInvalid()) return false; - this->ResultDecl = Res.get(); + if (!GroMatchesRetType) + this->ResultDecl = Res.get(); return true; } @@ -1749,53 +1752,59 @@ return false; } - // StmtResult ReturnStmt = S.BuildReturnStmt(Loc, ReturnValue); - auto *GroDecl = VarDecl::Create( - S.Context, &FD, FD.getLocation(), FD.getLocation(), - &S.PP.getIdentifierTable().get("__coro_gro"), GroType, - S.Context.getTrivialTypeSourceInfo(GroType, Loc), SC_None); - GroDecl->setImplicit(); - - S.CheckVariableDeclarationType(GroDecl); - if (GroDecl->isInvalidDecl()) - return false; + StmtResult ReturnStmt; + clang::VarDecl *GroDecl = nullptr; + if (GroMatchesRetType) { + ReturnStmt = S.BuildReturnStmt(Loc, ReturnValue); + } else { + GroDecl = VarDecl::Create( + S.Context, &FD, FD.getLocation(), FD.getLocation(), + &S.PP.getIdentifierTable().get("__coro_gro"), GroType, + S.Context.getTrivialTypeSourceInfo(GroType, Loc), SC_None); + GroDecl->setImplicit(); + + S.CheckVariableDeclarationType(GroDecl); + if (GroDecl->isInvalidDecl()) + return false; - InitializedEntity Entity = InitializedEntity::InitializeVariable(GroDecl); - ExprResult Res = - S.PerformCopyInitialization(Entity, SourceLocation(), ReturnValue); - if (Res.isInvalid()) - return false; + InitializedEntity Entity = InitializedEntity::InitializeVariable(GroDecl); + ExprResult Res = + S.PerformCopyInitialization(Entity, SourceLocation(), ReturnValue); + if (Res.isInvalid()) + return false; - Res = S.ActOnFinishFullExpr(Res.get(), /*DiscardedValue*/ false); - if (Res.isInvalid()) - return false; + Res = S.ActOnFinishFullExpr(Res.get(), /*DiscardedValue*/ false); + if (Res.isInvalid()) + return false; - S.AddInitializerToDecl(GroDecl, Res.get(), - /*DirectInit=*/false); + S.AddInitializerToDecl(GroDecl, Res.get(), + /*DirectInit=*/false); - S.FinalizeDeclaration(GroDecl); + S.FinalizeDeclaration(GroDecl); - // Form a declaration statement for the return declaration, so that AST - // visitors can more easily find it. - StmtResult GroDeclStmt = - S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(GroDecl), Loc, Loc); - if (GroDeclStmt.isInvalid()) - return false; + // Form a declaration statement for the return declaration, so that AST + // visitors can more easily find it. + StmtResult GroDeclStmt = + S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(GroDecl), Loc, Loc); + if (GroDeclStmt.isInvalid()) + return false; - this->ResultDecl = GroDeclStmt.get(); + this->ResultDecl = GroDeclStmt.get(); - ExprResult declRef = S.BuildDeclRefExpr(GroDecl, GroType, VK_LValue, Loc); - if (declRef.isInvalid()) - return false; + ExprResult declRef = S.BuildDeclRefExpr(GroDecl, GroType, VK_LValue, Loc); + if (declRef.isInvalid()) + return false; - StmtResult ReturnStmt = S.BuildReturnStmt(Loc, declRef.get()); + ReturnStmt = S.BuildReturnStmt(Loc, declRef.get()); + } if (ReturnStmt.isInvalid()) { noteMemberDeclaredHere(S, ReturnValue, Fn); return false; } - if (cast(ReturnStmt.get())->getNRVOCandidate() == GroDecl) + if (!GroMatchesRetType && + cast(ReturnStmt.get())->getNRVOCandidate() == GroDecl) GroDecl->setNRVOVariable(true); this->ReturnStmt = ReturnStmt.get(); diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -8052,11 +8052,12 @@ return StmtError(); Builder.Deallocate = DeallocRes.get(); - assert(S->getResultDecl() && "ResultDecl must already be built"); - StmtResult ResultDecl = getDerived().TransformStmt(S->getResultDecl()); - if (ResultDecl.isInvalid()) - return StmtError(); - Builder.ResultDecl = ResultDecl.get(); + if (auto *ResultDecl = S->getResultDecl()) { + StmtResult Res = getDerived().TransformStmt(ResultDecl); + if (Res.isInvalid()) + return StmtError(); + Builder.ResultDecl = Res.get(); + } if (auto *ReturnStmt = S->getReturnStmt()) { StmtResult Res = getDerived().TransformStmt(ReturnStmt); diff --git a/clang/test/SemaCXX/coroutine-no-move-ctor.cpp b/clang/test/SemaCXX/coroutine-no-move-ctor.cpp --- a/clang/test/SemaCXX/coroutine-no-move-ctor.cpp +++ b/clang/test/SemaCXX/coroutine-no-move-ctor.cpp @@ -15,13 +15,10 @@ }; using promise_type = invoker_promise; invoker() {} - // TODO: implement RVO for get_return_object type matching - // function return type. - // - // invoker(const invoker &) = delete; - // invoker &operator=(const invoker &) = delete; - // invoker(invoker &&) = delete; - // invoker &operator=(invoker &&) = delete; + invoker(const invoker &) = delete; + invoker &operator=(const invoker &) = delete; + invoker(invoker &&) = delete; + invoker &operator=(invoker &&) = delete; }; invoker f() {