Index: include/clang/AST/Stmt.h =================================================================== --- include/clang/AST/Stmt.h +++ include/clang/AST/Stmt.h @@ -131,7 +131,8 @@ unsigned : NumStmtBits; - unsigned NumStmts : 32 - NumStmtBits; + unsigned WasReplaced : 1; + unsigned NumStmts : 32 - (NumStmtBits + 1); /// The location of the opening "{". SourceLocation LBraceLoc; @@ -1328,6 +1329,7 @@ explicit CompoundStmt(SourceLocation Loc) : Stmt(CompoundStmtClass), RBraceLoc(Loc) { CompoundStmtBits.NumStmts = 0; + CompoundStmtBits.WasReplaced = 0; CompoundStmtBits.LBraceLoc = Loc; } @@ -1341,7 +1343,10 @@ using body_range = llvm::iterator_range; body_range body() { return body_range(body_begin(), body_end()); } - body_iterator body_begin() { return getTrailingObjects(); } + body_iterator body_begin() { + Stmt** begin = getTrailingObjects(); + return !CompoundStmtBits.WasReplaced ? begin : (body_iterator)begin[0]; + } body_iterator body_end() { return body_begin() + size(); } Stmt *body_front() { return !body_empty() ? body_begin()[0] : nullptr; } @@ -1357,7 +1362,7 @@ } const_body_iterator body_begin() const { - return getTrailingObjects(); + return const_cast(this)->body_begin(); } const_body_iterator body_end() const { return body_begin() + size(); } @@ -1391,6 +1396,8 @@ return const_reverse_body_iterator(body_begin()); } + void replaceStmts(const ASTContext &C, llvm::ArrayRef Stmts); + // Get the Stmt that StmtExpr would consider to be the result of this // compound statement. This is used by StmtExpr to properly emulate the GCC // compound expression extension, which ignores trailing NullStmts when Index: lib/AST/Stmt.cpp =================================================================== --- lib/AST/Stmt.cpp +++ lib/AST/Stmt.cpp @@ -293,6 +293,7 @@ SourceLocation RB) : Stmt(CompoundStmtClass), RBraceLoc(RB) { CompoundStmtBits.NumStmts = Stmts.size(); + CompoundStmtBits.WasReplaced = 0; setStmts(Stmts); CompoundStmtBits.LBraceLoc = LB; } @@ -300,7 +301,7 @@ void CompoundStmt::setStmts(ArrayRef Stmts) { assert(CompoundStmtBits.NumStmts == Stmts.size() && "NumStmts doesn't fit in bits of CompoundStmtBits.NumStmts!"); - + assert(!CompoundStmtBits.WasReplaced && "Call replaceStmts!"); std::copy(Stmts.begin(), Stmts.end(), body_begin()); } @@ -316,10 +317,32 @@ void *Mem = C.Allocate(totalSizeToAlloc(NumStmts), alignof(CompoundStmt)); CompoundStmt *New = new (Mem) CompoundStmt(EmptyShell()); + New->CompoundStmtBits.WasReplaced = 0; New->CompoundStmtBits.NumStmts = NumStmts; return New; } +void CompoundStmt::replaceStmts(const ASTContext &C, + llvm::ArrayRef Stmts) { + Stmt** Body = body_begin(); + + if (CompoundStmtBits.WasReplaced) + C.Deallocate(Body); + else + memset(body_begin(), 0, size()); + + CompoundStmtBits.NumStmts = Stmts.size(); + assert(CompoundStmtBits.NumStmts == Stmts.size() && + "NumStmts doesn't fit in bits of CompoundStmtBits.NumStmts!"); + + Body = new (C) Stmt*[Stmts.size()]; + std::copy(Stmts.begin(), Stmts.end(), Body); + + getTrailingObjects()[0] = reinterpret_cast(Body); + + CompoundStmtBits.WasReplaced = 1; +} + const Expr *ValueStmt::getExprStmt() const { const Stmt *S = this; do { Index: unittests/AST/CMakeLists.txt =================================================================== --- unittests/AST/CMakeLists.txt +++ unittests/AST/CMakeLists.txt @@ -28,6 +28,7 @@ OMPStructuredBlockTest.cpp SourceLocationTest.cpp StmtPrinterTest.cpp + StmtTest.cpp StructuralEquivalenceTest.cpp ) Index: unittests/AST/StmtTest.cpp =================================================================== --- /dev/null +++ unittests/AST/StmtTest.cpp @@ -0,0 +1,53 @@ +//===- unittests/AST/StmtTest.cpp --- Declaration tests -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Unit tests for Stmt nodes in the AST. +// +//===----------------------------------------------------------------------===// + +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/Frontend/ASTUnit.h" +#include "clang/Tooling/Tooling.h" + +#include "DeclMatcher.h" + +#include "gtest/gtest.h" + +#include +#include + +using namespace clang; +using namespace clang::ast_matchers; +using namespace clang::tooling; + + +TEST(Stmt, CompoundStmtReplaceStmts) { + const char *const InputFileName = "input.cc"; + std::string Code = "void f() { int i = 1; } void g() { float f = 2.; }"; + std::vector Args; + std::unique_ptr AST = + tooling::buildASTFromCodeWithArgs(Code, Args, InputFileName); + + ASTContext &C = AST->getASTContext(); + TranslationUnitDecl *TU = C.getTranslationUnitDecl(); + auto MakeMatcher = [](const char* Id) { return functionDecl(hasName(Id)); }; + + auto FDf = FirstDeclMatcher().match(TU, MakeMatcher("f")); + auto FDg = FirstDeclMatcher().match(TU, MakeMatcher("g")); + + auto CSf = cast(FDf->getBody()); + auto CSg = cast(FDg->getBody()); + + llvm::SmallVector Stmts; + Stmts.push_back(CSf->body_front()); + // FIXME: Clone the CSg content to keep the AST invariants. + Stmts.push_back(CSg->body_front()); + CSf->replaceStmts(C, Stmts); + + ASSERT_TRUE(CSf->size() == 2); +}