Index: include/clang/Tooling/ASTDiff/ASTDiff.h =================================================================== --- include/clang/Tooling/ASTDiff/ASTDiff.h +++ include/clang/Tooling/ASTDiff/ASTDiff.h @@ -21,6 +21,7 @@ #define LLVM_CLANG_TOOLING_ASTDIFF_ASTDIFF_H #include "clang/Frontend/ASTUnit.h" +#include "clang/Rewrite/Core/Rewriter.h" #include "clang/Tooling/ASTDiff/ASTDiffInternal.h" namespace clang { @@ -51,6 +52,9 @@ llvm::Optional getQualifiedIdentifier() const; }; +bool patch(SyntaxTree &ModelSrc, SyntaxTree &ModelDst, SyntaxTree &TargetSrc, + const ComparisonOptions &Options, raw_ostream &OS); + class ASTDiff { public: ASTDiff(SyntaxTree &Src, SyntaxTree &Dst, const ComparisonOptions &Options); @@ -69,13 +73,17 @@ /// They can be constructed from any Decl or Stmt. class SyntaxTree { public: + /// Empty (invalid) SyntaxTree. + SyntaxTree(); /// Constructs a tree from a translation unit. SyntaxTree(ASTUnit &AST); /// Constructs a tree from any AST node. template SyntaxTree(T *Node, ASTUnit &AST) : TreeImpl(llvm::make_unique(this, Node, AST)) {} - SyntaxTree(SyntaxTree &&Other) = default; + SyntaxTree(SyntaxTree &&Other); + SyntaxTree &operator=(SyntaxTree &&Other); + explicit SyntaxTree(const SyntaxTree &Other); ~SyntaxTree(); ASTUnit &getASTUnit() const; @@ -93,7 +101,7 @@ /// Returns the range that contains the text that is associated with this /// node. - /* SourceRange getSourceRange(const Node &N) const; */ + SourceRange getSourceRange(const Node &N) const; /// Returns the offsets for the range returned by getSourceRange. std::pair getSourceRangeOffsets(const Node &N) const; Index: lib/Tooling/ASTDiff/ASTDiff.cpp =================================================================== --- lib/Tooling/ASTDiff/ASTDiff.cpp +++ lib/Tooling/ASTDiff/ASTDiff.cpp @@ -18,6 +18,8 @@ #include "clang/AST/LexicallyOrderedRecursiveASTVisitor.h" #include "clang/AST/StmtVisitor.h" #include "clang/Lex/Lexer.h" +#include "clang/Rewrite/Core/Rewriter.h" +#include "clang/Tooling/Core/Replacement.h" #include "llvm/ADT/PriorityQueue.h" #include "llvm/Support/MD5.h" @@ -27,6 +29,7 @@ using namespace llvm; using namespace clang; +using namespace tooling; namespace clang { namespace diff { @@ -139,6 +142,7 @@ typename std::enable_if::value, T>::type *Node, ASTUnit &AST) : Impl(Parent, dyn_cast(Node), AST) {} + explicit Impl(SyntaxTree *Parent, const Impl &Other); SyntaxTree *Parent; ASTUnit &AST; @@ -175,6 +179,8 @@ HashType hashNode(const Node &N) const; + SourceRange getSourceRange(const Node &N) const; + private: void initTree(); void setLeftMostDescendants(); @@ -337,6 +343,15 @@ initTree(); } +SyntaxTree::Impl::Impl(SyntaxTree *Parent, const Impl &Other) + : Impl(Parent, Other.AST) { + Nodes = Other.Nodes; + Leaves = Other.Leaves; + PostorderIds = Other.PostorderIds; + NodesBfs = Other.NodesBfs; + TemplateArgumentLocations = TemplateArgumentLocations; +} + static std::vector getSubtreePostorder(const SyntaxTree::Impl &Tree, NodeId Root) { std::vector Postorder; @@ -638,6 +653,26 @@ return HashResult; } +SourceRange SyntaxTree::Impl::getSourceRange(const Node &N) const { + SourceRange Range; + if (auto *Arg = N.ASTNode.get()) + Range = TemplateArgumentLocations.at(&N - &Nodes[0]); + else { + Range = N.ASTNode.getSourceRange(); + if (auto *ThisExpr = N.ASTNode.get()) + if (ThisExpr->isImplicit()) + Range.setEnd(Range.getBegin()); + // If it is a CXXConstructExpr that is not a temporary, then there is + // probably an identifier of an initialization that is included in the + // range. This identifier belongs to the parent node, so stick to the + // ctor arguments only. + if (auto *CE = N.ASTNode.get()) + if (!isa(CE)) + Range = CE->getParenOrBraceRange(); + } + return getSourceExtent(AST, Range); +} + /// Identifies a node in a subtree by its postorder offset, starting at 1. struct SNodeId { int Id = 0; @@ -1210,10 +1245,19 @@ return DiffImpl->getMapped(SourceTree.TreeImpl, Id); } +SyntaxTree::SyntaxTree() : TreeImpl(nullptr) {} + SyntaxTree::SyntaxTree(ASTUnit &AST) : TreeImpl(llvm::make_unique( this, AST.getASTContext().getTranslationUnitDecl(), AST)) {} +SyntaxTree::SyntaxTree(SyntaxTree &&Other) = default; + +SyntaxTree &SyntaxTree::operator=(SyntaxTree &&Other) = default; + +SyntaxTree::SyntaxTree(const SyntaxTree &Other) + : TreeImpl(llvm::make_unique(this, *Other.TreeImpl)) {} + SyntaxTree::~SyntaxTree() = default; ASTUnit &SyntaxTree::getASTUnit() const { return TreeImpl->AST; } @@ -1237,19 +1281,14 @@ return TreeImpl->findPositionInParent(Id); } +SourceRange SyntaxTree::getSourceRange(const Node &N) const { + return TreeImpl->getSourceRange(N); +} + std::pair SyntaxTree::getSourceRangeOffsets(const Node &N) const { const SourceManager &SrcMgr = TreeImpl->AST.getSourceManager(); - SourceRange Range; - if (auto *Arg = N.ASTNode.get()) - Range = TreeImpl->TemplateArgumentLocations.at(&N - &TreeImpl->Nodes[0]); - else { - Range = N.ASTNode.getSourceRange(); - if (auto *ThisExpr = N.ASTNode.get()) - if (ThisExpr->isImplicit()) - Range.setEnd(Range.getBegin()); - } - Range = getSourceExtent(TreeImpl->AST, Range); + SourceRange Range = TreeImpl->getSourceRange(N); unsigned Begin = SrcMgr.getFileOffset(Range.getBegin()); unsigned End = SrcMgr.getFileOffset(Range.getEnd()); return {Begin, End}; @@ -1263,5 +1302,89 @@ return TreeImpl->getNodeValue(N); } +struct Patcher { + SyntaxTree::Impl &ModelSrc, &ModelDst, &Target; + const ComparisonOptions &Options; + raw_ostream &OS; + SourceManager &SrcMgr; + const LangOptions &LangOpts; + Replacements Replaces; + SyntaxTree ModelSrcCopy; + ASTDiff ModelDiff, ModelTargetDiff; + + Patcher(SyntaxTree &ModelSrc, SyntaxTree &ModelDst, SyntaxTree &Target, + const ComparisonOptions &Options, raw_ostream &OS) + : ModelSrc(*ModelSrc.TreeImpl), ModelDst(*ModelDst.TreeImpl), + Target(*Target.TreeImpl), Options(Options), OS(OS), + SrcMgr(this->Target.AST.getSourceManager()), + LangOpts(this->Target.AST.getLangOpts()), ModelSrcCopy(ModelSrc), + ModelDiff(ModelSrc, ModelDst, Options), + ModelTargetDiff(ModelSrcCopy, Target, Options) {} + + bool apply() { + addDeletions(); + Rewriter Rewrite(SrcMgr, LangOpts); + if (!applyAllReplacements(Replaces, Rewrite)) { + llvm::errs() << "failed to apply replacements\n"; + return false; + } + Rewrite.getEditBuffer(SrcMgr.getMainFileID()).write(OS); + return true; + } + +private: + void addDeletions() { + for (NodeId Id = ModelSrc.getRootId(), E = ModelSrc.getSize(); Id < E; + ++Id) { + const Node &ModelNode = ModelSrc.getNode(Id); + if (ModelNode.Change != Delete) + continue; + NodeId TargetId = ModelTargetDiff.getMapped(ModelSrcCopy, Id); + if (TargetId.isInvalid()) + continue; + Replacement R(SrcMgr, findRangeForDeletion(TargetId), "", LangOpts); + if (Replaces.add(R)) + llvm::errs() << "Info: Failed to add replacement.\n"; + Id = ModelNode.RightMostDescendant; + } + } + + CharSourceRange findRangeForDeletion(NodeId Id) { + const Node &N = Target.getNode(Id); + SourceRange Range = Target.getSourceRange(N); + if (N.Parent.isInvalid()) + return {Range, false}; + const Node &Parent = Target.getNode(N.Parent); + auto &DTN = Parent.ASTNode; + size_t SiblingIndex = Target.findPositionInParent(Id); + const auto &Siblings = Parent.Children; + // Remove the comma if the location is within a comma-separated list of at + // least size 2 (minus the callee for CallExpr). + if (DTN.get() && Siblings.size() > 2) { + bool LastSibling = SiblingIndex == Siblings.size() - 1; + SourceLocation CommaLoc = Range.getEnd(); + if (LastSibling) + CommaLoc = + Target.getSourceRange(Target.getNode(Siblings[SiblingIndex - 1])) + .getEnd() + .getLocWithOffset(-1); + CommaLoc = + Lexer::findLocationAfterToken(CommaLoc, tok::comma, SrcMgr, LangOpts, + /*SkipTrailingWhitespaceAndNewLine=*/ + false); + if (LastSibling) + Range.setBegin(CommaLoc.getLocWithOffset(-1)); + else + Range.setEnd(CommaLoc); + } + return {Range, false}; + } +}; + +bool patch(SyntaxTree &ModelSrc, SyntaxTree &ModelDst, SyntaxTree &Target, + const ComparisonOptions &Options, raw_ostream &OS) { + return Patcher(ModelSrc, ModelDst, Target, Options, OS).apply(); +} + } // end namespace diff } // end namespace clang Index: lib/Tooling/ASTDiff/CMakeLists.txt =================================================================== --- lib/Tooling/ASTDiff/CMakeLists.txt +++ lib/Tooling/ASTDiff/CMakeLists.txt @@ -8,4 +8,6 @@ clangBasic clangAST clangLex + clangRewrite + clangToolingCore ) Index: tools/clang-diff/CMakeLists.txt =================================================================== --- tools/clang-diff/CMakeLists.txt +++ tools/clang-diff/CMakeLists.txt @@ -9,6 +9,7 @@ target_link_libraries(clang-diff clangBasic clangFrontend + clangRewrite clangTooling clangToolingASTDiff ) Index: tools/clang-diff/ClangDiff.cpp =================================================================== --- tools/clang-diff/ClangDiff.cpp +++ tools/clang-diff/ClangDiff.cpp @@ -42,6 +42,12 @@ cl::desc("Output a side-by-side diff in HTML."), cl::init(false), cl::cat(ClangDiffCategory)); +static cl::opt + Patch("patch", + cl::desc("Try to apply the edit actions between the two input " + "files to the specified target."), + cl::desc(""), cl::cat(ClangDiffCategory)); + static cl::opt SourcePath(cl::Positional, cl::desc(""), cl::Required, cl::cat(ClangDiffCategory)); @@ -563,6 +569,16 @@ } diff::SyntaxTree SrcTree(*Src); diff::SyntaxTree DstTree(*Dst); + + if (!Patch.empty()) { + auto Target = getAST(CommonCompilations, Patch); + if (!Target) + return 1; + diff::SyntaxTree TargetTree(*Target); + diff::patch(SrcTree, DstTree, TargetTree, Options, llvm::outs()); + return 0; + } + diff::ASTDiff Diff(SrcTree, DstTree, Options); if (HtmlDiff) { Index: unittests/Tooling/ASTDiffTest.cpp =================================================================== --- /dev/null +++ unittests/Tooling/ASTDiffTest.cpp @@ -0,0 +1,85 @@ +//===- unittest/Tooling/ASTDiffTest.cpp -----------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "clang/Tooling/ASTDiff/ASTDiff.h" +#include "clang/Tooling/Tooling.h" +#include "gtest/gtest.h" + +using namespace clang; +using namespace tooling; + +static std::string patchResult(std::array Codes) { + diff::SyntaxTree Trees[3]; + std::unique_ptr ASTs[3]; + std::vector Args = {}; + for (int I = 0; I < 3; I++) { + ASTs[I] = buildASTFromCode(Codes[I]); + if (!ASTs[I]) { + llvm::errs() << "Failed to build AST from code:\n" << Codes[I] << "\n"; + return ""; + } + Trees[I] = diff::SyntaxTree(*ASTs[I]); + } + + diff::ComparisonOptions Options; + std::string TargetDstCode; + llvm::raw_string_ostream OS(TargetDstCode); + if (!diff::patch(/*ModelSrc=*/Trees[0], /*ModelDst=*/Trees[1], + /*TargetSrc=*/Trees[2], Options, OS)) + return ""; + return OS.str(); +} + +// abstract the EXPECT_EQ call so that the code snippets align properly +// use macros for this to make test failures have proper line numbers +#define PATCH(Preamble, ModelSrc, ModelDst, Target, Expected) \ + EXPECT_EQ(patchResult({{std::string(Preamble) + ModelSrc, \ + std::string(Preamble) + ModelDst, \ + std::string(Preamble) + Target}}), \ + std::string(Preamble) + Expected) + +TEST(ASTDiff, TestDeleteArguments) { + PATCH(R"(void printf(const char *, ...);)", + R"(void foo(int x) { printf("%d", x, x); })", + R"(void foo(int x) { printf("%d", x); })", + R"(void foo(int x) { printf("different string %d", x, x); })", + R"(void foo(int x) { printf("different string %d", x); })"); + + PATCH(R"(void foo(...);)", + R"(void test1() { foo ( 1 + 1); })", + R"(void test1() { foo ( ); })", + R"(void test2() { foo ( 1 + 1 ); })", + R"(void test2() { foo ( ); })"); + + PATCH(R"(void foo(...);)", + R"(void test1() { foo (1, 2 + 2); })", + R"(void test1() { foo (2 + 2); })", + R"(void test2() { foo (/*L*/ 0 /*R*/ , 2 + 2); })", + R"(void test2() { foo (/*L*/ 2 + 2); })"); + + PATCH(R"(void foo(...);)", + R"(void test1() { foo (1, 2); })", + R"(void test1() { foo (1); })", + R"(void test2() { foo (0, /*L*/ 0 /*R*/); })", + R"(void test2() { foo (0 /*R*/); })"); +} + +TEST(ASTDiff, TestDeleteDecls) { + PATCH(R"()", + R"()", + R"()", + R"()", + R"()"); + + PATCH(R"()", + R"(void foo(){})", + R"()", + R"(int x; void foo() {;;} int y;)", + R"(int x; int y;)"); +} Index: unittests/Tooling/CMakeLists.txt =================================================================== --- unittests/Tooling/CMakeLists.txt +++ unittests/Tooling/CMakeLists.txt @@ -11,6 +11,7 @@ endif() add_clang_unittest(ToolingTests + ASTDiffTest.cpp ASTSelectionTest.cpp CastExprTest.cpp CommentHandlerTest.cpp @@ -43,4 +44,5 @@ clangTooling clangToolingCore clangToolingRefactor + clangToolingASTDiff )