Index: include/clang/AST/Stmt.h =================================================================== --- include/clang/AST/Stmt.h +++ include/clang/AST/Stmt.h @@ -879,7 +879,7 @@ /// IfStmt - This represents an if/then/else. /// class IfStmt : public Stmt { - enum { VAR, COND, THEN, ELSE, END_EXPR }; + enum { INIT, VAR, COND, THEN, ELSE, END_EXPR }; Stmt* SubExprs[END_EXPR]; SourceLocation IfLoc; @@ -887,7 +887,7 @@ public: IfStmt(const ASTContext &C, SourceLocation IL, - bool IsConstexpr, VarDecl *var, Expr *cond, + bool IsConstexpr, Stmt *init, VarDecl *var, Expr *cond, Stmt *then, SourceLocation EL = SourceLocation(), Stmt *elsev = nullptr); @@ -905,6 +905,10 @@ VarDecl *getConditionVariable() const; void setConditionVariable(const ASTContext &C, VarDecl *V); + Stmt *getInit() { return SubExprs[INIT]; } + const Stmt *getInit() const { return SubExprs[INIT]; } + void setInit(Stmt *S) { SubExprs[INIT] = S; } + /// If this IfStmt has a condition variable, return the faux DeclStmt /// associated with the creation of that condition variable. const DeclStmt *getConditionVariableDeclStmt() const { Index: include/clang/Parse/Parser.h =================================================================== --- include/clang/Parse/Parser.h +++ include/clang/Parse/Parser.h @@ -1682,7 +1682,8 @@ StmtResult ParseCompoundStatementBody(bool isStmtExpr = false); bool ParseParenExprOrCondition(Sema::ConditionResult &CondResult, SourceLocation Loc, - Sema::ConditionKind CK); + Sema::ConditionKind CK, + Stmt** Init = nullptr); StmtResult ParseIfStatement(SourceLocation *TrailingElseLoc); StmtResult ParseSwitchStatement(SourceLocation *TrailingElseLoc); StmtResult ParseWhileStatement(SourceLocation *TrailingElseLoc); @@ -1908,6 +1909,11 @@ return isDeclarationSpecifier(true); } + /// isForInitDeclarationWithSemi - is the same as isForInitDeclaration + /// but makes sure that there is a semicolon at the of the declaration + /// (needed for C++1z 'if statement with init') + bool isForInitDeclarationWithSemi(); + /// \brief Determine whether this is a C++1z for-range-identifier. bool isForRangeIdentifier(); Index: include/clang/Sema/Sema.h =================================================================== --- include/clang/Sema/Sema.h +++ include/clang/Sema/Sema.h @@ -3397,10 +3397,10 @@ class ConditionResult; StmtResult ActOnIfStmt(SourceLocation IfLoc, bool IsConstexpr, - ConditionResult Cond, Stmt *ThenVal, + ConditionResult Cond, Stmt *InitVal, Stmt *ThenVal, SourceLocation ElseLoc, Stmt *ElseVal); StmtResult BuildIfStmt(SourceLocation IfLoc, bool IsConstexpr, - ConditionResult Cond, Stmt *ThenVal, + ConditionResult Cond, Stmt *InitVal, Stmt *ThenVal, SourceLocation ElseLoc, Stmt *ElseVal); StmtResult ActOnStartOfSwitchStmt(SourceLocation SwitchLoc, ConditionResult Cond); @@ -8958,7 +8958,8 @@ enum class ConditionKind { Boolean, ///< A boolean condition, from 'if', 'while', 'for', or 'do'. ConstexprIf, ///< A constant boolean condition from 'if constexpr'. - Switch ///< An integral condition for a 'switch' statement. + Switch, ///< An integral condition for a 'switch' statement. + IfWithInit ///< A for-init-statement from C++17's 'if statement with initializer' }; ConditionResult ActOnCondition(Scope *S, SourceLocation Loc, Index: lib/AST/ASTImporter.cpp =================================================================== --- lib/AST/ASTImporter.cpp +++ lib/AST/ASTImporter.cpp @@ -4964,6 +4964,9 @@ Stmt *ASTNodeImporter::VisitIfStmt(IfStmt *S) { SourceLocation ToIfLoc = Importer.Import(S->getIfLoc()); VarDecl *ToConditionVariable = nullptr; + Stmt *ToInit = Importer.Import(S->getInit()); + if (!ToInit && S->getInit()) + return nullptr; if (VarDecl *FromConditionVariable = S->getConditionVariable()) { ToConditionVariable = dyn_cast_or_null(Importer.Import(FromConditionVariable)); @@ -4982,6 +4985,7 @@ return nullptr; return new (Importer.getToContext()) IfStmt(Importer.getToContext(), ToIfLoc, S->isConstexpr(), + ToInit, ToConditionVariable, ToCondition, ToThenStmt, ToElseLoc, ToElseStmt); Index: lib/AST/Stmt.cpp =================================================================== --- lib/AST/Stmt.cpp +++ lib/AST/Stmt.cpp @@ -764,11 +764,12 @@ } IfStmt::IfStmt(const ASTContext &C, SourceLocation IL, bool IsConstexpr, - VarDecl *var, Expr *cond, Stmt *then, SourceLocation EL, + Stmt *init, VarDecl *var, Expr *cond, Stmt *then, SourceLocation EL, Stmt *elsev) : Stmt(IfStmtClass), IfLoc(IL), ElseLoc(EL) { setConstexpr(IsConstexpr); setConditionVariable(C, var); + SubExprs[INIT] = init; SubExprs[COND] = cond; SubExprs[THEN] = then; SubExprs[ELSE] = elsev; Index: lib/Analysis/BodyFarm.cpp =================================================================== --- lib/Analysis/BodyFarm.cpp +++ lib/Analysis/BodyFarm.cpp @@ -239,7 +239,7 @@ SourceLocation()); // (5) Create the 'if' statement. - IfStmt *If = new (C) IfStmt(C, SourceLocation(), false, nullptr, UO, CS); + IfStmt *If = new (C) IfStmt(C, SourceLocation(), false, nullptr, nullptr, UO, CS); return If; } @@ -343,7 +343,7 @@ /// Construct the If. Stmt *If = - new (C) IfStmt(C, SourceLocation(), false, nullptr, Comparison, Body, + new (C) IfStmt(C, SourceLocation(), false, nullptr, nullptr, Comparison, Body, SourceLocation(), Else); return If; Index: lib/CodeGen/CGStmt.cpp =================================================================== --- lib/CodeGen/CGStmt.cpp +++ lib/CodeGen/CGStmt.cpp @@ -557,6 +557,9 @@ // unequal to 0. The condition must be a scalar type. LexicalScope ConditionScope(*this, S.getCond()->getSourceRange()); + if (S.getInit()) + EmitStmt(S.getInit()); + if (S.getConditionVariable()) EmitAutoVarDecl(*S.getConditionVariable()); Index: lib/Parse/ParseStmt.cpp =================================================================== --- lib/Parse/ParseStmt.cpp +++ lib/Parse/ParseStmt.cpp @@ -1054,10 +1054,26 @@ /// errors in the condition. bool Parser::ParseParenExprOrCondition(Sema::ConditionResult &Cond, SourceLocation Loc, - Sema::ConditionKind CK) { + Sema::ConditionKind CK, + Stmt** Init) { BalancedDelimiterTracker T(*this, tok::l_paren); T.consumeOpen(); + if (CK == Sema::ConditionKind::IfWithInit && + isForInitDeclarationWithSemi()) { // if (int X = 4; + // Parse declaration, which eats the ';'. + ParsedAttributesWithRange attrs(AttrFactory); + MaybeParseCXX11Attributes(attrs); + + SourceLocation DeclStart = Tok.getLocation(), DeclEnd; + DeclGroupPtrTy DG = ParseSimpleDeclaration( + Declarator::ForContext, DeclEnd, attrs, false, nullptr); + StmtResult InitStmt = Actions.ActOnDeclStmt(DG, DeclStart, Tok.getLocation()); + if (Init) *Init = InitStmt.get(); + + ConsumeToken(); // Consume semi + } + if (getLangOpts().CPlusPlus) Cond = ParseCXXCondition(Loc, CK); else { @@ -1140,9 +1156,14 @@ // Parse the condition. Sema::ConditionResult Cond; + Stmt* Init = nullptr; if (ParseParenExprOrCondition(Cond, IfLoc, IsConstexpr ? Sema::ConditionKind::ConstexprIf - : Sema::ConditionKind::Boolean)) + : getLangOpts().CPlusPlus1z + ? Sema::ConditionKind::IfWithInit + : Sema::ConditionKind::Boolean, + &Init)) + return StmtError(); llvm::Optional ConstexprCondition; @@ -1241,8 +1262,8 @@ if (ElseStmt.isInvalid()) ElseStmt = Actions.ActOnNullStmt(ElseStmtLoc); - return Actions.ActOnIfStmt(IfLoc, IsConstexpr, Cond, ThenStmt.get(), ElseLoc, - ElseStmt.get()); + return Actions.ActOnIfStmt(IfLoc, IsConstexpr, Cond, Init, + ThenStmt.get(), ElseLoc, ElseStmt.get()); } /// ParseSwitchStatement Index: lib/Parse/ParseTentative.cpp =================================================================== --- lib/Parse/ParseTentative.cpp +++ lib/Parse/ParseTentative.cpp @@ -66,6 +66,31 @@ } } +/// isForInitDeclarationWithSemi - is the same as isForInitDeclaration +/// but makes sure that there is a semicolon at the of the declaration +/// (needed for C++1z 'if statement with init') +bool Parser::isForInitDeclarationWithSemi() { + if (!isForInitDeclaration()) + return false; + + TentativeParsingAction PA(*this); + TPResult TPR = TryParseSimpleDeclaration(false); + if (Tok.isNot(tok::semi)) + TPR = TPResult::Error; + PA.Revert(); + + // In case of an error, let the declaration parsing code handle it. + if (TPR == TPResult::Error) + return true; + + // Declarations take precedence over expressions. + if (TPR == TPResult::Ambiguous) + TPR = TPResult::True; + + assert(TPR == TPResult::True || TPR == TPResult::False); + return TPR == TPResult::True; +} + /// isCXXSimpleDeclaration - C++-specialized function that disambiguates /// between a simple-declaration or an expression-statement. /// If during the disambiguation process a parsing error is encountered, Index: lib/Sema/SemaExpr.cpp =================================================================== --- lib/Sema/SemaExpr.cpp +++ lib/Sema/SemaExpr.cpp @@ -14385,6 +14385,7 @@ ExprResult Cond; switch (CK) { case ConditionKind::Boolean: + case ConditionKind::IfWithInit: Cond = CheckBooleanCondition(Loc, SubExpr); break; Index: lib/Sema/SemaExprCXX.cpp =================================================================== --- lib/Sema/SemaExprCXX.cpp +++ lib/Sema/SemaExprCXX.cpp @@ -3095,6 +3095,7 @@ switch (CK) { case ConditionKind::Boolean: + case ConditionKind::IfWithInit: return CheckBooleanCondition(StmtLoc, Condition.get()); case ConditionKind::ConstexprIf: Index: lib/Sema/SemaStmt.cpp =================================================================== --- lib/Sema/SemaStmt.cpp +++ lib/Sema/SemaStmt.cpp @@ -505,7 +505,7 @@ StmtResult Sema::ActOnIfStmt(SourceLocation IfLoc, bool IsConstexpr, ConditionResult Cond, - Stmt *thenStmt, SourceLocation ElseLoc, + Stmt *initStmt, Stmt *thenStmt, SourceLocation ElseLoc, Stmt *elseStmt) { if (Cond.isInvalid()) Cond = ConditionResult( @@ -524,11 +524,11 @@ DiagnoseEmptyStmtBody(CondExpr->getLocEnd(), thenStmt, diag::warn_empty_if_body); - return BuildIfStmt(IfLoc, IsConstexpr, Cond, thenStmt, ElseLoc, elseStmt); + return BuildIfStmt(IfLoc, IsConstexpr, Cond, initStmt, thenStmt, ElseLoc, elseStmt); } StmtResult Sema::BuildIfStmt(SourceLocation IfLoc, bool IsConstexpr, - ConditionResult Cond, Stmt *thenStmt, + ConditionResult Cond, Stmt *initStmt, Stmt *thenStmt, SourceLocation ElseLoc, Stmt *elseStmt) { if (Cond.isInvalid()) return StmtError(); @@ -539,7 +539,7 @@ DiagnoseUnusedExprResult(thenStmt); DiagnoseUnusedExprResult(elseStmt); - return new (Context) IfStmt(Context, IfLoc, IsConstexpr, Cond.get().first, + return new (Context) IfStmt(Context, IfLoc, IsConstexpr, initStmt, Cond.get().first, Cond.get().second, thenStmt, ElseLoc, elseStmt); } Index: lib/Sema/TreeTransform.h =================================================================== --- lib/Sema/TreeTransform.h +++ lib/Sema/TreeTransform.h @@ -1174,9 +1174,9 @@ /// By default, performs semantic analysis to build the new statement. /// Subclasses may override this routine to provide different behavior. StmtResult RebuildIfStmt(SourceLocation IfLoc, bool IsConstexpr, - Sema::ConditionResult Cond, Stmt *Then, + Sema::ConditionResult Cond, Stmt *Init, Stmt *Then, SourceLocation ElseLoc, Stmt *Else) { - return getSema().ActOnIfStmt(IfLoc, IsConstexpr, Cond, Then, ElseLoc, Else); + return getSema().ActOnIfStmt(IfLoc, IsConstexpr, Cond, Init, Then, ElseLoc, Else); } /// \brief Start building a new switch statement. @@ -6225,6 +6225,11 @@ template StmtResult TreeTransform::TransformIfStmt(IfStmt *S) { + // Transform the initialization statement + StmtResult Init = getDerived().TransformStmt(S->getInit()); + if (Init.isInvalid()) + return StmtError(); + // Transform the condition Sema::ConditionResult Cond = getDerived().TransformCondition( S->getIfLoc(), S->getConditionVariable(), S->getCond(), @@ -6257,6 +6262,7 @@ } if (!getDerived().AlwaysRebuild() && + Init.get() == S->getInit() && Cond.get() == std::make_pair(S->getConditionVariable(), S->getCond()) && Then.get() == S->getThen() && Else.get() == S->getElse()) @@ -6263,7 +6269,7 @@ return S; return getDerived().RebuildIfStmt(S->getIfLoc(), S->isConstexpr(), Cond, - Then.get(), S->getElseLoc(), Else.get()); + Init.get(), Then.get(), S->getElseLoc(), Else.get()); } template Index: lib/Serialization/ASTReaderStmt.cpp =================================================================== --- lib/Serialization/ASTReaderStmt.cpp +++ lib/Serialization/ASTReaderStmt.cpp @@ -185,6 +185,7 @@ void ASTStmtReader::VisitIfStmt(IfStmt *S) { VisitStmt(S); S->setConstexpr(Record[Idx++]); + S->setInit(Reader.ReadSubStmt()); S->setConditionVariable(Reader.getContext(), ReadDeclAs(Record, Idx)); S->setCond(Reader.ReadSubExpr()); Index: lib/Serialization/ASTWriterStmt.cpp =================================================================== --- lib/Serialization/ASTWriterStmt.cpp +++ lib/Serialization/ASTWriterStmt.cpp @@ -129,6 +129,7 @@ void ASTStmtWriter::VisitIfStmt(IfStmt *S) { VisitStmt(S); Record.push_back(S->isConstexpr()); + Record.AddStmt(S->getInit()); Record.AddDeclRef(S->getConditionVariable()); Record.AddStmt(S->getCond()); Record.AddStmt(S->getThen());