Index: include/clang/Tooling/ASTDiff/ASTDiff.h =================================================================== --- include/clang/Tooling/ASTDiff/ASTDiff.h +++ include/clang/Tooling/ASTDiff/ASTDiff.h @@ -163,6 +163,11 @@ /// this node, that is, none of its descendants includes them. SmallVector getOwnedSourceRanges() const; + /// This differs from getSourceRange() in the sense that the range is extended + /// to include the trailing comma if the node is within a comma-separated + /// list. + CharSourceRange findRangeForDeletion() const; + /// Returns the offsets for the range returned by getSourceRange(). std::pair getSourceRangeOffsets() const; Index: include/clang/Tooling/ASTDiff/ASTPatch.h =================================================================== --- /dev/null +++ include/clang/Tooling/ASTDiff/ASTPatch.h @@ -0,0 +1,49 @@ +//===- ASTPatch.h - Structural patching based on ASTDiff ------*- C++ -*- -===// +// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_TOOLING_ASTDIFF_ASTPATCH_H +#define LLVM_CLANG_TOOLING_ASTDIFF_ASTPATCH_H + +#include "clang/Tooling/ASTDiff/ASTDiff.h" +#include "clang/Tooling/Refactoring.h" +#include "llvm/Support/Error.h" + +namespace clang { +namespace diff { + +enum class patching_error { + failed_to_build_AST, + failed_to_apply_replacements, + failed_to_overwrite_files, +}; + +class PatchingError : public llvm::ErrorInfo { +public: + PatchingError(patching_error Err) : Err(Err) {} + std::string message() const override; + void log(raw_ostream &OS) const override { OS << message() << "\n"; } + patching_error get() const { return Err; } + static char ID; + +private: + std::error_code convertToErrorCode() const override { + return llvm::inconvertibleErrorCode(); + } + patching_error Err; +}; + +llvm::Error patch(tooling::RefactoringTool &TargetTool, SyntaxTree &Src, + SyntaxTree &Dst, const ComparisonOptions &Options, + bool Debug = false); + +} // end namespace diff +} // end namespace clang + +#endif // LLVM_CLANG_TOOLING_ASTDIFF_ASTPATCH_H Index: lib/Tooling/ASTDiff/ASTDiff.cpp =================================================================== --- lib/Tooling/ASTDiff/ASTDiff.cpp +++ lib/Tooling/ASTDiff/ASTDiff.cpp @@ -841,6 +841,37 @@ return SourceRanges; } +CharSourceRange Node::findRangeForDeletion() const { + CharSourceRange Range = getSourceRange(); + if (!getParent()) + return Range; + NodeRef Parent = *getParent(); + SyntaxTree &Tree = getTree(); + SourceManager &SM = Tree.getSourceManager(); + const LangOptions &LangOpts = Tree.getLangOpts(); + auto &DTN = ASTNode; + auto &ParentDTN = Parent.ASTNode; + size_t SiblingIndex = findPositionInParent(); + const auto &Siblings = Parent.Children; + // Remove the comma if the location is within a comma-separated list of + // at least size 2 (minus the callee for CallExpr). + if ((ParentDTN.get() && Siblings.size() > 2) || + (DTN.get() && Siblings.size() > 2)) { + bool LastSibling = SiblingIndex == Siblings.size() - 1; + SourceLocation CommaLoc; + if (LastSibling) { + CommaLoc = Parent.getChild(SiblingIndex - 1).getSourceRange().getEnd(); + Range.setBegin(CommaLoc); + } else { + Optional Comma = + Lexer::findNextToken(Range.getEnd(), SM, LangOpts); + if (Comma && Comma->is(tok::comma)) + Range.setEnd(Comma->getEndLoc()); + } + } + return Range; +} + void forEachTokenInRange(CharSourceRange Range, SyntaxTree &Tree, std::function Body) { SourceLocation Begin = Range.getBegin(), End = Range.getEnd(); Index: lib/Tooling/ASTDiff/ASTPatch.cpp =================================================================== --- /dev/null +++ lib/Tooling/ASTDiff/ASTPatch.cpp @@ -0,0 +1,582 @@ +//===- ASTPatch.cpp - Structural patching based on ASTDiff ----*- C++ -*- -===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "clang/Tooling/ASTDiff/ASTPatch.h" + +#include "clang/AST/DeclTemplate.h" +#include "clang/AST/ExprCXX.h" +#include "clang/Rewrite/Core/Rewriter.h" +#include "clang/Tooling/Core/Replacement.h" + +using namespace llvm; +using namespace clang; +using namespace tooling; + +namespace clang { +namespace diff { + +static Error error(patching_error code) { + return llvm::make_error(code); +}; + +static CharSourceRange makeEmptyCharRange(SourceLocation Point) { + return CharSourceRange::getCharRange(Point, Point); +} + +// Returns a comparison function that considers invalid source locations +// to be less than anything. +static std::function +makeTolerantLess(SourceManager &SM) { + return [&SM](SourceLocation A, SourceLocation B) { + if (A.isInvalid()) + return true; + if (B.isInvalid()) + return false; + BeforeThanCompare Less(SM); + return Less(A, B); + }; +} + +namespace { +// This wraps a node from Patcher::Target or Patcher::Dst. +class PatchedTreeNode { + NodeRef BaseNode; + +public: + operator NodeRef() const { return BaseNode; } + NodeRef originalNode() const { return *this; } + CharSourceRange getSourceRange() const { return BaseNode.getSourceRange(); } + NodeId getId() const { return BaseNode.getId(); } + SyntaxTree &getTree() const { return BaseNode.getTree(); } + StringRef getTypeLabel() const { return BaseNode.getTypeLabel(); } + decltype(BaseNode.getOwnedSourceRanges()) getOwnedSourceRanges() { + return BaseNode.getOwnedSourceRanges(); + } + + // This flag indicates whether this node, or any of its descendants was + // changed with regards to the original tree. + bool Changed = false; + // The pointers to the children, including nodes that have been inserted or + // moved here. + SmallVector Children; + // First location for each child. + SmallVector ChildrenLocations; + // The offsets at which the children should be inserted into OwnText. + SmallVector ChildrenOffsets; + + // This contains the text of this node, but not the text of it's children. + Optional OwnText; + + PatchedTreeNode(NodeRef BaseNode) : BaseNode(BaseNode) {} + PatchedTreeNode(const PatchedTreeNode &Other) = delete; + PatchedTreeNode(PatchedTreeNode &&Other) = default; + + void addInsertion(PatchedTreeNode &PatchedNode, SourceLocation InsertionLoc) { + addChildAt(PatchedNode, InsertionLoc); + } + void addChild(PatchedTreeNode &PatchedNode) { + SourceLocation InsertionLoc = PatchedNode.getSourceRange().getBegin(); + addChildAt(PatchedNode, InsertionLoc); + } + +private: + void addChildAt(PatchedTreeNode &PatchedNode, SourceLocation InsertionLoc) { + auto Less = makeTolerantLess(getTree().getSourceManager()); + auto It = std::lower_bound(ChildrenLocations.begin(), + ChildrenLocations.end(), InsertionLoc, Less); + auto Offset = It - ChildrenLocations.begin(); + Children.insert(Children.begin() + Offset, &PatchedNode); + ChildrenLocations.insert(It, InsertionLoc); + } +}; +} // end anonymous namespace + +namespace { +class Patcher { + SyntaxTree &Dst, &Target; + SourceManager &SM; + const LangOptions &LangOpts; + BeforeThanCompare Less; + ASTDiff Diff, TargetDiff; + RefactoringTool &TargetTool; + bool Debug; + std::vector PatchedTreeNodes; + std::map InsertedNodes; + // Maps NodeId in Dst to a flag that is true if this node is + // part of an inserted subtree. + std::vector AtomicInsertions; + +public: + Patcher(SyntaxTree &Src, SyntaxTree &Dst, SyntaxTree &Target, + const ComparisonOptions &Options, RefactoringTool &TargetTool, + bool Debug) + : Dst(Dst), Target(Target), SM(Target.getSourceManager()), + LangOpts(Target.getLangOpts()), Less(SM), Diff(Src, Dst, Options), + TargetDiff(Src, Target, Options), TargetTool(TargetTool), Debug(Debug) { + } + + Error apply(); + +private: + void buildPatchedTree(); + void addInsertedAndMovedNodes(); + SourceLocation findLocationForInsertion(NodeRef &InsertedNode, + PatchedTreeNode &InsertionTarget); + SourceLocation findLocationForMove(NodeRef DstNode, NodeRef TargetNode, + PatchedTreeNode &NewParent); + void markChangedNodes(); + Error addReplacementsForChangedNodes(); + Error addReplacementsForTopLevelChanges(); + + // Recursively builds the text that is represented by this subtree. + std::string buildSourceText(PatchedTreeNode &PatchedNode); + void setOwnedSourceText(PatchedTreeNode &PatchedNode); + std::pair + findPointOfInsertion(NodeRef N, PatchedTreeNode &TargetParent) const; + bool isInserted(const PatchedTreeNode &PatchedNode) const { + return isFromDst(PatchedNode); + } + ChangeKind getChange(NodeRef TargetNode) const { + if (!isFromTarget(TargetNode)) + return NoChange; + const Node *SrcNode = TargetDiff.getMapped(TargetNode); + if (!SrcNode) + return NoChange; + return Diff.getNodeChange(*SrcNode); + } + bool isRemoved(NodeRef TargetNode) const { + return getChange(TargetNode) == Delete; + } + bool isMoved(NodeRef TargetNode) const { + return getChange(TargetNode) == Move || getChange(TargetNode) == UpdateMove; + } + bool isRemovedOrMoved(NodeRef TargetNode) const { + return isRemoved(TargetNode) || isMoved(TargetNode); + } + PatchedTreeNode &findParent(NodeRef N) { + if (isFromDst(N)) + return findDstParent(N); + return findTargetParent(N); + } + PatchedTreeNode &findDstParent(NodeRef DstNode) { + const Node *SrcNode = Diff.getMapped(DstNode); + NodeRef DstParent = *DstNode.getParent(); + if (SrcNode) { + assert(Diff.getNodeChange(*SrcNode) == Insert); + const Node *TargetParent = mapDstToTarget(DstParent); + assert(TargetParent); + return getTargetPatchedNode(*TargetParent); + } + return getPatchedNode(DstParent); + } + PatchedTreeNode &findTargetParent(NodeRef TargetNode) { + assert(isFromTarget(TargetNode)); + const Node *SrcNode = TargetDiff.getMapped(TargetNode); + if (SrcNode) { + ChangeKind Change = Diff.getNodeChange(*SrcNode); + if (Change == Move || Change == UpdateMove) { + NodeRef DstNode = *Diff.getMapped(*SrcNode); + return getPatchedNode(*DstNode.getParent()); + } + } + return getTargetPatchedNode(*TargetNode.getParent()); + } + CharSourceRange getRangeForReplacing(NodeRef TargetNode) const { + if (isRemovedOrMoved(TargetNode)) + return TargetNode.findRangeForDeletion(); + return TargetNode.getSourceRange(); + } + Error addReplacement(Replacement &&R) { + return TargetTool.getReplacements()[R.getFilePath()].add(R); + } + bool isFromTarget(NodeRef N) const { return &N.getTree() == &Target; } + bool isFromDst(NodeRef N) const { return &N.getTree() == &Dst; } + PatchedTreeNode &getTargetPatchedNode(NodeRef N) { + assert(isFromTarget(N)); + return PatchedTreeNodes[N.getId()]; + } + PatchedTreeNode &getPatchedNode(NodeRef N) { + if (isFromDst(N)) + return *InsertedNodes.at(N.getId()); + return PatchedTreeNodes[N.getId()]; + } + const Node *mapDstToTarget(NodeRef DstNode) const { + const Node *SrcNode = Diff.getMapped(DstNode); + if (!SrcNode) + return nullptr; + return TargetDiff.getMapped(*SrcNode); + } + const Node *mapTargetToDst(NodeRef TargetNode) const { + const Node *SrcNode = TargetDiff.getMapped(TargetNode); + if (!SrcNode) + return nullptr; + return Diff.getMapped(*SrcNode); + } +}; +} // end anonymous namespace + +static void markBiggestSubtrees(std::vector &Marked, SyntaxTree &Tree, + llvm::function_ref Predicate) { + Marked.resize(Tree.getSize()); + for (NodeRef N : Tree.postorder()) { + bool AllChildrenMarked = + std::all_of(N.begin(), N.end(), + [&Marked](NodeRef Child) { return Marked[Child.getId()]; }); + Marked[N.getId()] = Predicate(N) && AllChildrenMarked; + } +} + +Error Patcher::apply() { + if (Debug) + Diff.dumpChanges(llvm::errs(), /*DumpMatches=*/true); + markBiggestSubtrees(AtomicInsertions, Dst, [this](NodeRef DstNode) { + return Diff.getNodeChange(DstNode) == Insert; + }); + buildPatchedTree(); + addInsertedAndMovedNodes(); + markChangedNodes(); + if (auto Err = addReplacementsForChangedNodes()) + return Err; + Rewriter Rewrite(SM, LangOpts); + if (!TargetTool.applyAllReplacements(Rewrite)) + return error(patching_error::failed_to_apply_replacements); + if (Rewrite.overwriteChangedFiles()) + // Some file has not been saved successfully. + return error(patching_error::failed_to_overwrite_files); + return Error::success(); +} + +static bool wantToInsertBefore(SourceLocation Insertion, SourceLocation Point, + BeforeThanCompare &Less) { + assert(Insertion.isValid()); + assert(Point.isValid()); + return Less(Insertion, Point); +} + +void Patcher::buildPatchedTree() { + // Firstly, add all nodes of the tree that will be patched to + // PatchedTreeNodes. This way, their offset (getId()) is the same as in the + // original tree. + PatchedTreeNodes.reserve(Target.getSize()); + for (NodeRef TargetNode : Target) + PatchedTreeNodes.emplace_back(TargetNode); + // Then add all inserted nodes, from Dst. + for (NodeId DstId = Dst.getRootId(), E = Dst.getSize(); DstId < E; ++DstId) { + NodeRef DstNode = Dst.getNode(DstId); + ChangeKind Change = Diff.getNodeChange(DstNode); + if (Change == Insert) { + PatchedTreeNodes.emplace_back(DstNode); + InsertedNodes.emplace(DstNode.getId(), &PatchedTreeNodes.back()); + // If the whole subtree is inserted, we can skip the children, as we + // will just copy the text of the entire subtree. + if (AtomicInsertions[DstId]) + DstId = DstNode.RightMostDescendant; + } + } + // Add existing children. + for (auto &PatchedNode : PatchedTreeNodes) { + if (isFromTarget(PatchedNode)) + for (auto &Child : PatchedNode.originalNode()) + if (!isRemovedOrMoved(Child)) + PatchedNode.addChild(getPatchedNode(Child)); + } +} + +void Patcher::addInsertedAndMovedNodes() { + ChangeKind Change = NoChange; + for (NodeId DstId = Dst.getRootId(), E = Dst.getSize(); DstId < E; + DstId = Change == Insert && AtomicInsertions[DstId] + ? Dst.getNode(DstId).RightMostDescendant + 1 + : DstId + 1) { + NodeRef DstNode = Dst.getNode(DstId); + Change = Diff.getNodeChange(DstNode); + if (!(Change == Move || Change == UpdateMove || Change == Insert)) + continue; + NodeRef DstParent = *DstNode.getParent(); + PatchedTreeNode *InsertionTarget, *NodeToInsert; + SourceLocation InsertionLoc; + if (Diff.getNodeChange(DstParent) == Insert) { + InsertionTarget = &getPatchedNode(DstParent); + } else { + const Node *TargetParent = mapDstToTarget(DstParent); + if (!TargetParent) + continue; + InsertionTarget = &getTargetPatchedNode(*TargetParent); + } + if (Change == Insert) { + NodeToInsert = &getPatchedNode(DstNode); + InsertionLoc = findLocationForInsertion(DstNode, *InsertionTarget); + } else { + assert(Change == Move || Change == UpdateMove); + const Node *TargetNode = mapDstToTarget(DstNode); + assert(TargetNode && "Node to update not found."); + NodeToInsert = &getTargetPatchedNode(*TargetNode); + InsertionLoc = + findLocationForMove(DstNode, *TargetNode, *InsertionTarget); + } + assert(InsertionLoc.isValid()); + InsertionTarget->addInsertion(*NodeToInsert, InsertionLoc); + } +} + +SourceLocation +Patcher::findLocationForInsertion(NodeRef DstNode, + PatchedTreeNode &InsertionTarget) { + assert(isFromDst(DstNode)); + assert(isFromDst(InsertionTarget) || isFromTarget(InsertionTarget)); + int ChildIndex; + bool RightOfChild; + unsigned NumChildren = InsertionTarget.Children.size(); + std::tie(ChildIndex, RightOfChild) = + findPointOfInsertion(DstNode, InsertionTarget); + if (NumChildren && ChildIndex != -1) { + auto NeighborRange = InsertionTarget.Children[ChildIndex]->getSourceRange(); + SourceLocation InsertionLocation = + RightOfChild ? NeighborRange.getEnd() : NeighborRange.getBegin(); + if (InsertionLocation.isValid()) + return InsertionLocation; + } + llvm_unreachable("Not implemented."); +} + +SourceLocation Patcher::findLocationForMove(NodeRef DstNode, NodeRef TargetNode, + PatchedTreeNode &NewParent) { + assert(isFromDst(DstNode)); + assert(isFromTarget(TargetNode)); + return DstNode.getSourceRange().getEnd(); +} + +void Patcher::markChangedNodes() { + for (auto Pair : InsertedNodes) { + NodeRef DstNode = Dst.getNode(Pair.first); + getPatchedNode(DstNode).Changed = true; + } + // Mark nodes in original as changed. + for (NodeRef TargetNode : Target.postorder()) { + auto &PatchedNode = PatchedTreeNodes[TargetNode.getId()]; + const Node *SrcNode = TargetDiff.getMapped(TargetNode); + if (!SrcNode) + continue; + ChangeKind Change = Diff.getNodeChange(*SrcNode); + auto &Children = PatchedNode.Children; + bool AnyChildChanged = + std::any_of(Children.begin(), Children.end(), + [](PatchedTreeNode *Child) { return Child->Changed; }); + bool AnyChildRemoved = std::any_of( + PatchedNode.originalNode().begin(), PatchedNode.originalNode().end(), + [this](NodeRef Child) { return isRemovedOrMoved(Child); }); + assert(!PatchedNode.Changed); + PatchedNode.Changed = + AnyChildChanged || AnyChildRemoved || Change != NoChange; + } +} + +Error Patcher::addReplacementsForChangedNodes() { + for (NodeId TargetId = Target.getRootId(), E = Target.getSize(); TargetId < E; + ++TargetId) { + NodeRef TargetNode = Target.getNode(TargetId); + auto &PatchedNode = getTargetPatchedNode(TargetNode); + if (!PatchedNode.Changed) + continue; + if (TargetId == Target.getRootId()) + return addReplacementsForTopLevelChanges(); + CharSourceRange Range = getRangeForReplacing(TargetNode); + std::string Text = + isRemovedOrMoved(PatchedNode) ? "" : buildSourceText(PatchedNode); + if (auto Err = addReplacement({SM, Range, Text, LangOpts})) + return Err; + TargetId = TargetNode.RightMostDescendant; + } + return Error::success(); +} + +Error Patcher::addReplacementsForTopLevelChanges() { + auto &Root = getTargetPatchedNode(Target.getRoot()); + for (unsigned I = 0, E = Root.Children.size(); I < E; ++I) { + PatchedTreeNode *Child = Root.Children[I]; + if (!Child->Changed) + continue; + std::string ChildText = buildSourceText(*Child); + CharSourceRange ChildRange; + if (isInserted(*Child) || isMoved(*Child)) { + SourceLocation InsertionLoc; + unsigned NumChildren = Root.Children.size(); + int ChildIndex; + bool RightOfChild; + std::tie(ChildIndex, RightOfChild) = findPointOfInsertion(*Child, Root); + if (NumChildren && ChildIndex != -1) { + auto NeighborRange = Root.Children[ChildIndex]->getSourceRange(); + InsertionLoc = + RightOfChild ? NeighborRange.getEnd() : NeighborRange.getBegin(); + } else { + InsertionLoc = SM.getLocForEndOfFile(SM.getMainFileID()) + .getLocWithOffset(-int(strlen("\n"))); + } + ChildRange = makeEmptyCharRange(InsertionLoc); + } else { + ChildRange = Child->getSourceRange(); + } + if (auto Err = addReplacement({SM, ChildRange, ChildText, LangOpts})) { + return Err; + } + } + for (NodeRef Child : Root.originalNode()) { + if (isRemovedOrMoved(Child)) { + auto ChildRange = Child.findRangeForDeletion(); + if (auto Err = addReplacement({SM, ChildRange, "", LangOpts})) + return Err; + } + } + return Error::success(); +} + +static StringRef trailingText(SourceLocation Loc, SyntaxTree &Tree) { + Token NextToken; + bool Failure = Lexer::getRawToken(Loc, NextToken, Tree.getSourceManager(), + Tree.getLangOpts(), + /*IgnoreWhiteSpace=*/true); + if (Failure) + return StringRef(); + assert(!Failure); + return Lexer::getSourceText( + CharSourceRange::getCharRange({Loc, NextToken.getLocation()}), + Tree.getSourceManager(), Tree.getLangOpts()); +} + +std::string Patcher::buildSourceText(PatchedTreeNode &PatchedNode) { + auto &Children = PatchedNode.Children; + auto &ChildrenOffsets = PatchedNode.ChildrenOffsets; + auto &OwnText = PatchedNode.OwnText; + auto Range = PatchedNode.getSourceRange(); + SyntaxTree &Tree = PatchedNode.getTree(); + SourceManager &MySM = Tree.getSourceManager(); + const LangOptions &MyLangOpts = Tree.getLangOpts(); + assert(!isRemoved(PatchedNode)); + if (!PatchedNode.Changed || + (isFromDst(PatchedNode) && AtomicInsertions[PatchedNode.getId()])) { + std::string Text = Lexer::getSourceText(Range, MySM, MyLangOpts); + // TODO why + if (!isFromDst(PatchedNode)) + Text += trailingText(Range.getEnd(), Tree); + return Text; + } + setOwnedSourceText(PatchedNode); + std::string Result; + unsigned Offset = 0; + assert(ChildrenOffsets.size() == Children.size()); + for (unsigned I = 0, E = Children.size(); I < E; ++I) { + PatchedTreeNode *Child = Children[I]; + unsigned Start = ChildrenOffsets[I]; + Result += OwnText->substr(Offset, Start - Offset); + Result += buildSourceText(*Child); + Offset = Start; + } + assert(Offset <= OwnText->size()); + Result += OwnText->substr(Offset, OwnText->size() - Offset); + return Result; +} + +void Patcher::setOwnedSourceText(PatchedTreeNode &PatchedNode) { + assert(isFromTarget(PatchedNode) || isFromDst(PatchedNode)); + SyntaxTree &Tree = PatchedNode.getTree(); + const Node *SrcNode = nullptr; + bool IsUpdate = false; + auto &OwnText = PatchedNode.OwnText; + auto &Children = PatchedNode.Children; + auto &ChildrenLocations = PatchedNode.ChildrenLocations; + auto &ChildrenOffsets = PatchedNode.ChildrenOffsets; + OwnText = ""; + unsigned NumChildren = Children.size(); + if (isFromTarget(PatchedNode)) { + SrcNode = TargetDiff.getMapped(PatchedNode); + ChangeKind Change = SrcNode ? Diff.getNodeChange(*SrcNode) : NoChange; + IsUpdate = Change == Update || Change == UpdateMove; + } + unsigned ChildIndex = 0; + auto MySourceRanges = PatchedNode.getOwnedSourceRanges(); + BeforeThanCompare MyLess(Tree.getSourceManager()); + for (auto &MySubRange : MySourceRanges) { + SourceLocation ChildBegin; + SourceLocation InsertionBegin; + while (ChildIndex < NumChildren && + ((ChildBegin = ChildrenLocations[ChildIndex]).isInvalid() || + wantToInsertBefore(ChildBegin, MySubRange.getEnd(), MyLess))) { + ChildrenOffsets.push_back(OwnText->size()); + ++ChildIndex; + } + if (IsUpdate) { + llvm_unreachable("Not implemented."); + } else { + *OwnText += Lexer::getSourceText(MySubRange, Tree.getSourceManager(), + Tree.getLangOpts()); + } + } + while (ChildIndex++ < NumChildren) + ChildrenOffsets.push_back(OwnText->size()); +} + +std::pair +Patcher::findPointOfInsertion(NodeRef N, PatchedTreeNode &TargetParent) const { + assert(isFromDst(N) || isFromTarget(N)); + assert(isFromTarget(TargetParent)); + auto MapFunction = [this, &N](PatchedTreeNode &Sibling) { + if (isFromDst(N) == isFromDst(Sibling)) + return &NodeRef(Sibling); + if (isFromDst(N)) + return mapTargetToDst(Sibling); + else + return mapDstToTarget(Sibling); + }; + unsigned NumChildren = TargetParent.Children.size(); + BeforeThanCompare Less(N.getTree().getSourceManager()); + auto NodeIndex = N.findPositionInParent(); + SourceLocation MyLoc = N.getSourceRange().getBegin(); + assert(MyLoc.isValid()); + for (unsigned I = 0; I < NumChildren; ++I) { + const Node *Sibling = MapFunction(*TargetParent.Children[I]); + if (!Sibling) + continue; + SourceLocation SiblingLoc = Sibling->getSourceRange().getBegin(); + if (SiblingLoc.isInvalid()) + continue; + if (NodeIndex && Sibling == &N.getParent()->getChild(NodeIndex - 1)) { + return {I, /*RightOfSibling=*/true}; + } + if (Less(MyLoc, SiblingLoc)) { + return {I, /*RightOfSibling=*/false}; + } + } + return {-1, true}; +} + +Error patch(RefactoringTool &TargetTool, SyntaxTree &Src, SyntaxTree &Dst, + const ComparisonOptions &Options, bool Debug) { + std::vector> TargetASTs; + TargetTool.buildASTs(TargetASTs); + if (TargetASTs.size() != 1) + return error(patching_error::failed_to_build_AST); + SyntaxTree Target(*TargetASTs[0]); + return Patcher(Src, Dst, Target, Options, TargetTool, Debug).apply(); +} + +std::string PatchingError::message() const { + switch (Err) { + case patching_error::failed_to_build_AST: + return "Failed to build AST.\n"; + case patching_error::failed_to_apply_replacements: + return "Failed to apply replacements.\n"; + case patching_error::failed_to_overwrite_files: + return "Failed to overwrite some file(s).\n"; + }; +} + +char PatchingError::ID = 1; + +} // end namespace diff +} // end namespace clang Index: lib/Tooling/ASTDiff/CMakeLists.txt =================================================================== --- lib/Tooling/ASTDiff/CMakeLists.txt +++ lib/Tooling/ASTDiff/CMakeLists.txt @@ -4,8 +4,13 @@ add_clang_library(clangToolingASTDiff ASTDiff.cpp + ASTPatch.cpp LINK_LIBS clangBasic clangAST clangLex + clangRewrite + clangFrontend + clangTooling + clangToolingCore ) Index: test/Tooling/clang-diff-patch.test =================================================================== --- /dev/null +++ test/Tooling/clang-diff-patch.test @@ -0,0 +1,9 @@ +// compare the file with an empty file, patch it to remove all code +RUN: rm -rf %t && mkdir -p %t +RUN: cp %S/clang-diff-ast.cpp %t +RUN: echo > %t/dst.cpp +RUN: clang-diff %t/clang-diff-ast.cpp %t/dst.cpp \ +RUN: -patch %t/clang-diff-ast.cpp -- -std=c++11 +// the resulting file should not contain anything other than comments and +// whitespace +RUN: cat %t/clang-diff-ast.cpp | grep -v '^#' | grep -v '^\s*//' | not grep -v '^\s*$' Index: tools/clang-diff/CMakeLists.txt =================================================================== --- tools/clang-diff/CMakeLists.txt +++ tools/clang-diff/CMakeLists.txt @@ -9,6 +9,8 @@ target_link_libraries(clang-diff clangBasic clangFrontend + clangRewrite clangTooling + clangToolingCore clangToolingASTDiff ) Index: tools/clang-diff/ClangDiff.cpp =================================================================== --- tools/clang-diff/ClangDiff.cpp +++ tools/clang-diff/ClangDiff.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "clang/Tooling/ASTDiff/ASTDiff.h" +#include "clang/Tooling/ASTDiff/ASTPatch.h" #include "clang/Tooling/CommonOptionsParser.h" #include "clang/Tooling/Tooling.h" #include "llvm/Support/CommandLine.h" @@ -41,6 +42,12 @@ cl::desc("Output a side-by-side diff in HTML."), cl::init(false), cl::cat(ClangDiffCategory)); +static cl::opt + FileToPatch("patch", + cl::desc("Try to apply the edit actions between the two input " + "files to the specified target."), + cl::desc(""), cl::cat(ClangDiffCategory)); + static cl::opt SourcePath(cl::Positional, cl::desc(""), cl::Required, cl::cat(ClangDiffCategory)); @@ -453,6 +460,24 @@ } diff::SyntaxTree SrcTree(*Src); diff::SyntaxTree DstTree(*Dst); + + if (!FileToPatch.empty()) { + std::array Files = {{FileToPatch}}; + RefactoringTool TargetTool(CommonCompilations + ? *CommonCompilations + : *getCompilationDatabase(FileToPatch), + Files); + if (auto Err = diff::patch(TargetTool, SrcTree, DstTree, Options)) { + llvm::handleAllErrors( + std::move(Err), + [](const diff::PatchingError &PE) { PE.log(llvm::errs()); }, + [](const ReplacementError &RE) { RE.log(llvm::errs()); }); + llvm::errs() << "*** errors occured, patching failed.\n"; + return 1; + } + return 0; + } + diff::ASTDiff Diff(SrcTree, DstTree, Options); if (HtmlDiff) { Index: unittests/Tooling/ASTPatchTest.cpp =================================================================== --- /dev/null +++ unittests/Tooling/ASTPatchTest.cpp @@ -0,0 +1,265 @@ +//===- unittest/Tooling/ASTPatchTest.cpp ----------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "clang/Tooling/ASTDiff/ASTPatch.h" +#include "clang/Tooling/ASTDiff/ASTDiff.h" +#include "clang/Tooling/Tooling.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Program.h" +#include "gtest/gtest.h" +#include + +using namespace clang; +using namespace tooling; + +std::string ReadShellCommand(const Twine &Command) { + char Buffer[128]; + std::string Result; + std::shared_ptr Pipe(popen(Command.str().data(), "r"), pclose); + if (!Pipe) + return Result; + while (!feof(Pipe.get())) { + if (fgets(Buffer, 128, Pipe.get()) != nullptr) + Result += Buffer; + } + return Result; +} + +class ASTPatchTest : public ::testing::Test { + llvm::SmallString<256> TargetFile, ExpectedFile; + std::array TargetFileArray; + +public: + void SetUp() override { + std::string Suffix = "cpp"; + ASSERT_FALSE(llvm::sys::fs::createTemporaryFile( + "clang-libtooling-patch-target", Suffix, TargetFile)); + ASSERT_FALSE(llvm::sys::fs::createTemporaryFile( + "clang-libtooling-patch-expected", Suffix, ExpectedFile)); + TargetFileArray[0] = TargetFile.str(); + } + void TearDown() override { + llvm::sys::fs::remove(TargetFile); + llvm::sys::fs::remove(ExpectedFile); + } + + void WriteFile(StringRef Filename, StringRef Contents) { + std::ofstream OS(Filename); + OS << Contents.str(); + assert(OS.good()); + } + + std::string ReadFile(StringRef Filename) { + std::ifstream IS(Filename); + std::stringstream OS; + OS << IS.rdbuf(); + assert(IS.good()); + return OS.str(); + } + + std::string formatExpected(StringRef Code) { + WriteFile(ExpectedFile, Code); + return ReadShellCommand("clang-format " + ExpectedFile); + } + + llvm::Expected patchResult(const char *SrcCode, + const char *DstCode, + const char *TargetCode) { + std::unique_ptr SrcAST = buildASTFromCode(SrcCode), + DstAST = buildASTFromCode(DstCode); + if (!SrcAST || !DstAST) { + if (!SrcAST) + llvm::errs() << "Failed to build AST from code:\n" << SrcCode << "\n"; + if (!DstAST) + llvm::errs() << "Failed to build AST from code:\n" << DstCode << "\n"; + return llvm::make_error( + diff::patching_error::failed_to_build_AST); + } + + diff::SyntaxTree Src(*SrcAST); + diff::SyntaxTree Dst(*DstAST); + + WriteFile(TargetFile, TargetCode); + FixedCompilationDatabase Compilations(".", std::vector()); + RefactoringTool TargetTool(Compilations, TargetFileArray); + diff::ComparisonOptions Options; + + if (auto Err = diff::patch(TargetTool, Src, Dst, Options, /*Debug=*/false)) + return std::move(Err); + return ReadShellCommand("clang-format " + TargetFile); + } + +#define APPEND_NEWLINE(x) x "\n" +// use macros for this to make test failures have proper line numbers +#define PATCH(Src, Dst, Target, ExpectedResult) \ + { \ + llvm::Expected Result = patchResult( \ + APPEND_NEWLINE(Src), APPEND_NEWLINE(Dst), APPEND_NEWLINE(Target)); \ + ASSERT_TRUE(bool(Result)); \ + EXPECT_EQ(Result.get(), formatExpected(APPEND_NEWLINE(ExpectedResult))); \ + } +#define PATCH_ERROR(Src, Dst, Target, ErrorCode) \ + { \ + llvm::Expected Result = patchResult(Src, Dst, Target); \ + ASSERT_FALSE(bool(Result)); \ + llvm::handleAllErrors(Result.takeError(), \ + [&](const diff::PatchingError &PE) { \ + EXPECT_EQ(PE.get(), ErrorCode); \ + }); \ + } +}; + +TEST_F(ASTPatchTest, Delete) { + PATCH(R"(void f() { { int x = 1; } })", + R"(void f() { })", + R"(void f() { { int x = 2; } })", + R"(void f() { })"); + PATCH(R"(void foo(){})", + R"()", + R"(int x; void foo() {;;} int y;)", + R"(int x; int y;)"); +} +TEST_F(ASTPatchTest, DeleteCallArguments) { + PATCH(R"(void foo(...); void test1() { foo ( 1 + 1); })", + R"(void foo(...); void test1() { foo ( ); })", + R"(void foo(...); void test2() { foo ( 1 + 1 ); })", + R"(void foo(...); void test2() { foo ( ); })"); +} +TEST_F(ASTPatchTest, DeleteParmVarDecl) { + PATCH(R"(void foo(int a);)", + R"(void foo();)", + R"(void bar(int x);)", + R"(void bar();)"); +} +TEST_F(ASTPatchTest, Insert) { + PATCH(R"(class C { C() {} };)", + R"(class C { int b; C() {} };)", + R"(class C { int c; C() {} };)", + R"(class C { int c;int b; C() {} };)"); + PATCH(R"(class C { C() {} };)", + R"(class C { int b; C() {} };)", + R"(class C { C() {} };)", + R"(class C { int b; C() {} };)"); + PATCH(R"(class C { int x; };)", + R"(class C { int x;int b; };)", + R"(class C { int x ;int c; };)", + R"(class C { int x;int b;int c; };)"); + PATCH(R"(class C { int x; };)", + R"(class C { int x;int b; };)", + R"(class C { int x; int c; };)", + R"(class C { int x;int b;int c; };)"); + PATCH(R"(class C { int x; };)", + R"(class C { int x;int b; };)", + R"(class C { int x;int c; };)", + R"(class C { int x;int b;int c; };)"); + PATCH(R"(int a;)", + R"(int a; int x();)", + R"(int a;)", + R"(int a; int x();)"); + PATCH(R"(int a; int b;)", + R"(int a; int x; int b;)", + R"(int a; int b;)", + R"(int a; int x; int b;)"); + PATCH(R"(int b;)", + R"(int x; int b;)", + R"(int b;)", + R"(int x; int b;)"); + PATCH(R"(void f() { int x = 1 + 1; })", + R"(void f() { { int x = 1 + 1; } })", + R"(void f() { int x = 1 + 1; })", + R"(void f() { { int x = 1 + 1; } })"); +} +TEST_F(ASTPatchTest, InsertNoParent) { + PATCH(R"(void f() { })", + R"(void f() { int x; })", + R"()", + R"()"); +} +TEST_F(ASTPatchTest, InsertTopLevel) { + PATCH(R"(namespace a {})", + R"(namespace a {} void x();)", + R"(namespace a {})", + R"(namespace a {} void x();)"); +} +TEST_F(ASTPatchTest, Move) { + PATCH(R"(namespace a { void f(){} })", + R"(namespace a {} void f(){} )", + R"(namespace a { void f(){} })", + R"(namespace a {} void f(){} )"); + PATCH(R"(namespace a { void f(){} } int x;)", + R"(namespace a {} void f(){} int x;)", + R"(namespace a { void f(){} } int x;)", + R"(namespace a {} void f(){} int x;)"); + PATCH(R"(namespace a { namespace { } })", + R"(namespace a { })", + R"(namespace a { namespace { } })", + R"(namespace a { })"); + PATCH(R"(namespace { int x = 1 + 1; })", + R"(namespace { int x = 1 + 1; int y;})", + R"(namespace { int x = 1 + 1; })", + R"(namespace { int x = 1 + 1; int y;})"); + PATCH(R"(namespace { int y; int x = 1 + 1; })", + R"(namespace { int x = 1 + 1; int y; })", + R"(namespace { int y; int x = 1 + 1; })", + R"(namespace { int x = 1 + 1; int y; })"); + PATCH(R"(void f() { ; int x = 1 + 1; })", + R"(void f() { int x = 1 + 1; ; })", + R"(void f() { ; int x = 1 + 1; })", + R"(void f() { int x = 1 + 1; ; })"); + PATCH(R"(void f() { {{;;;}} })", + R"(void f() { {{{;;;}}} })", + R"(void f() { {{;;;}} })", + R"(void f() { {{{;;;}}} })"); +} +TEST_F(ASTPatchTest, MoveNoSource) { + PATCH(R"(void f() { })", + R"(void f() { int x; })", + R"()", + R"()"); +} +TEST_F(ASTPatchTest, MoveNoTarget) { + PATCH(R"(int x; void f() { })", + R"(void f() { int x; })", + R"(int x;)", + R"()"); +} +TEST_F(ASTPatchTest, Newline) { + PATCH(R"(void f(){ +; +})", + R"(void f(){ +; +int x; +})", + R"(void f(){ +; +})", + R"(void f(){ +; +int x; +})"); +} +TEST_F(ASTPatchTest, Nothing) { + PATCH(R"()", + R"()", + R"()", + R"()"); +} +TEST_F(ASTPatchTest, Update) { + PATCH(R"(class A { int x; };)", + R"(class A { int x; };)", + R"(class A { int y; };)", + R"(class A { int y; };)"); +} +TEST_F(ASTPatchTest, UpdateMove) { + PATCH(R"(void f() { { int x = 1; } })", + R"(void f() { })", + R"(void f() { { int x = 2; } })", + R"(void f() { })"); +} Index: unittests/Tooling/CMakeLists.txt =================================================================== --- unittests/Tooling/CMakeLists.txt +++ unittests/Tooling/CMakeLists.txt @@ -11,6 +11,7 @@ endif() add_clang_unittest(ToolingTests + ASTPatchTest.cpp ASTSelectionTest.cpp CastExprTest.cpp CommentHandlerTest.cpp @@ -45,4 +46,5 @@ clangTooling clangToolingCore clangToolingRefactor + clangToolingASTDiff )