Index: include/clang/Tooling/RefactoringCallbacks.h =================================================================== --- include/clang/Tooling/RefactoringCallbacks.h +++ include/clang/Tooling/RefactoringCallbacks.h @@ -47,6 +47,32 @@ Replacements Replace; }; +/// \brief Adaptor between \c ast_matchers::MatchFinder and \c +/// tooling::RefactoringTool. +/// +/// Runs AST matchers and stores the \c tooling::Replacements in a map. +class ASTMatchRefactorer { +public: + ASTMatchRefactorer(std::map &FileToReplaces); + + template + void addMatcher(const T &Matcher, RefactoringCallback *Callback) { + MatchFinder.addMatcher(Matcher, Callback); + Callbacks.emplace_back(Callback); + } + + void addDynamicMatcher(const ast_matchers::inetrnal::DynTypedMatcher &Matcher, + RefactoringCallback *Callback); + + std::unique_ptr newASTConsumer(); + +private: + friend class RefactoringASTConsumer; + std::vector Callbacks; + ast_matchers::MatchFinder MatchFinder; + std::map &FileToReplaces; +}; + /// \brief Replace the text of the statement bound to \c FromId with the text in /// \c ToText. class ReplaceStmtWithText : public RefactoringCallback { @@ -59,6 +85,22 @@ std::string ToText; }; +/// \brief Replace the text of an AST node bound to \c FromId with the result of +/// evaluating the template in \c ToTemplate. +/// +/// Expressions of the form ${NodeName} in \c ToTemplate will be +/// replaced by the text of the node bound to ${NodeName}. The string +/// "$$" will be replaced by "$". +class ReplaceNodeWithTemplate : public RefactoringCallback { +public: + ReplaceNodeWithTemplate(StringRef FromId, StringRef ToTemplate); + void run(const ast_matchers::MatchFinder::MatchResult &Result) override; + +private: + std::string FromId; + std::string ToTemplate; +}; + /// \brief Replace the text of the statement bound to \c FromId with the text of /// the statement bound to \c ToId. class ReplaceStmtWithStmt : public RefactoringCallback { @@ -84,7 +126,7 @@ const bool PickTrueBranch; }; -} // end namespace tooling -} // end namespace clang +} // end namespace tooling +} // end namespace clang #endif Index: lib/Tooling/RefactoringCallbacks.cpp =================================================================== --- lib/Tooling/RefactoringCallbacks.cpp +++ lib/Tooling/RefactoringCallbacks.cpp @@ -9,8 +9,10 @@ // // //===----------------------------------------------------------------------===// -#include "clang/Lex/Lexer.h" #include "clang/Tooling/RefactoringCallbacks.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/Basic/SourceLocation.h" +#include "clang/Lex/Lexer.h" namespace clang { namespace tooling { @@ -20,18 +22,60 @@ return Replace; } -static Replacement replaceStmtWithText(SourceManager &Sources, - const Stmt &From, +ASTMatchRefactorer::ASTMatchRefactorer( + std::map &FileToReplaces) + : FileToReplaces(FileToReplaces) {} + +void ASTMatchRefactorer::addDynamicMatcher( + const ast_matchers::internal::DynTypedMatcher &Matcher, + RefactoringCallback *Callback) { + MatchFinder.addDynamicMatcher(Matcher, Callback); + Callbacks.emplace_back(Callback); +} + +class RefactoringASTConsumer : public ASTConsumer { +public: + RefactoringASTConsumer(ASTMatchRefactorer &Refactoring) + : Refactoring(Refactoring) {} + + void HandleTranslationUnit(ASTContext &Context) override { + for (const auto &Callback : Refactoring.Callbacks) { + Callback->getReplacements().clear(); + } + Refactoring.MatchFinder.matchAST(Context); + for (const auto &Callback : Refactoring.Callbacks) { + for (const auto &Replacement : Callback->getReplacements()) { + llvm::Error Err = + Refactoring.FileToReplaces[Replacement.getFilePath()].add( + Replacement); + if (Err) { + llvm::errs() << "Skipping replacement " << Replacement.toString() + << " due to this error:\n" + << toString(std::move(Err)) << "\n"; + } + } + } + } + +private: + ASTMatchRefactorer &Refactoring; +}; + +std::unique_ptr ASTMatchRefactorer::newASTConsumer() { + return llvm::make_unique(*this); +} + +static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From, StringRef Text) { - return tooling::Replacement(Sources, CharSourceRange::getTokenRange( - From.getSourceRange()), Text); + return tooling::Replacement( + Sources, CharSourceRange::getTokenRange(From.getSourceRange()), Text); } -static Replacement replaceStmtWithStmt(SourceManager &Sources, - const Stmt &From, +static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From, const Stmt &To) { - return replaceStmtWithText(Sources, From, Lexer::getSourceText( - CharSourceRange::getTokenRange(To.getSourceRange()), - Sources, LangOptions())); + return replaceStmtWithText( + Sources, From, + Lexer::getSourceText(CharSourceRange::getTokenRange(To.getSourceRange()), + Sources, LangOptions())); } ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText) @@ -103,5 +147,66 @@ } } -} // end namespace tooling -} // end namespace clang +ReplaceNodeWithTemplate::ReplaceNodeWithTemplate(StringRef FromId, + StringRef ToTemplate) + : FromId(FromId), ToTemplate(ToTemplate) {} + +void ReplaceNodeWithTemplate::run( + const ast_matchers::MatchFinder::MatchResult &Result) { + const auto &NodeMap = Result.Nodes.getMap(); + + std::string ToText; + for (size_t Index = 0; Index < ToTemplate.size();) { + if (ToTemplate[Index] == '$') { + if (ToTemplate.substr(Index, 2) == "$$") { + Index += 2; + ToText += "$"; + continue; + } else if (ToTemplate.substr(Index, 2) == "${") { + size_t EndOfIdentifier = ToTemplate.find("}", Index); + if (EndOfIdentifier == std::string::npos) { + llvm::errs() << "Unterminated ${...} in replacement template near " + << ToTemplate.substr(Index) << "\n"; + assert(false); + } + std::string SourceNodeName = + ToTemplate.substr(Index + 2, EndOfIdentifier - Index - 2); + if (NodeMap.count(SourceNodeName) == 0) { + llvm::errs() + << "Node " << SourceNodeName + << " used in replacement template not bound in Matcher \n"; + assert(false); + } + CharSourceRange Source = CharSourceRange::getTokenRange( + NodeMap.at(SourceNodeName).getSourceRange()); + ToText += Lexer::getSourceText(Source, *Result.SourceManager, + Result.Context->getLangOpts()); + Index = EndOfIdentifier + 1; + } else { + llvm::errs() << "Invalid $ in replacement template near " + << ToTemplate.substr(Index) << "\n"; + assert(false); + } + } else { + size_t NextIndex = ToTemplate.find('$', Index + 1); + ToText = ToText + ToTemplate.substr(Index, NextIndex - Index); + Index = NextIndex; + } + } + if (NodeMap.count(FromId) == 0) { + llvm::errs() << "Node to be replaced " << FromId << " not bound in query\n"; + assert(false); + } + auto Replacement = + tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText, + Result.Context->getLangOpts()); + llvm::Error Err = Replace.add(Replacement); + if (Err) { + llvm::errs() << "Query and replace failed in " << Replacement.getFilePath() + << "! " << llvm::toString(std::move(Err)) << "\n"; + assert(false); + } +} + +} // end namespace tooling +} // end namespace clang Index: unittests/Tooling/RefactoringCallbacksTest.cpp =================================================================== --- unittests/Tooling/RefactoringCallbacksTest.cpp +++ unittests/Tooling/RefactoringCallbacksTest.cpp @@ -7,10 +7,10 @@ // //===----------------------------------------------------------------------===// -#include "clang/Tooling/RefactoringCallbacks.h" #include "RewriterTestContext.h" #include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/ASTMatchers/ASTMatchers.h" +#include "clang/Tooling/RefactoringCallbacks.h" #include "gtest/gtest.h" namespace clang { @@ -19,11 +19,10 @@ using namespace ast_matchers; template -void expectRewritten(const std::string &Code, - const std::string &Expected, - const T &AMatcher, - RefactoringCallback &Callback) { - MatchFinder Finder; +void expectRewritten(const std::string &Code, const std::string &Expected, + const T &AMatcher, RefactoringCallback &Callback) { + std::map FileToReplace; + ASTMatchRefactorer Finder(FileToReplace); Finder.addMatcher(AMatcher, &Callback); std::unique_ptr Factory( tooling::newFrontendActionFactory(&Finder)); @@ -31,7 +30,7 @@ << "Parsing error in \"" << Code << "\""; RewriterTestContext Context; FileID ID = Context.createInMemoryFile("input.cc", Code); - EXPECT_TRUE(tooling::applyAllReplacements(Callback.getReplacements(), + EXPECT_TRUE(tooling::applyAllReplacements(FileToReplace["input.cc"], Context.Rewrite)); EXPECT_EQ(Expected, Context.getRewrittenText(ID)); } @@ -61,18 +60,18 @@ std::string Code = "void f() { int i = 1; }"; std::string Expected = "void f() { int i = 2; }"; ReplaceStmtWithText Callback("id", "2"); - expectRewritten(Code, Expected, id("id", expr(integerLiteral())), - Callback); + expectRewritten(Code, Expected, id("id", expr(integerLiteral())), Callback); } TEST(RefactoringCallbacksTest, ReplacesStmtWithStmt) { std::string Code = "void f() { int i = false ? 1 : i * 2; }"; std::string Expected = "void f() { int i = i * 2; }"; ReplaceStmtWithStmt Callback("always-false", "should-be"); - expectRewritten(Code, Expected, - id("always-false", conditionalOperator( - hasCondition(cxxBoolLiteral(equals(false))), - hasFalseExpression(id("should-be", expr())))), + expectRewritten( + Code, Expected, + id("always-false", + conditionalOperator(hasCondition(cxxBoolLiteral(equals(false))), + hasFalseExpression(id("should-be", expr())))), Callback); } @@ -80,10 +79,10 @@ std::string Code = "bool a; void f() { if (a) f(); else a = true; }"; std::string Expected = "bool a; void f() { f(); }"; ReplaceIfStmtWithItsBody Callback("id", true); - expectRewritten(Code, Expected, - id("id", ifStmt( - hasCondition(implicitCastExpr(hasSourceExpression( - declRefExpr(to(varDecl(hasName("a"))))))))), + expectRewritten( + Code, Expected, + id("id", ifStmt(hasCondition(implicitCastExpr(hasSourceExpression( + declRefExpr(to(varDecl(hasName("a"))))))))), Callback); } @@ -92,9 +91,34 @@ std::string Expected = "void f() { }"; ReplaceIfStmtWithItsBody Callback("id", false); expectRewritten(Code, Expected, - id("id", ifStmt(hasCondition(cxxBoolLiteral(equals(false))))), - Callback); + id("id", ifStmt(hasCondition(cxxBoolLiteral(equals(false))))), + Callback); +} + +TEST(RefactoringCallbacksTest, TemplateJustText) { + std::string Code = "void f() { int i = 1; }"; + std::string Expected = "void f() { FOO }"; + ReplaceNodeWithTemplate Callback("id", "FOO"); + expectRewritten(Code, Expected, id("id", declStmt()), Callback); +} + +TEST(RefactoringCallbacksTest, TemplateSimpleSubst) { + std::string Code = "void f() { int i = 1; }"; + std::string Expected = "void f() { long x = 1; }"; + ReplaceNodeWithTemplate Callback("decl", "long x = ${init}"); + expectRewritten(Code, Expected, + id("decl", varDecl(hasInitializer(id("init", expr())))), + Callback); +} + +TEST(RefactoringCallbacksTest, TemplateLiteral) { + std::string Code = "void f() { int i = 1; }"; + std::string Expected = "void f() { string x = \"$-1\"; }"; + ReplaceNodeWithTemplate Callback("decl", "string x = \"$$-${init}\""); + expectRewritten(Code, Expected, + id("decl", varDecl(hasInitializer(id("init", expr())))), + Callback); } -} // end namespace ast_matchers -} // end namespace clang +} // end namespace ast_matchers +} // end namespace clang