Index: include/clang/AST/Stmt.h =================================================================== --- include/clang/AST/Stmt.h +++ include/clang/AST/Stmt.h @@ -256,6 +256,9 @@ unsigned : NumStmtBits; + /// True if this ReturnStmt has storage for an NRVO candidate. + unsigned HasNRVOCandidate : 1; + /// The location of the "return". SourceLocation RetLoc; }; @@ -1999,40 +2002,67 @@ /// return a value, and it allows returning a value in functions declared to /// return void. We explicitly model this in the AST, which means you can't /// depend on the return type of the function and the presence of an argument. -class ReturnStmt : public Stmt { +class ReturnStmt final + : public Stmt, + private llvm::TrailingObjects { + friend TrailingObjects; + + /// The return expression. Stmt *RetExpr; - const VarDecl *NRVOCandidate; -public: - explicit ReturnStmt(SourceLocation RL) : ReturnStmt(RL, nullptr, nullptr) {} + // ReturnStmt is followed optionally by a trailing "const VarDecl *" + // for the NRVO candidate. Present if and only if hasNRVOCandidate(). - ReturnStmt(SourceLocation RL, Expr *E, const VarDecl *NRVOCandidate) - : Stmt(ReturnStmtClass), RetExpr((Stmt *)E), - NRVOCandidate(NRVOCandidate) { - ReturnStmtBits.RetLoc = RL; + /// True if this ReturnStmt has storage for an NRVO candidate. + bool hasNRVOCandidate() const { return ReturnStmtBits.HasNRVOCandidate; } + + unsigned numTrailingObjects(OverloadToken) const { + return hasNRVOCandidate(); } - /// Build an empty return expression. - explicit ReturnStmt(EmptyShell Empty) : Stmt(ReturnStmtClass, Empty) {} + /// Build a return statement. + ReturnStmt(SourceLocation RL, Expr *E, const VarDecl *NRVOCandidate); - const Expr *getRetValue() const; - Expr *getRetValue(); - void setRetValue(Expr *E) { RetExpr = reinterpret_cast(E); } + /// Build an empty return statement. + explicit ReturnStmt(EmptyShell Empty, bool HasNRVOCandidate); - SourceLocation getReturnLoc() const { return ReturnStmtBits.RetLoc; } - void setReturnLoc(SourceLocation L) { ReturnStmtBits.RetLoc = L; } +public: + /// Create a return statement. + static ReturnStmt *Create(const ASTContext &Ctx, SourceLocation RL, Expr *E, + const VarDecl *NRVOCandidate); + + /// Create an empty return statement, optionally with + /// storage for an NRVO candidate. + static ReturnStmt *CreateEmpty(const ASTContext &Ctx, bool HasNRVOCandidate); + + Expr *getRetValue() { return reinterpret_cast(RetExpr); } + const Expr *getRetValue() const { return reinterpret_cast(RetExpr); } + void setRetValue(Expr *E) { RetExpr = reinterpret_cast(E); } /// Retrieve the variable that might be used for the named return /// value optimization. /// /// The optimization itself can only be performed if the variable is /// also marked as an NRVO object. - const VarDecl *getNRVOCandidate() const { return NRVOCandidate; } - void setNRVOCandidate(const VarDecl *Var) { NRVOCandidate = Var; } + const VarDecl *getNRVOCandidate() const { + return hasNRVOCandidate() ? *getTrailingObjects() + : nullptr; + } - SourceLocation getBeginLoc() const { return getReturnLoc(); } + /// Set the variable that might be used for the named return value + /// optimization. The return statement must have storage for it, + /// which is the case if and only if hasNRVOCandidate() is true. + void setNRVOCandidate(const VarDecl *Var) { + assert(hasNRVOCandidate() && + "This return statement has no storage for an NRVO candidate!"); + *getTrailingObjects() = Var; + } - SourceLocation getEndLoc() const { + SourceLocation getReturnLoc() const { return ReturnStmtBits.RetLoc; } + void setReturnLoc(SourceLocation L) { ReturnStmtBits.RetLoc = L; } + + SourceLocation getBeginLoc() const { return getReturnLoc(); } + SourceLocation getEndLoc() const LLVM_READONLY { return RetExpr ? RetExpr->getEndLoc() : getReturnLoc(); } @@ -2042,7 +2072,8 @@ // Iterators child_range children() { - if (RetExpr) return child_range(&RetExpr, &RetExpr+1); + if (RetExpr) + return child_range(&RetExpr, &RetExpr + 1); return child_range(child_iterator(), child_iterator()); } }; Index: lib/AST/ASTImporter.cpp =================================================================== --- lib/AST/ASTImporter.cpp +++ lib/AST/ASTImporter.cpp @@ -5957,8 +5957,8 @@ const VarDecl *ToNRVOCandidate; std::tie(ToReturnLoc, ToRetValue, ToNRVOCandidate) = *Imp; - return new (Importer.getToContext()) ReturnStmt( - ToReturnLoc, ToRetValue, ToNRVOCandidate); + return ReturnStmt::Create(Importer.getToContext(), ToReturnLoc, ToRetValue, + ToNRVOCandidate); } ExpectedStmt ASTNodeImporter::VisitCXXCatchStmt(CXXCatchStmt *S) { Index: lib/AST/Stmt.cpp =================================================================== --- lib/AST/Stmt.cpp +++ lib/AST/Stmt.cpp @@ -1042,11 +1042,33 @@ } // ReturnStmt -const Expr* ReturnStmt::getRetValue() const { - return cast_or_null(RetExpr); -} -Expr* ReturnStmt::getRetValue() { - return cast_or_null(RetExpr); +ReturnStmt::ReturnStmt(SourceLocation RL, Expr *E, const VarDecl *NRVOCandidate) + : Stmt(ReturnStmtClass), RetExpr(E) { + bool HasNRVOCandidate = NRVOCandidate != nullptr; + ReturnStmtBits.HasNRVOCandidate = HasNRVOCandidate; + if (HasNRVOCandidate) + setNRVOCandidate(NRVOCandidate); + setReturnLoc(RL); +} + +ReturnStmt::ReturnStmt(EmptyShell Empty, bool HasNRVOCandidate) + : Stmt(ReturnStmtClass, Empty) { + ReturnStmtBits.HasNRVOCandidate = HasNRVOCandidate; +} + +ReturnStmt *ReturnStmt::Create(const ASTContext &Ctx, SourceLocation RL, + Expr *E, const VarDecl *NRVOCandidate) { + bool HasNRVOCandidate = NRVOCandidate != nullptr; + void *Mem = Ctx.Allocate(totalSizeToAlloc(HasNRVOCandidate), + alignof(ReturnStmt)); + return new (Mem) ReturnStmt(RL, E, NRVOCandidate); +} + +ReturnStmt *ReturnStmt::CreateEmpty(const ASTContext &Ctx, + bool HasNRVOCandidate) { + void *Mem = Ctx.Allocate(totalSizeToAlloc(HasNRVOCandidate), + alignof(ReturnStmt)); + return new (Mem) ReturnStmt(EmptyShell(), HasNRVOCandidate); } // CaseStmt Index: lib/Analysis/BodyFarm.cpp =================================================================== --- lib/Analysis/BodyFarm.cpp +++ lib/Analysis/BodyFarm.cpp @@ -201,10 +201,9 @@ /*arrow=*/true, /*free=*/false); } - ReturnStmt *ASTMaker::makeReturn(const Expr *RetVal) { - return new (C) ReturnStmt(SourceLocation(), const_cast(RetVal), - nullptr); + return ReturnStmt::Create(C, SourceLocation(), const_cast(RetVal), + /* NRVOCandidate=*/nullptr); } IntegerLiteral *ASTMaker::makeIntegerLiteral(uint64_t Value, QualType Ty) { Index: lib/CodeGen/CGObjC.cpp =================================================================== --- lib/CodeGen/CGObjC.cpp +++ lib/CodeGen/CGObjC.cpp @@ -883,9 +883,10 @@ // If there's a non-trivial 'get' expression, we just have to emit that. if (!hasTrivialGetExpr(propImpl)) { if (!AtomicHelperFn) { - ReturnStmt ret(SourceLocation(), propImpl->getGetterCXXConstructor(), - /*nrvo*/ nullptr); - EmitReturnStmt(ret); + auto *ret = ReturnStmt::Create(getContext(), SourceLocation(), + propImpl->getGetterCXXConstructor(), + /* NRVOCandidate=*/nullptr); + EmitReturnStmt(*ret); } else { ObjCIvarDecl *ivar = propImpl->getPropertyIvarDecl(); Index: lib/Sema/SemaStmt.cpp =================================================================== --- lib/Sema/SemaStmt.cpp +++ lib/Sema/SemaStmt.cpp @@ -3226,7 +3226,8 @@ return StmtError(); RetValExp = ER.get(); } - return new (Context) ReturnStmt(ReturnLoc, RetValExp, nullptr); + return ReturnStmt::Create(Context, ReturnLoc, RetValExp, + /* NRVOCandidate=*/nullptr); } if (HasDeducedReturnType) { @@ -3352,8 +3353,8 @@ return StmtError(); RetValExp = ER.get(); } - ReturnStmt *Result = new (Context) ReturnStmt(ReturnLoc, RetValExp, - NRVOCandidate); + auto *Result = + ReturnStmt::Create(Context, ReturnLoc, RetValExp, NRVOCandidate); // If we need to check for the named return value optimization, // or if we need to infer the return type, @@ -3582,7 +3583,8 @@ return StmtError(); RetValExp = ER.get(); } - return new (Context) ReturnStmt(ReturnLoc, RetValExp, nullptr); + return ReturnStmt::Create(Context, ReturnLoc, RetValExp, + /* NRVOCandidate=*/nullptr); } // FIXME: Add a flag to the ScopeInfo to indicate whether we're performing @@ -3677,7 +3679,8 @@ } } - Result = new (Context) ReturnStmt(ReturnLoc, RetValExp, nullptr); + Result = ReturnStmt::Create(Context, ReturnLoc, RetValExp, + /* NRVOCandidate=*/nullptr); } else if (!RetValExp && !HasDependentReturnType) { FunctionDecl *FD = getCurFunctionDecl(); @@ -3699,7 +3702,8 @@ else Diag(ReturnLoc, DiagID) << getCurMethodDecl()->getDeclName() << 1/*meth*/; - Result = new (Context) ReturnStmt(ReturnLoc); + Result = ReturnStmt::Create(Context, ReturnLoc, /* RetExpr=*/nullptr, + /* NRVOCandidate=*/nullptr); } else { assert(RetValExp || HasDependentReturnType); const VarDecl *NRVOCandidate = nullptr; @@ -3752,7 +3756,7 @@ return StmtError(); RetValExp = ER.get(); } - Result = new (Context) ReturnStmt(ReturnLoc, RetValExp, NRVOCandidate); + Result = ReturnStmt::Create(Context, ReturnLoc, RetValExp, NRVOCandidate); } // If we need to check for the named return value optimization, save the Index: lib/Serialization/ASTReaderStmt.cpp =================================================================== --- lib/Serialization/ASTReaderStmt.cpp +++ lib/Serialization/ASTReaderStmt.cpp @@ -328,9 +328,14 @@ void ASTStmtReader::VisitReturnStmt(ReturnStmt *S) { VisitStmt(S); + + bool HasNRVOCandidate = Record.readInt(); + S->setRetValue(Record.readSubExpr()); + if (HasNRVOCandidate) + S->setNRVOCandidate(ReadDeclAs()); + S->setReturnLoc(ReadSourceLocation()); - S->setNRVOCandidate(ReadDeclAs()); } void ASTStmtReader::VisitDeclStmt(DeclStmt *S) { @@ -2359,7 +2364,8 @@ break; case STMT_RETURN: - S = new (Context) ReturnStmt(Empty); + S = ReturnStmt::CreateEmpty( + Context, /* HasNRVOCandidate=*/Record[ASTStmtReader::NumStmtFields]); break; case STMT_DECL: Index: lib/Serialization/ASTWriterStmt.cpp =================================================================== --- lib/Serialization/ASTWriterStmt.cpp +++ lib/Serialization/ASTWriterStmt.cpp @@ -249,9 +249,15 @@ void ASTStmtWriter::VisitReturnStmt(ReturnStmt *S) { VisitStmt(S); + + bool HasNRVOCandidate = S->getNRVOCandidate() != nullptr; + Record.push_back(HasNRVOCandidate); + Record.AddStmt(S->getRetValue()); + if (HasNRVOCandidate) + Record.AddDeclRef(S->getNRVOCandidate()); + Record.AddSourceLocation(S->getReturnLoc()); - Record.AddDeclRef(S->getNRVOCandidate()); Code = serialization::STMT_RETURN; }