Index: include/clang/Tooling/ASTDiff/ASTPatch.h =================================================================== --- /dev/null +++ include/clang/Tooling/ASTDiff/ASTPatch.h @@ -0,0 +1,70 @@ +//===- ASTPatch.h - AST patching ------------------------------*- C++ -*- -===// +// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_TOOLING_ASTDIFF_ASTPATCH_H +#define LLVM_CLANG_TOOLING_ASTDIFF_ASTPATCH_H + +#include "clang/AST/ASTContext.h" +#include "clang/AST/ASTTypeTraits.h" + +using namespace clang; +using namespace ast_type_traits; + +namespace clang { +namespace patch { + +bool remove(Decl *D, ASTContext &Context) { + auto *Ctx = D->getLexicalDeclContext(); + if (!Ctx->containsDecl(D)) + return false; + Ctx->removeDecl(D); + return true; +} + +static bool removeFrom(Stmt *S, ASTContext &Context, DynTypedNode Parent) { + auto *ConstParentS = Parent.template get(); + if (!ConstParentS) + return false; + auto *ParentS = const_cast(ConstParentS); + if (auto *CS = dyn_cast(ParentS)) { + std::vector Stmts(CS->child_begin(), CS->child_end()); + auto End = Stmts.end(); + Stmts.erase(std::remove(Stmts.begin(), Stmts.end(), S), Stmts.end()); + CS->setStmts(Context, Stmts); + return Stmts.end() != End; + } + return false; +} + +bool remove(Stmt *Node, ASTContext &Context) { + if (!Node) + return false; + auto DTN = DynTypedNode::create(*Node); + const auto &Parents = Context.getParents(DTN); + if (Parents.empty()) + return false; + auto &Parent = Parents[0]; + return removeFrom(Node, Context, Parent); +} + +template bool remove(const T *N, ASTContext &Context) { + return remove(const_cast(N), Context); +} + +bool remove(DynTypedNode DTN, ASTContext &Context) { + if (auto *S = DTN.get()) + return remove(S, Context); + if (auto *D = DTN.get()) + return remove(D, Context); + return false; +} +} +} +#endif Index: unittests/Tooling/ASTPatchTest.cpp =================================================================== --- /dev/null +++ unittests/Tooling/ASTPatchTest.cpp @@ -0,0 +1,64 @@ +//===- unittests/Tooling/ASTPatchTest.cpp ---------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/Tooling/ASTDiff/ASTPatch.h" +#include "clang/Tooling/Tooling.h" +#include "gtest/gtest.h" + +using namespace clang; +using namespace tooling; +using namespace ast_matchers; + +namespace { +template struct DeleteMatch : public MatchFinder::MatchCallback { + unsigned NumFound = 0; + bool Success = true; + + void run(const MatchFinder::MatchResult &Result) override { + auto *Node = Result.Nodes.getNodeAs("id"); + if (!Node) + return; + ++NumFound; + Success = Success && patch::remove(Node, *Result.Context); + } +}; +} // end anonymous namespace + +template +static testing::AssertionResult +isRemovalSuccessful(const NodeMatcher &StmtMatch, StringRef Code) { + DeleteMatch Deleter; + MatchFinder Finder; + Finder.addMatcher(StmtMatch, &Deleter); + std::unique_ptr Factory( + newFrontendActionFactory(&Finder)); + if (!runToolOnCode(Factory->create(), Code)) + return testing::AssertionFailure() + << R"(Parsing error in ")" << Code.str() << R"(")"; + if (Deleter.NumFound == 0) + return testing::AssertionFailure() << "Matcher didn't find any statements"; + return testing::AssertionResult(Deleter.Success); +} + +TEST(ASTPatch, RemoveStmt) { + ASSERT_TRUE(isRemovalSuccessful(returnStmt().bind("id"), + R"(void x(){ return;})")); +} + +TEST(ASTPatch, RemoveDecl) { + ASSERT_TRUE(isRemovalSuccessful(varDecl().bind("id"), + R"(int x = 0;)")); + ASSERT_TRUE(isRemovalSuccessful(functionTemplateDecl().bind("id"), R"( +template struct pred {}; +template pred > swap(); +template pred > swap(); +void swap(); +)")); +} Index: unittests/Tooling/CMakeLists.txt =================================================================== --- unittests/Tooling/CMakeLists.txt +++ unittests/Tooling/CMakeLists.txt @@ -11,6 +11,7 @@ endif() add_clang_unittest(ToolingTests + ASTPatchTest.cpp CastExprTest.cpp CommentHandlerTest.cpp CompilationDatabaseTest.cpp