diff --git a/clang-tools-extra/clang-tidy/readability/DeleteNullPointerCheck.h b/clang-tools-extra/clang-tidy/readability/DeleteNullPointerCheck.h --- a/clang-tools-extra/clang-tidy/readability/DeleteNullPointerCheck.h +++ b/clang-tools-extra/clang-tidy/readability/DeleteNullPointerCheck.h @@ -10,6 +10,7 @@ #define LLVM_CLANG_TOOLS_EXTRA_CLANG_TIDY_READABILITY_DELETE_NULL_POINTER_H #include "../ClangTidy.h" +#include "../utils/TransformerTidy.h" namespace clang { namespace tidy { @@ -20,12 +21,12 @@ /// /// For the user-facing documentation see: /// http://clang.llvm.org/extra/clang-tidy/checks/readability-delete-null-pointer.html -class DeleteNullPointerCheck : public ClangTidyCheck { +tooling::RewriteRule RewriteDeleteNullPointer(); + +class DeleteNullPointerCheck : public utils::TransformerTidy { public: DeleteNullPointerCheck(StringRef Name, ClangTidyContext *Context) - : ClangTidyCheck(Name, Context) {} - void registerMatchers(ast_matchers::MatchFinder *Finder) override; - void check(const ast_matchers::MatchFinder::MatchResult &Result) override; + : TransformerTidy(RewriteDeleteNullPointer(), Name, Context) {} }; } // namespace readability diff --git a/clang-tools-extra/clang-tidy/readability/DeleteNullPointerCheck.cpp b/clang-tools-extra/clang-tidy/readability/DeleteNullPointerCheck.cpp --- a/clang-tools-extra/clang-tidy/readability/DeleteNullPointerCheck.cpp +++ b/clang-tools-extra/clang-tidy/readability/DeleteNullPointerCheck.cpp @@ -10,6 +10,7 @@ #include "clang/AST/ASTContext.h" #include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/Lex/Lexer.h" +#include "clang/Tooling/Refactoring/Transformer.h" using namespace clang::ast_matchers; @@ -17,16 +18,12 @@ namespace tidy { namespace readability { -void DeleteNullPointerCheck::registerMatchers(MatchFinder *Finder) { - const auto DeleteExpr = - cxxDeleteExpr(has(castExpr(has(declRefExpr( - to(decl(equalsBoundNode("deletedPointer")))))))) - .bind("deleteExpr"); +tooling::RewriteRule RewriteDeleteNullPointer() { + const auto DeleteExpr = cxxDeleteExpr(has( + castExpr(has(declRefExpr(to(decl(equalsBoundNode("deletedPointer")))))))); - const auto DeleteMemberExpr = - cxxDeleteExpr(has(castExpr(has(memberExpr(hasDeclaration( - fieldDecl(equalsBoundNode("deletedMemberPointer")))))))) - .bind("deleteMemberExpr"); + const auto DeleteMemberExpr = cxxDeleteExpr(has(castExpr(has(memberExpr( + hasDeclaration(fieldDecl(equalsBoundNode("deletedMemberPointer")))))))); const auto PointerExpr = ignoringImpCasts(anyOf( declRefExpr(to(decl().bind("deletedPointer"))), @@ -38,39 +35,23 @@ binaryOperator(hasEitherOperand(castExpr(hasCastKind(CK_NullToPointer))), hasEitherOperand(PointerExpr)); - Finder->addMatcher( - ifStmt(hasCondition(anyOf(PointerCondition, BinaryPointerCheckCondition)), - hasThen(anyOf( - DeleteExpr, DeleteMemberExpr, - compoundStmt(anyOf(has(DeleteExpr), has(DeleteMemberExpr)), - statementCountIs(1)) - .bind("compound")))) - .bind("ifWithDelete"), - this); -} - -void DeleteNullPointerCheck::check(const MatchFinder::MatchResult &Result) { - const auto *IfWithDelete = Result.Nodes.getNodeAs("ifWithDelete"); - const auto *Compound = Result.Nodes.getNodeAs("compound"); - - auto Diag = diag( - IfWithDelete->getBeginLoc(), - "'if' statement is unnecessary; deleting null pointer has no effect"); - if (IfWithDelete->getElse()) - return; - // FIXME: generate fixit for this case. - - Diag << FixItHint::CreateRemoval(CharSourceRange::getTokenRange( - IfWithDelete->getBeginLoc(), - Lexer::getLocForEndOfToken(IfWithDelete->getCond()->getEndLoc(), 0, - *Result.SourceManager, - Result.Context->getLangOpts()))); - if (Compound) { - Diag << FixItHint::CreateRemoval( - CharSourceRange::getTokenRange(Compound->getLBracLoc())); - Diag << FixItHint::CreateRemoval( - CharSourceRange::getTokenRange(Compound->getRBracLoc())); - } + tooling::StmtId DelStmt; + using tooling::bind; + using tooling::stencil_generators::statements; + return tooling::RewriteRule() + .matching(ifStmt( + hasCondition(anyOf(PointerCondition, BinaryPointerCheckCondition)), + hasThen(bind( + DelStmt, + stmt(anyOf(DeleteExpr, DeleteMemberExpr, + compoundStmt(statementCountIs(1), + has(stmt(anyOf(DeleteExpr, + DeleteMemberExpr)))))))), + // FIXME: handle else case. + unless(hasElse(stmt())))) + .replaceWith(statements(DelStmt)) + .explain( + "'if' statement is unnecessary; deleting null pointer has no effect"); } } // namespace readability diff --git a/clang-tools-extra/clang-tidy/readability/ElseAfterReturnCheck.h b/clang-tools-extra/clang-tidy/readability/ElseAfterReturnCheck.h --- a/clang-tools-extra/clang-tidy/readability/ElseAfterReturnCheck.h +++ b/clang-tools-extra/clang-tidy/readability/ElseAfterReturnCheck.h @@ -10,20 +10,21 @@ #define LLVM_CLANG_TOOLS_EXTRA_CLANG_TIDY_READABILITY_ELSEAFTERRETURNCHECK_H #include "../ClangTidy.h" +#include "../utils/TransformerTidy.h" namespace clang { namespace tidy { namespace readability { +std::vector RewriteElseAfterBranch(); + /// Flags the usages of `else` after `return`. /// /// http://llvm.org/docs/CodingStandards.html#don-t-use-else-after-a-return -class ElseAfterReturnCheck : public ClangTidyCheck { +class ElseAfterReturnCheck : public utils::MultiTransformerTidy { public: ElseAfterReturnCheck(StringRef Name, ClangTidyContext *Context) - : ClangTidyCheck(Name, Context) {} - void registerMatchers(ast_matchers::MatchFinder *Finder) override; - void check(const ast_matchers::MatchFinder::MatchResult &Result) override; + : MultiTransformerTidy(RewriteElseAfterBranch(), Name, Context) {} }; } // namespace readability diff --git a/clang-tools-extra/clang-tidy/readability/ElseAfterReturnCheck.cpp b/clang-tools-extra/clang-tidy/readability/ElseAfterReturnCheck.cpp --- a/clang-tools-extra/clang-tidy/readability/ElseAfterReturnCheck.cpp +++ b/clang-tools-extra/clang-tidy/readability/ElseAfterReturnCheck.cpp @@ -10,6 +10,7 @@ #include "clang/AST/ASTContext.h" #include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/Tooling/FixIt.h" +#include "clang/Tooling/Refactoring/Transformer.h" using namespace clang::ast_matchers; @@ -17,14 +18,14 @@ namespace tidy { namespace readability { -void ElseAfterReturnCheck::registerMatchers(MatchFinder *Finder) { - const auto InterruptsControlFlow = - stmt(anyOf(returnStmt().bind("return"), continueStmt().bind("continue"), - breakStmt().bind("break"), - expr(ignoringImplicit(cxxThrowExpr().bind("throw"))))); - Finder->addMatcher( - compoundStmt(forEach( - ifStmt(unless(isConstexpr()), +static tooling::RewriteRule RewriteElse( + ast_matchers::StatementMatcher InterruptsControlFlow, + StringRef ControlFlowInterruptor) { + using tooling::stencil_generators::statements; + tooling::StmtId IfS("If"), CondS("C"), ThenS("T"), ElseS("E"); + return tooling::RewriteRule() + .matching(compoundStmt(forEach( + ifStmt(IfS.bind(), hasCondition(CondS.bind()), unless(isConstexpr()), // FIXME: Explore alternatives for the // `if (T x = ...) {... return; } else { }` // pattern: @@ -32,30 +33,20 @@ // * fix by pulling out the variable declaration out of // the condition. unless(hasConditionVariableStatement(anything())), - hasThen(stmt(anyOf(InterruptsControlFlow, + hasThen(stmt(ThenS.bind(), + anyOf(InterruptsControlFlow, compoundStmt(has(InterruptsControlFlow))))), - hasElse(stmt().bind("else"))) - .bind("if"))), - this); + hasElse(ElseS.bind()))))) + .change(IfS) + .replaceWith("if (", CondS, ") ", ThenS, " ", statements(ElseS)) + .explain("do not use 'else' after '", ControlFlowInterruptor, "'"); } -void ElseAfterReturnCheck::check(const MatchFinder::MatchResult &Result) { - const auto *If = Result.Nodes.getNodeAs("if"); - SourceLocation ElseLoc = If->getElseLoc(); - std::string ControlFlowInterruptor; - for (const auto *BindingName : {"return", "continue", "break", "throw"}) - if (Result.Nodes.getNodeAs(BindingName)) - ControlFlowInterruptor = BindingName; - - DiagnosticBuilder Diag = diag(ElseLoc, "do not use 'else' after '%0'") - << ControlFlowInterruptor; - Diag << tooling::fixit::createRemoval(ElseLoc); - - // FIXME: Removing the braces isn't always safe. Do a more careful analysis. - // FIXME: Change clang-format to correctly un-indent the code. - if (const auto *CS = Result.Nodes.getNodeAs("else")) - Diag << tooling::fixit::createRemoval(CS->getLBracLoc()) - << tooling::fixit::createRemoval(CS->getRBracLoc()); +std::vector RewriteElseAfterBranch() { + return {RewriteElse(returnStmt(), "return"), + RewriteElse(continueStmt(), "continue"), + RewriteElse(breakStmt(), "break"), + RewriteElse(expr(ignoringImplicit(cxxThrowExpr())), "throw")}; } } // namespace readability diff --git a/clang-tools-extra/clang-tidy/utils/CMakeLists.txt b/clang-tools-extra/clang-tidy/utils/CMakeLists.txt --- a/clang-tools-extra/clang-tidy/utils/CMakeLists.txt +++ b/clang-tools-extra/clang-tidy/utils/CMakeLists.txt @@ -13,6 +13,7 @@ LexerUtils.cpp NamespaceAliaser.cpp OptionsUtils.cpp + TransformerTidy.cpp TypeTraits.cpp UsingInserter.cpp @@ -22,4 +23,5 @@ clangBasic clangLex clangTidy + clangToolingRefactor ) diff --git a/clang-tools-extra/clang-tidy/utils/TransformerTidy.h b/clang-tools-extra/clang-tidy/utils/TransformerTidy.h new file mode 100644 --- /dev/null +++ b/clang-tools-extra/clang-tidy/utils/TransformerTidy.h @@ -0,0 +1,56 @@ +//===---------- TransformerTidy.h - clang-tidy ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_TOOLS_EXTRA_CLANG_TIDY_TRANSFORMER_TIDY_H +#define LLVM_CLANG_TOOLS_EXTRA_CLANG_TIDY_TRANSFORMER_TIDY_H + +#include "../ClangTidy.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/Tooling/Refactoring/Transformer.h" +#include +#include + +namespace clang { +namespace tidy { +namespace utils { + +// A ClangTidy encompassing a single rewrite rule. +class TransformerTidy : public ClangTidyCheck { +public: + TransformerTidy(tooling::RewriteRule R, StringRef Name, + ClangTidyContext *Context) + : ClangTidyCheck(Name, Context), Rule(std::move(R)) {} + + void registerMatchers(ast_matchers::MatchFinder *Finder) override; + void check(const ast_matchers::MatchFinder::MatchResult &Result) override; + +private: + tooling::RewriteRule Rule; +}; + +class MultiTransformerTidy : public ClangTidyCheck { +public: + MultiTransformerTidy(std::vector Rules, StringRef Name, + ClangTidyContext *Context); + + void registerMatchers(ast_matchers::MatchFinder *Finder) override; + + // `check` will never be called, since all of the matchers are registered to + // child tidies. + void check(const ast_matchers::MatchFinder::MatchResult &Result) override {} + +private: + // Use a deque to ensure pointer stability of elements. + std::deque Tidies; +}; + +} // namespace utils +} // namespace tidy +} // namespace clang + +#endif // LLVM_CLANG_TOOLS_EXTRA_CLANG_TIDY_TRANSFORMER_TIDY_H diff --git a/clang-tools-extra/clang-tidy/utils/TransformerTidy.cpp b/clang-tools-extra/clang-tidy/utils/TransformerTidy.cpp new file mode 100644 --- /dev/null +++ b/clang-tools-extra/clang-tidy/utils/TransformerTidy.cpp @@ -0,0 +1,65 @@ +//===---------- TransformerTidy.cpp - clang-tidy -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TransformerTidy.h" + +namespace clang { +namespace tidy { +namespace utils { + +void TransformerTidy::registerMatchers(ast_matchers::MatchFinder *Finder) { + Finder->addDynamicMatcher(Rule.matcher(), this); +} + +void TransformerTidy::check( + const ast_matchers::MatchFinder::MatchResult &Result) { + auto ChangeOrErr = tooling::internal::transform(Result, Rule); + if (auto Err = ChangeOrErr.takeError()) { + llvm::errs() << "Rewrite failed: " << llvm::toString(std::move(Err)) + << "\n"; + return; + } + auto &Change = *ChangeOrErr; + auto &Range = Change.Range; + if (Range.isInvalid()) { + // No rewrite applied (but no error encountered either). + return; + } + auto MessageOrErr = Rule.explanation().eval(Result); + if (auto Err = MessageOrErr.takeError()) { + llvm::errs() << "Evaluation of the explanation stencil failed: " + << llvm::toString(std::move(Err)) << "\n"; + return; + } + StringRef Message = *MessageOrErr; + if (Message.empty()) { + Message = "no explanation"; + } + diag(Range.getBegin(), Message) + << FixItHint::CreateReplacement(Range, Change.Replacement); + // FIXME: Should we warn if added_headers/removed_headers is non empty? +} + +MultiTransformerTidy::MultiTransformerTidy( + std::vector Rules, StringRef Name, + ClangTidyContext *Context) + : ClangTidyCheck(Name, Context) { + for (auto &R : Rules) { + Tidies.emplace_back(std::move(R), Name, Context); + } +} + +void MultiTransformerTidy::registerMatchers(ast_matchers::MatchFinder *Finder) { + for (auto &T : Tidies) { + T.registerMatchers(Finder); + } +} + +} // namespace utils +} // namespace tidy +} // namespace clang diff --git a/clang-tools-extra/test/clang-tidy/readability-delete-null-pointer.cpp b/clang-tools-extra/test/clang-tidy/readability-delete-null-pointer.cpp --- a/clang-tools-extra/test/clang-tidy/readability-delete-null-pointer.cpp +++ b/clang-tools-extra/test/clang-tidy/readability-delete-null-pointer.cpp @@ -12,9 +12,9 @@ // CHECK-MESSAGES: :[[@LINE-3]]:3: warning: 'if' statement is unnecessary; deleting null pointer has no effect [readability-delete-null-pointer] // CHECK-FIXES: {{^ }}// #1 - // CHECK-FIXES-NEXT: {{^ }}// #2 + // CHECK-FIXES-NEXT: // #2 // CHECK-FIXES-NEXT: delete p; - // CHECK-FIXES-NEXT: {{^ }}// #3 + // CHECK-FIXES-NEXT: // #3 int *p2 = new int[3]; // #4 @@ -23,7 +23,7 @@ // CHECK-MESSAGES: :[[@LINE-2]]:3: warning: 'if' statement is unnecessary; // CHECK-FIXES: // #4 - // CHECK-FIXES-NEXT: {{^ }}// #5 + // CHECK-FIXES-NOT: if (p2) // #5 // CHECK-FIXES-NEXT: delete[] p2; int *p3 = 0; @@ -52,9 +52,6 @@ char *c2; if (c2) { - // CHECK-MESSAGES: :[[@LINE-1]]:3: warning: 'if' statement is unnecessary; - // CHECK-FIXES: } else { - // CHECK-FIXES: c2 = c; delete c2; } else { c2 = c; @@ -64,8 +61,8 @@ if (mp) // #6 delete mp; // CHECK-MESSAGES: :[[@LINE-2]]:7: warning: 'if' statement is unnecessary; deleting null pointer has no effect [readability-delete-null-pointer] - // CHECK-FIXES: {{^ }}// #6 - // CHECK-FIXES-NEXT: delete mp; + // CHECK-FIXES-NOT: if (mp) // #6 + // CHECK-FIXES: delete mp; } int *mp; }; diff --git a/clang-tools-extra/test/clang-tidy/readability-else-after-return-if-constexpr.cpp b/clang-tools-extra/test/clang-tidy/readability-else-after-return-if-constexpr.cpp --- a/clang-tools-extra/test/clang-tidy/readability-else-after-return-if-constexpr.cpp +++ b/clang-tools-extra/test/clang-tidy/readability-else-after-return-if-constexpr.cpp @@ -4,9 +4,9 @@ void f() { if (sizeof(int) > 4) return; + // CHECK-MESSAGES: [[@LINE-2]]:3: warning: do not use 'else' after 'return' else return; - // CHECK-MESSAGES: [[@LINE-2]]:3: warning: do not use 'else' after 'return' if constexpr (sizeof(int) > 4) return; diff --git a/clang-tools-extra/test/clang-tidy/readability-else-after-return.cpp b/clang-tools-extra/test/clang-tidy/readability-else-after-return.cpp --- a/clang-tools-extra/test/clang-tidy/readability-else-after-return.cpp +++ b/clang-tools-extra/test/clang-tidy/readability-else-after-return.cpp @@ -13,17 +13,17 @@ void f(int a) { if (a > 0) + // CHECK-MESSAGES: :[[@LINE-1]]:3: warning: do not use 'else' after 'return' return; else // comment-0 - // CHECK-MESSAGES: :[[@LINE-1]]:3: warning: do not use 'else' after 'return' - // CHECK-FIXES: {{^}} // comment-0 return; + // CHECK-FIXES: {{^}} if (a > 0) return; return; if (a > 0) { + // CHECK-MESSAGES: :[[@LINE-1]]:3: warning: do not use 'else' after 'return' return; } else { // comment-1 - // CHECK-MESSAGES: :[[@LINE-1]]:5: warning: do not use 'else' after 'return' - // CHECK-FIXES: {{^}} } // comment-1 + // CHECK-FIXES: {{^}} }{{ *}}// comment-1 return; } @@ -58,18 +58,18 @@ if (a > 0) { if (a < 10) + // CHECK-MESSAGES: :[[@LINE-1]]:5: warning: do not use 'else' after 'return' return; else // comment-5 - // CHECK-MESSAGES: :[[@LINE-1]]:5: warning: do not use 'else' after 'return' - // CHECK-FIXES: {{^}} // comment-5 f(0); + // CHECK-FIXES: {{^}} if (a < 10) return; f(0); } else { if (a > 10) + // CHECK-MESSAGES: :[[@LINE-1]]:5: warning: do not use 'else' after 'return' return; else // comment-6 - // CHECK-MESSAGES: :[[@LINE-1]]:5: warning: do not use 'else' after 'return' - // CHECK-FIXES: {{^}} // comment-6 f(0); + // CHECK-FIXES: {{^}} if (a > 10) return; f(0); } } @@ -78,29 +78,29 @@ if (x) { continue; } else { // comment-7 - // CHECK-MESSAGES: :[[@LINE-1]]:7: warning: do not use 'else' after 'continue' - // CHECK-FIXES: {{^}} } // comment-7 + // CHECK-MESSAGES: :[[@LINE-3]]:5: warning: do not use 'else' after 'continue' + // CHECK-FIXES: {{^}} }{{ *}}// comment-7 x++; } if (x) { break; } else { // comment-8 - // CHECK-MESSAGES: :[[@LINE-1]]:7: warning: do not use 'else' after 'break' - // CHECK-FIXES: {{^}} } // comment-8 + // CHECK-MESSAGES: :[[@LINE-3]]:5: warning: do not use 'else' after 'break' + // CHECK-FIXES: {{^}} }{{ *}}// comment-8 x++; } if (x) { throw 42; } else { // comment-9 - // CHECK-MESSAGES: :[[@LINE-1]]:7: warning: do not use 'else' after 'throw' - // CHECK-FIXES: {{^}} } // comment-9 + // CHECK-MESSAGES: :[[@LINE-3]]:5: warning: do not use 'else' after 'throw' + // CHECK-FIXES: {{^}} }{{ *}}// comment-9 x++; } if (x) { throw my_exception("foo"); } else { // comment-10 - // CHECK-MESSAGES: :[[@LINE-1]]:7: warning: do not use 'else' after 'throw' - // CHECK-FIXES: {{^}} } // comment-10 + // CHECK-MESSAGES: :[[@LINE-3]]:5: warning: do not use 'else' after 'throw' + // CHECK-FIXES: {{^}} }{{ *}}// comment-10 x++; } } diff --git a/clang-tools-extra/unittests/clang-tidy/CMakeLists.txt b/clang-tools-extra/unittests/clang-tidy/CMakeLists.txt --- a/clang-tools-extra/unittests/clang-tidy/CMakeLists.txt +++ b/clang-tools-extra/unittests/clang-tidy/CMakeLists.txt @@ -16,7 +16,8 @@ ObjCModuleTest.cpp OverlappingReplacementsTest.cpp UsingInserterTest.cpp - ReadabilityModuleTest.cpp) + ReadabilityModuleTest.cpp + TransformerTidyTest.cpp) target_link_libraries(ClangTidyTests PRIVATE @@ -35,4 +36,5 @@ clangTidyUtils clangTooling clangToolingCore + clangToolingRefactor ) diff --git a/clang-tools-extra/unittests/clang-tidy/TransformerTidyTest.cpp b/clang-tools-extra/unittests/clang-tidy/TransformerTidyTest.cpp new file mode 100644 --- /dev/null +++ b/clang-tools-extra/unittests/clang-tidy/TransformerTidyTest.cpp @@ -0,0 +1,68 @@ +//===---- TransformerTidyTest.cpp - clang-tidy ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../clang-tidy/utils/TransformerTidy.h" + +#include "ClangTidyTest.h" +#include "clang/ASTMatchers/ASTMatchers.h" +#include "gtest/gtest.h" + +namespace clang { +namespace tidy { +namespace utils { +namespace { +// Change `if ($c) $t $e` to `if (!$c) $e $t`. +// +// N.B. This rule is oversimplified (since it is just for testing): it won't +// construct the correct result if the input has compound statements. +tooling::RewriteRule invertIf() { + using ast_matchers::hasCondition; + using ast_matchers::hasElse; + using ast_matchers::hasThen; + using ast_matchers::ifStmt; + + tooling::ExprId C; + tooling::StmtId T, E; + + return tooling::RewriteRule() + .matching( + ifStmt(hasCondition(C.bind()), hasThen(T.bind()), hasElse(E.bind()))) + .replaceWith("if(!(", C, ")) ", E, " else ", T); +} + +class IfInverterTidy : public TransformerTidy { + public: + IfInverterTidy(StringRef Name, ClangTidyContext* Context) + : TransformerTidy(invertIf(), Name, Context) {} +}; + +// Basic test of using a rewrite rule as a ClangTidy. +TEST(TransformerTidyTest, Basic) { + const std::string Input = R"cc( + void log(const char* msg); + void foo() { + if (10 > 1.0) + log("oh no!"); + else + log("ok"); + } + )cc"; + + const std::string Expected = R"( + void log(const char* msg); + void foo() { + if(!(10 > 1.0)) log("ok"); else log("oh no!"); + } + )"; + + EXPECT_EQ(Expected, test::runCheckOnCode(Input)); +} +} // namespace +} // namespace utils +} // namespace tidy +} // namespace clang diff --git a/clang/include/clang/Tooling/Refactoring/Stencil.h b/clang/include/clang/Tooling/Refactoring/Stencil.h new file mode 100644 --- /dev/null +++ b/clang/include/clang/Tooling/Refactoring/Stencil.h @@ -0,0 +1,290 @@ +//===--- Stencil.h - Stencil class ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the *Stencil* abstraction: a code-generating object, +// parameterized by named references to (bound) AST nodes. Given a match +// result, a stencil can be evaluated to a string of source code. +// +// A stencil is similar in spirit to a format string: it is composed of a +// series of raw text strings, references to nodes (the parameters) and helper +// code-generation operations. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_TOOLING_REFACTOR_STENCIL_H_ +#define LLVM_CLANG_TOOLING_REFACTOR_STENCIL_H_ + +#include +#include + +#include "clang/AST/ASTContext.h" +#include "clang/AST/ASTTypeTraits.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" + +namespace clang { +namespace tooling { + +// A strong type for AST node identifiers. The standard API uses StringRefs for +// identifiers. The strong type allows us to distinguish ids from arbitrary +// text snippets in various parts of the API. +class NodeId { +public: + explicit NodeId(std::string Id) : Id(std::move(Id)) {} + + // Creates a NodeId whose name is based on the id. Guarantees that unique ids + // map to unique NodeIds. + explicit NodeId(size_t Id) : Id("id" + std::to_string(Id)) {} + + // Convenience constructor that generates a fresh id (with respect to other + // generated ids). + NodeId(); + + llvm::StringRef id() const { return Id; } + + // Gets the AST node in `result` corresponding to this NodeId, if + // any. Otherwise, returns null. + template + const Node * + getNodeAs(const ast_matchers::MatchFinder::MatchResult &Result) const { + return Result.Nodes.getNodeAs(Id); + } + +private: + std::string Id; +}; + +// A stencil is represented as a sequence of "parts" that can each individually +// generate a code string based on a match result. The different kinds of parts +// include (raw) text, references to bound nodes and assorted operations on +// bound nodes. +// +// Users can create custom Stencil operations by implementing this interface. +class StencilPartInterface { +public: + StencilPartInterface() = default; + virtual ~StencilPartInterface() = default; + + // Evaluates this part to a string and appends it to `result`. + virtual llvm::Error eval(const ast_matchers::MatchFinder::MatchResult &Match, + std::string *Result) const = 0; + + virtual std::unique_ptr clone() const = 0; + +protected: + // Since this is an abstract class, copying/assigning only make sense for + // derived classes implementing `Clone()`. + StencilPartInterface(const StencilPartInterface &) = default; + StencilPartInterface &operator=(const StencilPartInterface &) = default; +}; + +// A copyable facade for a std::unique_ptr. Copies result +// in a copy of the underlying pointee object. +class StencilPart { +public: + explicit StencilPart(std::unique_ptr Impl) + : Impl(std::move(Impl)) {} + + // Copy constructor/assignment produce a deep copy. + StencilPart(const StencilPart &P) : Impl(P.Impl->clone()) {} + StencilPart(StencilPart &&) = default; + StencilPart &operator=(const StencilPart &P) { + Impl = P.Impl->clone(); + return *this; + } + StencilPart &operator=(StencilPart &&) = default; + + // See StencilPartInterface::Eval. + llvm::Error eval(const ast_matchers::MatchFinder::MatchResult &Match, + std::string *Result) const { + return Impl->eval(Match, Result); + } + +private: + std::unique_ptr Impl; +}; + +// Include directive modification +// +// Stencils also support operations to add and remove preprocessor include +// directives. Users specify the included file with a string, which can +// optionally be enclosed in <> or "". If unenclosed, surrounding double quotes +// are implied. The resulting string is treated literally in the relevant +// operation. No attempt is made to interpret the path in the string; for +// example, to identify it with another path that resolves to the same file. + +// Add an #include for the specified path to the file being rewritten. No-op +// if the directive is already present. A `path` surrounded by <> adds a +// directive that uses <>; surrounded by "" (explicit or implicit) adds a +// directive that uses "". +struct AddIncludeOp { + std::string Path; +}; + +// Remove an #include of the specified path from the file being rewritten. +// No-op if the include isn't present. A `path` surrounded by <> removes a +// directive that uses <>; surrounded by "" (explicit or implicit) removes a +// directive that uses "". +struct RemoveIncludeOp { + std::string Path; +}; + +// A sequence of code fragments, references to parameters and code-generation +// operations that together can be evaluated to (a fragment of) source code, +// given a match result. +class Stencil { +public: + Stencil() = default; + + Stencil(const Stencil &) = default; + Stencil(Stencil &&) = default; + Stencil &operator=(const Stencil &) = default; + Stencil &operator=(Stencil &&) = default; + + // Compose a stencil from a series of parts. + template static Stencil cat(Ts &&... Parts) { + Stencil Stencil; + Stencil.Parts.reserve(sizeof...(Parts)); + auto Unused = {(Stencil.append(std::forward(Parts)), true)...}; + (void)Unused; + return Stencil; + } + + // Evaluates the stencil given a match result. Requires that the nodes in the + // result includes any ids referenced in the stencil. References to missing + // nodes will result in an invalid_argument error. + llvm::Expected + eval(const ast_matchers::MatchFinder::MatchResult &Match) const; + + // List of paths for which an include directive should be added. See + // AddIncludeOp for the meaning of the path strings. + const std::vector &addedIncludes() const { + return AddedIncludes; + } + + // List of paths for which an include directive should be removed. See + // RemoveIncludeOp for the meaning of the path strings. + const std::vector &removedIncludes() const { + return RemovedIncludes; + } + +private: + void append(const NodeId &Id); + void append(llvm::StringRef Text); + void append(StencilPart Part) { Parts.push_back(std::move(Part)); } + void append(AddIncludeOp Op) { AddedIncludes.push_back(std::move(Op.Path)); } + void append(RemoveIncludeOp Op) { + RemovedIncludes.push_back(std::move(Op.Path)); + } + + std::vector Parts; + // See corresponding accessors for descriptions of these two fields. + std::vector AddedIncludes; + std::vector RemovedIncludes; +}; + +// Functions for conveniently building stencil parts. +namespace stencil_generators { +// Abbreviation for NodeId construction allowing for more concise references to +// node ids in stencils. +inline NodeId id(llvm::StringRef Id) { return NodeId(Id); } + +// Yields exactly the text provided. +StencilPart text(llvm::StringRef Text); + +// Yields the source corresponding to the identified node. +StencilPart node(const NodeId &Id); +StencilPart node(llvm::StringRef Id); + +// Given a reference to a node e and a member m, yields "e->m", when e is a +// pointer, "e2->m" when e = "*e2" and "e.m" otherwise. "e" is wrapped in +// parentheses, if needed. Objects can be identified by NodeIds or strings and +// members can be identified by other parts (e.g. Name()) or raw text, hence the +// 4 overloads. +StencilPart member(const NodeId &ObjectId, StencilPart Member); +StencilPart member(const NodeId &ObjectId, llvm::StringRef Member); +StencilPart member(llvm::StringRef ObjectId, StencilPart Member); +StencilPart member(llvm::StringRef ObjectId, llvm::StringRef Member); + +// Renders a node's source as a value, even if the node is a pointer. +// Specifically, given a reference to a node "e", +// * when "e" has the form `&$expr`, yields `$expr`. +// * when "e" is a pointer, yields `*$e`. +// * otherwise, yields `$e`. +StencilPart asValue(const NodeId &Id); +StencilPart asValue(llvm::StringRef Id); + +// Renders a node's source as an address, even if the node is an lvalue. +// Specifically, given a reference to a node "e", +// * when "e" has the form `*$expr` (with '*' the builtin operator and `$expr` +// source code of an arbitrary expression), yields `$expr`. +// * when "e" is a pointer, yields `$e`, +// * otherwise, yields `&$e`. +StencilPart asAddress(const NodeId &Id); +StencilPart asAddress(llvm::StringRef Id); + +// Given a reference to a node "e", yields `($e)` if "e" may parse differently +// depending on context. For example, a binary operation is always wrapped while +// a variable reference is never wrapped. +StencilPart parens(const NodeId &Id); +StencilPart parens(llvm::StringRef Id); + +// Given a reference to a named declaration "d" (that is, a node of type +// NamedDecl or one its derived classes), yields the name. "d" must have +// an identifier name (that is, constructors are not valid arguments to the Name +// operation). +StencilPart name(const NodeId &DeclId); +StencilPart name(llvm::StringRef DeclId); + +// Given a reference to call expression node, yields the source text of the +// arguments (all source between the call's parentheses). +StencilPart args(const NodeId &CallId); +StencilPart args(llvm::StringRef CallId); + +// Given a reference to a compound statement node, yields the source text of the +// statements (all source between the braces). If the statement is not compound, +// yields the statement's source text. +StencilPart statements(const NodeId &StmtId); +StencilPart statements(llvm::StringRef StmtId); + +// Derive a string from a node. +using NodeFunction = std::function; + +// Derive a string from the result of a stencil-part evaluation. +using StringFunction = std::function; + +// Yields the string from applying `fn` to the referenced node. +StencilPart apply(NodeFunction Fn, const NodeId &Id); +StencilPart apply(NodeFunction Fn, llvm::StringRef Id); + +// Yields the string from applying `fn` to the evaluation of `part`. +StencilPart apply(StringFunction Fn, StencilPart Part); + +// Convenience overloads for case where target part is a node. +StencilPart apply(StringFunction Fn, const NodeId &Id); +StencilPart apply(StringFunction Fn, llvm::StringRef Id); + +// Add an include directive for 'path' into the file that is being rewritten. +// See comments on AddIncludeOp for more details. Not (yet) supported by clang +// tidy. +AddIncludeOp addInclude(llvm::StringRef Path); +// Remove an include directive for 'path' in the file that is being rewritten. +// See comments on RemoveIncludeOp for more details. Not (yet) supported by +// clang tidy. +RemoveIncludeOp removeInclude(llvm::StringRef Path); + +// For debug use only; semantics are not guaranteed. Generates the string +// resulting from calling the node's print() method. +StencilPart dPrint(const NodeId &Id); +StencilPart dPrint(llvm::StringRef Id); +} // namespace stencil_generators +} // namespace tooling +} // namespace clang +#endif // LLVM_CLANG_TOOLING_REFACTOR_STENCIL_H_ diff --git a/clang/include/clang/Tooling/Refactoring/Transformer.h b/clang/include/clang/Tooling/Refactoring/Transformer.h new file mode 100644 --- /dev/null +++ b/clang/include/clang/Tooling/Refactoring/Transformer.h @@ -0,0 +1,361 @@ +//===--- Transformer.h - Clang term-rewriting library -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a library supporting the concise specification of clang- +// based source-to-source transformations. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_TOOLING_REFACTOR_TRANSFORMER_H_ +#define LLVM_CLANG_TOOLING_REFACTOR_TRANSFORMER_H_ + +#include +#include +#include +#include +#include +#include + +#include "Stencil.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/ASTMatchers/ASTMatchers.h" +#include "clang/ASTMatchers/ASTMatchersInternal.h" +#include "clang/Tooling/Refactoring/AtomicChange.h" +#include "llvm/ADT/STLExtras.h" + +namespace clang { +namespace tooling { + +// Derivation of NodeId that identifies the intended node type for the id, which +// allows us to select appropriate overloads or constrain use of various +// combinators. `Node` is the AST node type corresponding to this id. +template +class TypedNodeId : public NodeId { + public: + using NodeId::NodeId; + using MatcherType = ast_matchers::internal::Matcher; + + // Creates a matcher corresponding to the AST-node type of this id and bound + // to this id. Targeted for settings where the type of matcher is + // obvious/uninteresting. For example, + // + // ExprId arg; + // auto matcher = callExpr(callee(IsFunctionNamed("foo")), + // hasArgument(0, arg.bind())); + MatcherType bind() const { + return ast_matchers::internal::BindableMatcher( + ast_matchers::internal::TrueMatcher()) + .bind(id()); + } +}; + +using ExprId = TypedNodeId; +using StmtId = TypedNodeId; +using DeclId = TypedNodeId; +using TypeId = TypedNodeId; + +// Introduce/define matcher-type abbreviations for all top-level classes in the +// AST class hierarchy. +using ast_matchers::CXXCtorInitializerMatcher; +using ast_matchers::DeclarationMatcher; +using ast_matchers::NestedNameSpecifierLocMatcher; +using ast_matchers::NestedNameSpecifierMatcher; +using ast_matchers::StatementMatcher; +using ast_matchers::TypeLocMatcher; +using ast_matchers::TypeMatcher; +using TemplateArgumentMatcher = + ast_matchers::internal::Matcher; +using TemplateNameMatcher = ast_matchers::internal::Matcher; +using ast_matchers::internal::DynTypedMatcher; + +// A simple abstraction of a filter for match results. Currently, it simply +// wraps a predicate, but we may extend the functionality to support a simple +// boolean expression language for constructing filters. +class MatchFilter { + public: + using Predicate = + std::function; + + MatchFilter() + : Filter([](const ast_matchers::MatchFinder::MatchResult&) { + return true; + }) {} + explicit MatchFilter(Predicate P) : Filter(std::move(P)) {} + + MatchFilter(const MatchFilter&) = default; + MatchFilter(MatchFilter&&) = default; + MatchFilter& operator=(const MatchFilter&) = default; + MatchFilter& operator=(MatchFilter&&) = default; + + bool matches(const ast_matchers::MatchFinder::MatchResult& Result) const { + return Filter(Result); + } + + private: + Predicate Filter; +}; + +// Selects the part of the AST node to replace. We support this to work around +// the fact that the AST does not differentiate various syntactic elements into +// their own nodes, so users can specify them relative to a node, instead. +// +// TODO(yitzhakm): Add tests for kMember and kName. +enum class NodePart { + // The node itself. + kNode, + // Given a MemberExpr, selects the member's token. + kMember, + // Given a NamedDecl or CxxCtorInitializer, select that token of the relevant + // name, not including qualifiers. + kName, +}; + +// A *rewrite rule* describes a transformation of source code. It has the +// following components: +// +// * Matcher: the pattern term, expressed as clang matchers (with Transformer +// extensions). +// +// * Where: a "where clause" -- that is, a predicate over (matched) AST nodes +// that restricts matches beyond what is (easily) expressable as a pattern. +// +// * Target: the source code impacted by the rule. This identifies an AST node, +// or part thereof, whose source range indicates the extent of the replacement +// applied by the replacement term. By default, the extent is the node +// matched by the pattern term. +// +// * Replacement: the replacement term, expressed as a code Stencil, which +// represents code or text interspersed with references to AST nodes. +// +// * Explanation: explanation of the rewrite. This, too, is represented as a +// Stencil to allow specializing the message based on parts of the matched +// code fragment. +// +// Rules have an additional, implicit, component: the parameters. These are +// portions of the pattern which are left unspecified, yet named so that we can +// reference them in the replacement term. The structure of parameters can be +// partially or even fully specified, in which case they serve just to identify +// matched nodes for later reference rather than abstract over portions of the +// AST. However, in all cases, we refer to named portions of the pattern as +// parameters. +// +// Parameters can be declared explicitly using the NodeId type and its +// derivatives or left implicit by using the native support for binding ids in +// the clang matchers and corresponding support for string identifiers in +// Stencils. +// +// All rule components are optional. An empty RewriteRule, however, matches any +// statement and replaces it with the empty string, so setting at least some +// parameters is recommended. +// +// RewriteRule is constructed in a "fluent" style, by chaining setters of +// individual components. We provide ref-qualified overloads of the setters to +// avoid an unnecessary copy when a RewriteRule is initialized from a temporary, +// like: +// +// RewriteRule r = RewriteRule().Pattern()... +class RewriteRule { + public: + RewriteRule(); + + RewriteRule(const RewriteRule&) = default; + RewriteRule(RewriteRule&&) = default; + RewriteRule& operator=(const RewriteRule&) = default; + RewriteRule& operator=(RewriteRule&&) = default; + + // `Matching()` supports all top-level nodes in the AST hierarchy. We spell + // out all of the permitted overloads, rather than defining a template, for + // documentation purposes and to give the user clear error messages if they + // pass a node that is not one of the permitted types. + RewriteRule& matching(CXXCtorInitializerMatcher M) & { + return setMatcher(std::move(M)); + } + RewriteRule& matching(DeclarationMatcher M) & { + return setMatcher(std::move(M)); + } + RewriteRule& matching(NestedNameSpecifierMatcher M) & { + return setMatcher(std::move(M)); + } + RewriteRule& matching(NestedNameSpecifierLocMatcher M) & { + return setMatcher(std::move(M)); + } + RewriteRule& matching(StatementMatcher M) & { + return setMatcher(std::move(M)); + } + RewriteRule& matching(TemplateArgumentMatcher M) & { + return setMatcher(std::move(M)); + } + RewriteRule& matching(TemplateNameMatcher M) & { + return setMatcher(std::move(M)); + } + RewriteRule& matching(TypeLocMatcher M) & { + return setMatcher(std::move(M)); + } + RewriteRule& matching(TypeMatcher M) & { return setMatcher(std::move(M)); } + + template + RewriteRule&& matching(MatcherT M) && { + return std::move(matching(std::move(M))); + } + + RewriteRule& where(MatchFilter::Predicate Filter) &; + RewriteRule&& where(MatchFilter::Predicate Filter) && { + return std::move(where(std::move(Filter))); + } + + RewriteRule& change(const NodeId& Target, NodePart Part = NodePart::kNode) &; + RewriteRule&& change(const NodeId& Target, + NodePart Part = NodePart::kNode) && { + return std::move(change(Target, Part)); + } + + RewriteRule& replaceWith(Stencil S) &; + RewriteRule&& replaceWith(Stencil S) && { + return std::move(replaceWith(std::move(S))); + } + + template + RewriteRule& replaceWith(Ts&&... Args) & { + Replacement = Stencil::cat(std::forward(Args)...); + return *this; + } + template + RewriteRule&& replaceWith(Ts&&... Args) && { + return std::move(replaceWith(std::forward(Args)...)); + } + + template + RewriteRule& explain(Ts&&... Args) & { + Explanation = Stencil::cat(std::forward(Args)...); + return *this; + } + template + RewriteRule&& explain(Ts&&... Args) && { + return std::move(explain(std::forward(Args)...)); + } + + const DynTypedMatcher& matcher() const { return Matcher; } + const MatchFilter& filter() const { return Filter; } + llvm::StringRef target() const { return Target; } + NodePart targetPart() const { return TargetPart; } + const Stencil& replacement() const { return Replacement; } + const Stencil& explanation() const { return Explanation; } + + private: + template + RewriteRule& setMatcher(MatcherT M) & { + auto DM = DynTypedMatcher(M); + DM.setAllowBind(true); + // The default target is `RootId`, so we bind it here. `tryBind` is + // guaranteed to succeed, because `AllowBind` is true. + Matcher = *DM.tryBind(RootId); + return *this; + } + + // Id used as the default target of each match. + static constexpr char RootId[] = "___root___"; + + // Supports any (top-level node) matcher type. + DynTypedMatcher Matcher; + MatchFilter Filter; + // The (bound) id of the node whose source will be replaced. This id should + // never be the empty string. By default, refers to the node matched by + // `matcher_`. + std::string Target; + NodePart TargetPart; + Stencil Replacement; + Stencil Explanation; +}; + +// Convenience factory function for the common case where a rule has a statement +// matcher, template and explanation. +RewriteRule makeRule(StatementMatcher Matcher, Stencil Replacement, + std::string Explanation); + +// A class that handles the matcher and callback registration for a single +// rewrite rule, as defined by the arguments of the constructor. +class Transformer : public ast_matchers::MatchFinder::MatchCallback { + public: + using ChangeConsumer = + std::function; + + Transformer(RewriteRule Rule, ChangeConsumer Consumer) + : Rule(std::move(Rule)), Consumer(std::move(Consumer)) {} + + // N.B. Passes `this` pointer to `match_finder`. So, this object should not + // be moved after this call. + void registerMatchers(ast_matchers::MatchFinder* MatchFinder); + + // Not called directly by users -- called by the framework, via base class + // pointer. + void run(const ast_matchers::MatchFinder::MatchResult& Result) override; + + private: + RewriteRule Rule; + ChangeConsumer Consumer; +}; + +// Convenience class to manage creation and storage of multiple rewriters. +class MultiTransformer { + public: + MultiTransformer(std::vector Rules, + const Transformer::ChangeConsumer& Consumer, + ast_matchers::MatchFinder* MF); + + private: + // Transformers register their `this` pointer with MatchFinder, so we use + // a deque to ensure stable pointers for each Transformer. + std::deque Transformers; +}; + +// Attempts to apply the rule to the given node to yield a string. Ignores the +// rule's `target` and `explanation` fields. The rule must match at most once; +// otherwise, the call will fail. +// +// Returns: +// * if the rewrite is successful, a string representing the replacement text +// for the given node, +// * if the rewrite does not apply (but no errors encountered), returns `None`. +// * if there is a failure, returns an `Error`. +llvm::Expected> maybeTransform( + const RewriteRule& Rule, const ast_type_traits::DynTypedNode& Node, + ASTContext* Context); + +template +llvm::Expected> maybeTransform( + const RewriteRule& Rule, const T& Node, ASTContext* Context) { + return maybeTransform(Rule, ast_type_traits::DynTypedNode::create(Node), + Context); +} + +// Binds the node described by `matcher` to the given node id. +template +ast_matchers::internal::Matcher bind( + const NodeId& Id, ast_matchers::internal::BindableMatcher Matcher) { + return Matcher.bind(Id.id()); +} + +namespace internal { +// A source "transformation," represented by a character range in the source to +// be replaced and a corresponding replacement string. +struct Transformation { + CharSourceRange Range; + std::string Replacement; +}; + +// Given a match and rule, tries to generate a transformation for the target of +// the rule. Fails if the match is not eligible for rewriting or any invariants +// are violated relating to bound nodes in the match. +Expected +transform(const ast_matchers::MatchFinder::MatchResult &Result, + const RewriteRule &Rule); +} // namespace internal +} // namespace tooling +} // namespace clang + +#endif // LLVM_CLANG_TOOLING_REFACTOR_TRANSFORMER_H_ diff --git a/clang/lib/Tooling/Refactoring/CMakeLists.txt b/clang/lib/Tooling/Refactoring/CMakeLists.txt --- a/clang/lib/Tooling/Refactoring/CMakeLists.txt +++ b/clang/lib/Tooling/Refactoring/CMakeLists.txt @@ -12,6 +12,8 @@ Rename/USRFinder.cpp Rename/USRFindingAction.cpp Rename/USRLocFinder.cpp + Stencil.cpp + Transformer.cpp LINK_LIBS clangAST diff --git a/clang/lib/Tooling/Refactoring/Stencil.cpp b/clang/lib/Tooling/Refactoring/Stencil.cpp new file mode 100644 --- /dev/null +++ b/clang/lib/Tooling/Refactoring/Stencil.cpp @@ -0,0 +1,760 @@ +//===--- Stencil.cpp - Stencil implementation -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Tooling/Refactoring/Stencil.h" + +#include +#include + +#include "clang/AST/ASTContext.h" +#include "clang/AST/ASTTypeTraits.h" +#include "clang/AST/Expr.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/ASTMatchers/ASTMatchers.h" +#include "clang/Lex/Lexer.h" +#include "clang/Tooling/FixIt.h" +#include "llvm/Support/Errc.h" + +namespace clang { +namespace tooling { +// +// BEGIN Utilities -- the folowing functions all belong in a separate utilities +// library. We include them here for the purposes of this demo so that it will +// compile +// + +// Returns true if expr needs to be put in parens when it is the target +// of a dot or arrow, i.e. when it is an operator syntactically. +static bool needParensBeforeDotOrArrow(const clang::Expr &Expr) { + // We always want parens around unary, binary, and ternary operators. + if (llvm::dyn_cast(&Expr) || + llvm::dyn_cast(&Expr) || + llvm::dyn_cast(&Expr)) { + return true; + } + + // We need parens around calls to all overloaded operators except for function + // calls, subscripts, and expressions that are already part of an implicit + // call to operator->. + if (const auto *Op = llvm::dyn_cast(&Expr)) { + return Op->getOperator() != clang::OO_Call && + Op->getOperator() != clang::OO_Subscript && + Op->getOperator() != clang::OO_Arrow; + } + + return false; +} + +// BEGIN from clang-tidy/readability/RedundantStringCStrCheck.cpp + +// Return true if expr needs to be put in parens when it is an argument of a +// prefix unary operator, e.g. when it is a binary or ternary operator +// syntactically. +static bool needParensAfterUnaryOperator(const Expr &ExprNode) { + if (isa(&ExprNode) || + isa(&ExprNode)) { + return true; + } + if (const auto *Op = dyn_cast(&ExprNode)) { + return Op->getNumArgs() == 2 && Op->getOperator() != OO_PlusPlus && + Op->getOperator() != OO_MinusMinus && Op->getOperator() != OO_Call && + Op->getOperator() != OO_Subscript; + } + return false; +} + +// Format a pointer to an expression: prefix with '*' but simplify +// when it already begins with '&'. Return empty string on failure. +static std::string formatDereference(const ASTContext &Context, + const Expr &ExprNode) { + if (const auto *Op = dyn_cast(&ExprNode)) { + if (Op->getOpcode() == UO_AddrOf) { + // Strip leading '&'. + return tooling::fixit::getText(*Op->getSubExpr()->IgnoreParens(), + Context); + } + } + StringRef Text = tooling::fixit::getText(ExprNode, Context); + + if (Text.empty()) + return std::string(); + // Add leading '*'. + if (needParensAfterUnaryOperator(ExprNode)) { + return (llvm::Twine("*(") + Text + ")").str(); + } + return (llvm::Twine("*") + Text).str(); +} + +// END from clang-tidy/readability/RedundantStringCStrCheck.cpp + +// Format a pointer to an expression: prefix with '&' but simplify when it +// already begins with '*'. Returns empty string on failure. +static std::string formatAddressOf(const ASTContext &Context, + const clang::Expr &Expr) { + if (const auto *Op = llvm::dyn_cast(&Expr)) { + if (Op->getOpcode() == clang::UO_Deref) { + // Strip leading '*'. + return tooling::fixit::getText(*Op->getSubExpr()->IgnoreParens(), + Context); + } + } + // Add leading '&'. + const std::string Text = fixit::getText(Expr, Context); + if (Text.empty()) + return std::string(); + if (needParensAfterUnaryOperator(Expr)) { + return (llvm::Twine("&(") + Text + ")").str(); + } + return (llvm::Twine("&") + Text).str(); +} + +static std::string formatDot(const ASTContext &Context, + const clang::Expr &Expr) { + if (const auto *Op = llvm::dyn_cast(&Expr)) { + if (Op->getOpcode() == clang::UO_Deref) { + // Strip leading '*', add following '->'. + const clang::Expr *SubExpr = Op->getSubExpr()->IgnoreParenImpCasts(); + const std::string DerefText = fixit::getText(*SubExpr, Context); + if (DerefText.empty()) + return std::string(); + if (needParensBeforeDotOrArrow(*SubExpr)) { + return (llvm::Twine("(") + DerefText + ")->").str(); + } + return (llvm::Twine(DerefText) + "->").str(); + } + } + // Add following '.'. + const std::string Text = fixit::getText(Expr, Context); + if (Text.empty()) + return std::string(); + if (needParensBeforeDotOrArrow(Expr)) { + return (llvm::Twine("(") + Text + ").").str(); + } + return (llvm::Twine(Text) + ".").str(); +} + +static std::string formatArrow(const ASTContext &Context, + const clang::Expr &Expr) { + if (const auto *Op = llvm::dyn_cast(&Expr)) { + if (Op->getOpcode() == clang::UO_AddrOf) { + // Strip leading '&', add following '.'. + const clang::Expr *SubExpr = Op->getSubExpr()->IgnoreParenImpCasts(); + const std::string DerefText = fixit::getText(*SubExpr, Context); + if (DerefText.empty()) + return std::string(); + if (needParensBeforeDotOrArrow(*SubExpr)) { + return (llvm::Twine("(") + DerefText + ").").str(); + } + return (llvm::Twine(DerefText) + ".").str(); + } + } + // Add following '->'. + const std::string Text = fixit::getText(Expr, Context); + if (Text.empty()) + return std::string(); + if (needParensBeforeDotOrArrow(Expr)) { + return (llvm::Twine("(") + Text + ")->").str(); + } + return (llvm::Twine(Text) + "->").str(); +} + +// BEGIN from: clang-tidy/utils/LexerUtils.cpp +static SourceLocation findPreviousTokenStart(SourceLocation Start, + const SourceManager &SM, + const LangOptions &LangOpts) { + if (Start.isInvalid() || Start.isMacroID()) + return SourceLocation(); + + SourceLocation BeforeStart = Start.getLocWithOffset(-1); + if (BeforeStart.isInvalid() || BeforeStart.isMacroID()) + return SourceLocation(); + + return Lexer::GetBeginningOfToken(BeforeStart, SM, LangOpts); +} + +static SourceLocation findPreviousTokenKind(SourceLocation Start, + const SourceManager &SM, + const LangOptions &LangOpts, + tok::TokenKind TK) { + while (true) { + SourceLocation L = findPreviousTokenStart(Start, SM, LangOpts); + if (L.isInvalid() || L.isMacroID()) + return SourceLocation(); + + Token T; + if (Lexer::getRawToken(L, T, SM, LangOpts, /*IgnoreWhiteSpace=*/true)) + return SourceLocation(); + + if (T.is(TK)) + return T.getLocation(); + + Start = L; + } +} +// END From: clang-tidy/utils/LexerUtils + +// For refactoring purposes, some expressions should be wrapped in parentheses +// to avoid changes in the order of operation, assuming no other information +// about the surrounding context. +static bool needsParens(const Expr *E) { + return isa(E) || isa(E) || + isa(E) || isa(E); +} + +// Finds the open paren of the call expression and return its location. Returns +// an invalid location if not found. +static SourceLocation +getOpenParen(const CallExpr &E, + const ast_matchers::MatchFinder::MatchResult &Result) { + SourceLocation EndLoc = + E.getNumArgs() == 0 ? E.getRParenLoc() : E.getArg(0)->getBeginLoc(); + return findPreviousTokenKind(EndLoc, *Result.SourceManager, + Result.Context->getLangOpts(), + tok::TokenKind::l_paren); +} + +// For a given range, returns the lexed token immediately after the range if +// and only if it's a semicolon. +static Optional getTrailingSemi(SourceLocation EndLoc, + const ASTContext &Context) { + if (Optional Next = Lexer::findNextToken( + EndLoc, Context.getSourceManager(), Context.getLangOpts())) { + return Next->is(clang::tok::TokenKind::semi) ? Next : None; + } + return None; +} + +static const clang::Stmt *getStatementParent(const clang::Stmt &node, + ASTContext &context) { + using namespace ast_matchers; + + auto is_or_has_node = + anyOf(equalsNode(&node), hasDescendant(equalsNode(&node))); + auto not_in_condition = unless(hasCondition(is_or_has_node)); + // Note that SwitchCase nodes have the subsequent statement as substatement. + // For example, in "case 1: a(); b();", a() will be the child of the + // SwitchCase "case 1:". + // TODO(djasper): Also handle other labels, probably not important in google3. + // missing: switchStmt() (although this is a weird corner case). + auto statement = stmt(hasParent( + stmt(anyOf(compoundStmt(), whileStmt(not_in_condition), + doStmt(not_in_condition), switchCase(), + ifStmt(not_in_condition), + forStmt(not_in_condition, unless(hasIncrement(is_or_has_node)), + unless(hasLoopInit(is_or_has_node))))) + .bind("parent"))); + return selectFirst("parent", + match(statement, node, context)); +} + +// Is a real statement (not an expression inside another expression). That is, +// not an expression with an expression parent. +static bool isRealStatement(const Stmt &S, ASTContext &Context) { + return !isa(S) || getStatementParent(S, Context) != nullptr; +} + +// For all non-expression statements, extend the source to include any trailing +// semi. Returns a SourceRange representing a token range. +static SourceRange getTokenRange(const Stmt &S, ASTContext &Context) { + // Only exlude non-statement expressions. + if (isRealStatement(S, Context)) { + // TODO: exclude case where last token is a right brace? + if (auto Tok = getTrailingSemi(S.getEndLoc(), Context)) + return SourceRange(S.getBeginLoc(), Tok->getLocation()); + } + return S.getSourceRange(); +} + +static SourceRange getTokenRange(const ast_type_traits::DynTypedNode &Node, + ASTContext &Context) { + if (const auto *S = Node.get()) + return getTokenRange(*S, Context); + return Node.getSourceRange(); +} + +template +StringRef getText(const T &Node, ASTContext &Context) { + return Lexer::getSourceText( + CharSourceRange::getTokenRange(getTokenRange(Node, Context)), + Context.getSourceManager(), Context.getLangOpts()); +} + +// +// END Utilities +// + +// For guaranteeing unique ids on NodeId creation. +static size_t nextId() { + // Start with a relatively high number to avoid bugs if the user mixes + // explicitly-numbered ids with those generated with `NextId()`. Similarly, we + // choose a number that allows generated ids to be easily recognized. + static std::atomic Next(2222); + return Next.fetch_add(1, std::memory_order_relaxed); +} + +// Gets the source text of the arguments to the call expression. Includes all +// source between the parentheses delimiting the call. +static StringRef +getArgumentsText(const CallExpr &CE, + const ast_matchers::MatchFinder::MatchResult &Result) { + auto Range = CharSourceRange::getCharRange( + getOpenParen(CE, Result).getLocWithOffset(1), CE.getRParenLoc()); + return Lexer::getSourceText(Range, Result.Context->getSourceManager(), + Result.Context->getLangOpts()); +} + +// Gets the source text of the statements in the compound statement. Includes +// all source between the braces. +static StringRef +getStatementsText(const CompoundStmt &CS, + const ast_matchers::MatchFinder::MatchResult &Result) { + auto Range = CharSourceRange::getCharRange( + CS.getLBracLoc().getLocWithOffset(1), CS.getRBracLoc()); + return Lexer::getSourceText(Range, Result.Context->getSourceManager(), + Result.Context->getLangOpts()); +} + +static Expected +getNode(const ast_matchers::BoundNodes &Nodes, llvm::StringRef Id) { + auto &NodesMap = Nodes.getMap(); + auto It = NodesMap.find(Id); + if (It == NodesMap.end()) { + return llvm::make_error(llvm::errc::invalid_argument, + "Id not bound: " + Id); + } + return It->second; +} + +namespace { +using ::clang::ast_matchers::MatchFinder; +using ::llvm::errc; +using ::llvm::Error; +using ::llvm::Expected; +using ::llvm::StringError; +using ::llvm::StringRef; + +// An arbitrary fragment of code within a stencil. +class RawText : public StencilPartInterface { +public: + explicit RawText(StringRef Text) : Text(Text) {} + + Error eval(const MatchFinder::MatchResult &, + std::string *Result) const override { + Result->append(Text); + return Error::success(); + } + + std::unique_ptr clone() const override { + return llvm::make_unique(*this); + } + +private: + std::string Text; +}; + +// A debugging operation to dump the AST for a particular (bound) AST node. +class DebugPrintNodeOp : public StencilPartInterface { +public: + explicit DebugPrintNodeOp(StringRef Id) : Id(Id) {} + + Error eval(const MatchFinder::MatchResult &Match, + std::string *Result) const override { + std::string Output; + llvm::raw_string_ostream Os(Output); + auto NodeOrErr = getNode(Match.Nodes, Id); + if (auto Err = NodeOrErr.takeError()) { + return Err; + } + NodeOrErr->print(Os, PrintingPolicy(Match.Context->getLangOpts())); + *Result += Os.str(); + return Error::success(); + } + + std::unique_ptr clone() const override { + return llvm::make_unique(*this); + } + +private: + std::string Id; +}; + +// A reference to a particular (bound) AST node. +class NodeRef : public StencilPartInterface { +public: + explicit NodeRef(StringRef Id) : Id(Id) {} + + Error eval(const MatchFinder::MatchResult &Match, + std::string *Result) const override { + auto NodeOrErr = getNode(Match.Nodes, Id); + if (auto Err = NodeOrErr.takeError()) { + return Err; + } + *Result += getText(NodeOrErr.get(), *Match.Context); + return Error::success(); + } + + std::unique_ptr clone() const override { + return llvm::make_unique(*this); + } + +private: + std::string Id; +}; + +// A stencil operation that, given a reference to an expression e and a Part +// describing a member m, yields "e->m", when e is a pointer, "e2->m" when e = +// "*e2" and "e.m" otherwise. +class MemberOp : public StencilPartInterface { +public: + MemberOp(StringRef ObjectId, StencilPart Member) + : ObjectId(ObjectId), Member(std::move(Member)) {} + + Error eval(const MatchFinder::MatchResult &Match, + std::string *Result) const override { + const auto *E = Match.Nodes.getNodeAs(ObjectId); + if (E == nullptr) { + return llvm::make_error(errc::invalid_argument, + "Id not bound: " + ObjectId); + } + // N.B. The RHS is a google string. TODO(yitzhakm): fix the RHS to be a + // std::string. + if (!E->isImplicitCXXThis()) { + *Result += E->getType()->isAnyPointerType() + ? formatArrow(*Match.Context, *E) + : formatDot(*Match.Context, *E); + } + return Member.eval(Match, Result); + } + + std::unique_ptr clone() const override { + return llvm::make_unique(*this); + } + +private: + std::string ObjectId; + StencilPart Member; +}; + +// Operations all take a single reference to a Expr parameter, e. +class ExprOp : public StencilPartInterface { +public: + enum class Operator { + // Yields "e2" when e = "&e2" (with '&' the builtin operator), "*e" when e + // is a pointer and "e" otherwise. + kValue, + // Yields "e2" when e = "*e2" (with '*' the builtin operator), "e" when e is + // a pointer and "&e" otherwise. + kAddress, + // Wraps e in parens if it may parse differently depending on context. For + // example, a binary operation is always wrapped while a variable reference + // is never wrapped. + kParens, + }; + + ExprOp(Operator Op, StringRef Id) : Op(Op), Id(Id) {} + + Error eval(const MatchFinder::MatchResult &Match, + std::string *Result) const override { + const auto *Expression = Match.Nodes.getNodeAs(Id); + if (Expression == nullptr) { + return llvm::make_error(errc::invalid_argument, + "Id not bound: " + Id); + } + const auto &Context = *Match.Context; + switch (Op) { + case ExprOp::Operator::kValue: + if (Expression->getType()->isAnyPointerType()) { + *Result += formatDereference(Context, *Expression); + } else { + *Result += fixit::getText(*Expression, Context); + } + break; + case ExprOp::Operator::kAddress: + if (Expression->getType()->isAnyPointerType()) { + *Result += fixit::getText(*Expression, Context); + } else { + *Result += formatAddressOf(Context, *Expression); + } + break; + case ExprOp::Operator::kParens: + if (needsParens(Expression)) { + *Result += "("; + *Result += fixit::getText(*Expression, Context); + *Result += ")"; + } else { + *Result += fixit::getText(*Expression, Context); + } + break; + } + return Error::success(); + } + + std::unique_ptr clone() const override { + return llvm::make_unique(*this); + } + +private: + Operator Op; + std::string Id; +}; + +// Given a reference to a named declaration d (NamedDecl), yields +// the name. "d" must have an identifier name (that is, constructors are +// not valid arguments to the Name operation). +class NameOp : public StencilPartInterface { +public: + explicit NameOp(StringRef Id) : Id(Id) {} + + Error eval(const MatchFinder::MatchResult &Match, + std::string *Result) const override { + const NamedDecl *Decl; + if (const auto *Init = Match.Nodes.getNodeAs(Id)) { + Decl = Init->getMember(); + if (Decl == nullptr) { + return llvm::make_error(errc::invalid_argument, + "non-member initializer: " + Id); + } + } else { + Decl = Match.Nodes.getNodeAs(Id); + if (Decl == nullptr) { + return llvm::make_error( + errc::invalid_argument, + "Id not bound or wrong type for Name op: " + Id); + } + } + // getIdentifier() guards the validity of getName(). + if (Decl->getIdentifier() == nullptr) { + return llvm::make_error(errc::invalid_argument, + "Decl is not identifier: " + Id); + } + *Result += Decl->getName(); + return Error::success(); + } + + std::unique_ptr clone() const override { + return llvm::make_unique(*this); + } + +private: + std::string Id; +}; + +// Given a reference to a call expression (CallExpr), yields the +// arguments as a comma separated list. +class ArgsOp : public StencilPartInterface { +public: + explicit ArgsOp(StringRef Id) : Id(Id) {} + + Error eval(const MatchFinder::MatchResult &Match, + std::string *Result) const override { + const auto *CE = Match.Nodes.getNodeAs(Id); + if (CE == nullptr) { + return llvm::make_error(errc::invalid_argument, + "Id not bound: " + Id); + } + *Result += getArgumentsText(*CE, Match); + return Error::success(); + } + + std::unique_ptr clone() const override { + return llvm::make_unique(*this); + } + +private: + std::string Id; +}; + +// Given a reference to a statement, yields the contents between the braces, if +// it is compound, or the statement and its trailing semicolon (if any) +// otherwise. +class StatementsOp : public StencilPartInterface { +public: + explicit StatementsOp(StringRef Id) : Id(Id) {} + + Error eval(const MatchFinder::MatchResult &Match, + std::string *Result) const override { + if (const auto *CS = Match.Nodes.getNodeAs(Id)) { + *Result += getStatementsText(*CS, Match); + return Error::success(); + } + if (const auto *S = Match.Nodes.getNodeAs(Id)) { + *Result += getText(*S, *Match.Context); + return Error::success(); + } + return llvm::make_error(errc::invalid_argument, + "Id not bound: " + Id); + } + + std::unique_ptr clone() const override { + return llvm::make_unique(*this); + } + +private: + std::string Id; +}; + +// Given a function and a reference to a node, yields the string that results +// from applying the function to the referenced node. +class NodeFunctionOp : public StencilPartInterface { +public: + NodeFunctionOp(stencil_generators::NodeFunction F, StringRef Id) + : F(std::move(F)), Id(Id) {} + + Error eval(const MatchFinder::MatchResult &Match, + std::string *Result) const override { + auto NodeOrErr = getNode(Match.Nodes, Id); + if (auto Err = NodeOrErr.takeError()) { + return Err; + } + *Result += F(*NodeOrErr, *Match.Context); + return Error::success(); + } + + std::unique_ptr clone() const override { + return llvm::make_unique(*this); + } + +private: + stencil_generators::NodeFunction F; + std::string Id; +}; + +// Given a function and a stencil part, yields the string that results from +// applying the function to the part's evaluation. +class StringFunctionOp : public StencilPartInterface { +public: + StringFunctionOp(stencil_generators::StringFunction F, StencilPart Part) + : F(std::move(F)), Part(std::move(Part)) {} + + Error eval(const MatchFinder::MatchResult &Match, + std::string *Result) const override { + std::string PartResult; + if (auto Err = Part.eval(Match, &PartResult)) { + return Err; + } + *Result += F(PartResult); + return Error::success(); + } + + std::unique_ptr clone() const override { + return llvm::make_unique(*this); + } + +private: + stencil_generators::StringFunction F; + StencilPart Part; +}; +} // namespace + +NodeId::NodeId() : NodeId(nextId()) {} + +void Stencil::append(const NodeId &Id) { + Parts.emplace_back(llvm::make_unique(Id.id())); +} + +void Stencil::append(StringRef Text) { + Parts.emplace_back(llvm::make_unique(Text)); +} + +llvm::Expected +Stencil::eval(const MatchFinder::MatchResult &Match) const { + std::string Result; + for (const auto &Part : Parts) { + if (auto Err = Part.eval(Match, &Result)) { + return std::move(Err); + } + } + return Result; +} + +namespace stencil_generators { +StencilPart text(StringRef Text) { + return StencilPart(llvm::make_unique(Text)); +} + +StencilPart node(llvm::StringRef Id) { + return StencilPart(llvm::make_unique(Id)); +} +StencilPart node(const NodeId &Id) { return node(Id.id()); } + +StencilPart member(StringRef Id, StringRef Member) { + return StencilPart(llvm::make_unique(Id, text(Member))); +} +StencilPart member(const NodeId &ObjectId, StringRef Member) { + return member(ObjectId.id(), Member); +} + +StencilPart member(StringRef Id, StencilPart Member) { + return StencilPart(llvm::make_unique(Id, std::move(Member))); +} +StencilPart member(const NodeId &ObjectId, StencilPart Member) { + return member(ObjectId.id(), std::move(Member)); +} + +StencilPart asValue(StringRef Id) { + return StencilPart(llvm::make_unique(ExprOp::Operator::kValue, Id)); +} +StencilPart asValue(const NodeId &Id) { return asValue(Id.id()); } + +StencilPart asAddress(StringRef Id) { + return StencilPart(llvm::make_unique(ExprOp::Operator::kAddress, Id)); +} +StencilPart asAddress(const NodeId &Id) { return asAddress(Id.id()); } + +StencilPart parens(StringRef Id) { + return StencilPart(llvm::make_unique(ExprOp::Operator::kParens, Id)); +} +StencilPart parens(const NodeId &Id) { return parens(Id.id()); } + +StencilPart name(StringRef DeclId) { + return StencilPart(llvm::make_unique(DeclId)); +} +StencilPart name(const NodeId &DeclId) { return name(DeclId.id()); } + +StencilPart apply(NodeFunction Fn, StringRef Id) { + return StencilPart(llvm::make_unique(std::move(Fn), Id)); +} +StencilPart apply(NodeFunction Fn, const NodeId &Id) { + return apply(std::move(Fn), Id.id()); +} + +StencilPart apply(StringFunction Fn, StencilPart Part) { + return StencilPart( + llvm::make_unique(std::move(Fn), std::move(Part))); +} +StencilPart apply(StringFunction Fn, llvm::StringRef Id) { + return apply(std::move(Fn), node(Id)); +} +StencilPart apply(StringFunction Fn, const NodeId &Id) { + return apply(std::move(Fn), node(Id)); +} + +StencilPart args(StringRef CallId) { + return StencilPart(llvm::make_unique(CallId)); +} +StencilPart args(const NodeId &CallId) { return args(CallId.id()); } + +StencilPart statements(llvm::StringRef StmtId) { + return StencilPart(llvm::make_unique(StmtId)); +} +StencilPart statements(const NodeId &StmtId) { return statements(StmtId.id()); } + +StencilPart dPrint(StringRef Id) { + return StencilPart(llvm::make_unique(Id)); +} +StencilPart dPrint(const NodeId &Id) { return dPrint(Id.id()); } + +AddIncludeOp addInclude(StringRef Path) { + return AddIncludeOp{std::string(Path)}; +} +RemoveIncludeOp removeInclude(StringRef Path) { + return RemoveIncludeOp{std::string(Path)}; +} +} // namespace stencil_generators +} // namespace tooling +} // namespace clang diff --git a/clang/lib/Tooling/Refactoring/Transformer.cpp b/clang/lib/Tooling/Refactoring/Transformer.cpp new file mode 100644 --- /dev/null +++ b/clang/lib/Tooling/Refactoring/Transformer.cpp @@ -0,0 +1,341 @@ +//===--- Transformer.cpp - Transformer library implementation ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Tooling/Refactoring/Transformer.h" + +#include +#include +#include +#include + +#include "clang/AST/Expr.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/ASTMatchers/ASTMatchers.h" +#include "clang/Basic/Diagnostic.h" +#include "clang/Basic/SourceLocation.h" +#include "clang/Rewrite/Core/Rewriter.h" +#include "clang/Tooling/Refactoring.h" +#include "clang/Tooling/Refactoring/AtomicChange.h" +#include "clang/Tooling/Refactoring/Stencil.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Errc.h" +#include "llvm/Support/Error.h" + +namespace clang { +namespace tooling { +namespace { +using ::clang::ast_matchers::MatchFinder; +using ::clang::ast_matchers::stmt; +using ::llvm::Error; +using ::llvm::Expected; +using ::llvm::Optional; +using ::llvm::StringError; +using ::llvm::StringRef; + +using MatchResult = MatchFinder::MatchResult; +} // namespace + +// For a given range, returns the lexed token immediately after the range if +// and only if it's a semicolon. +static Optional getTrailingSemi(SourceLocation EndLoc, + const ASTContext &Context) { + if (Optional Next = Lexer::findNextToken( + EndLoc, Context.getSourceManager(), Context.getLangOpts())) { + return Next->is(clang::tok::TokenKind::semi) ? Next : None; + } + return None; +} + +static const clang::Stmt *getStatementParent(const clang::Stmt &node, + ASTContext &context) { + using namespace ast_matchers; + + auto is_or_has_node = + anyOf(equalsNode(&node), hasDescendant(equalsNode(&node))); + auto not_in_condition = unless(hasCondition(is_or_has_node)); + // Note that SwitchCase nodes have the subsequent statement as substatement. + // For example, in "case 1: a(); b();", a() will be the child of the + // SwitchCase "case 1:". + // TODO(djasper): Also handle other labels, probably not important in google3. + // missing: switchStmt() (although this is a weird corner case). + auto statement = stmt(hasParent( + stmt(anyOf(compoundStmt(), whileStmt(not_in_condition), + doStmt(not_in_condition), switchCase(), + ifStmt(not_in_condition), + forStmt(not_in_condition, unless(hasIncrement(is_or_has_node)), + unless(hasLoopInit(is_or_has_node))))) + .bind("parent"))); + return selectFirst("parent", + match(statement, node, context)); +} + +// Is a real statement (not an expression inside another expression). That is, +// not an expression with an expression parent. +static bool isRealStatement(const Stmt &S, ASTContext &Context) { + return !isa(S) || getStatementParent(S, Context) != nullptr; +} + +// For all non-expression statements, extend the source to include any trailing +// semi. Returns a SourceRange representing a token range. +static SourceRange getTokenRange(const Stmt &S, ASTContext &Context) { + if (isRealStatement(S, Context)) { + // TODO: exclude case where last token is a right brace? + if (auto Tok = getTrailingSemi(S.getEndLoc(), Context)) + return SourceRange(S.getBeginLoc(), Tok->getLocation()); + } + return S.getSourceRange(); +} + +static SourceRange getTokenRange(const ast_type_traits::DynTypedNode &Node, + ASTContext &Context) { + if (const auto *S = Node.get()) + return getTokenRange(*S, Context); + return Node.getSourceRange(); +} + +static llvm::Error invalidArgumentError(llvm::Twine Message) { + return llvm::make_error(llvm::errc::invalid_argument, Message); +} + +static llvm::Error unboundNodeError(StringRef Role, StringRef Id) { + return invalidArgumentError(Role + " (" + Id + ") references unbound node"); +} + +static llvm::Error typeError(llvm::Twine Message, + const clang::ast_type_traits::ASTNodeKind& Kind) { + return invalidArgumentError(Message + " (node kind is " + Kind.asStringRef() + + ")"); +} + +static llvm::Error missingPropertyError(llvm::Twine Description, + StringRef Property) { + return invalidArgumentError(Description + " requires property '" + Property + + "'"); +} + +// Verifies that `node` is appropriate for the given `target_part`. +static Error verifyTarget(const clang::ast_type_traits::DynTypedNode& Node, + NodePart TargetPart) { + switch (TargetPart) { + case NodePart::kNode: + return Error::success(); + case NodePart::kMember: + if (Node.get() != nullptr) { + return Error::success(); + } + return typeError("NodePart::kMember applied to non-MemberExpr", + Node.getNodeKind()); + case NodePart::kName: + if (const auto* D = Node.get()) { + if (D->getDeclName().isIdentifier()) { + return Error::success(); + } + return missingPropertyError("NodePart::kName", "identifier"); + } + if (const auto* E = Node.get()) { + if (E->getNameInfo().getName().isIdentifier()) { + return Error::success(); + } + return missingPropertyError("NodePart::kName", "identifier"); + } + if (const auto* I = Node.get()) { + if (I->isMemberInitializer()) { + return Error::success(); + } + return missingPropertyError("NodePart::kName", "member initializer"); + } + return typeError( + "NodePart::kName applied to neither DeclRefExpr, NamedDecl nor " + "CXXCtorInitializer", + Node.getNodeKind()); + } + llvm_unreachable("Unexpected case in NodePart type."); +} + +// Requires VerifyTarget(node, target_part) == success. +static SourceRange getTarget(const clang::ast_type_traits::DynTypedNode &Node, + NodePart TargetPart, ASTContext &Context) { + switch (TargetPart) { + case NodePart::kNode: + return getTokenRange(Node, Context); + case NodePart::kMember: + return SourceRange(Node.get()->getMemberLoc()); + case NodePart::kName: + if (const auto* D = Node.get()) { + return SourceRange(D->getLocation()); + } + if (const auto* E = Node.get()) { + return SourceRange(E->getLocation()); + } + if (const auto* I = Node.get()) { + return SourceRange(I->getMemberLocation()); + } + // This should be unreachable if the target was already verified. + llvm_unreachable( + "NodePart::kName applied to neither NamedDecl nor " + "CXXCtorInitializer"); + } + llvm_unreachable("Unexpected case in NodePart type."); +} + +// TODO: move to shared utility lib. +static bool isOriginMacroBody(const clang::SourceManager& source_manager, + clang::SourceLocation loc) { + while (loc.isMacroID()) { + if (source_manager.isMacroBodyExpansion(loc)) return true; + // Otherwise, it must be in an argument, so we continue searching up the + // invocation stack. getImmediateMacroCallerLoc() gives the location of the + // argument text, inside the call text. + loc = source_manager.getImmediateMacroCallerLoc(loc); + } + return false; +} + +namespace internal { +Expected transform(const MatchResult &Result, + const RewriteRule &Rule) { + // Ignore results in failing TUs or those rejected by the where clause. + if (Result.Context->getDiagnostics().hasErrorOccurred() || + !Rule.filter().matches(Result)) { + return Transformation(); + } + + auto &NodesMap = Result.Nodes.getMap(); + auto It = NodesMap.find(Rule.target()); + if (It == NodesMap.end()) { + return unboundNodeError("rule.target()", Rule.target()); + } + if (auto Err = llvm::handleErrors( + verifyTarget(It->second, Rule.targetPart()), [&Rule](StringError &E) { + return invalidArgumentError("Failure targeting node" + + Rule.target() + ": " + E.getMessage()); + })) { + return std::move(Err); + } + SourceRange Target = + getTarget(It->second, Rule.targetPart(), *Result.Context); + if (Target.isInvalid() || + isOriginMacroBody(*Result.SourceManager, Target.getBegin())) { + return Transformation(); + } + + if (auto ReplacementOrErr = Rule.replacement().eval(Result)) { + return Transformation{clang::CharSourceRange::getTokenRange(Target), + std::move(*ReplacementOrErr)}; + } else { + return ReplacementOrErr.takeError(); + } +} +} // namespace internal + +RewriteRule::RewriteRule() + : Matcher(stmt()), Target(RootId), TargetPart(NodePart::kNode) {} + +constexpr char RewriteRule::RootId[]; + +RewriteRule& RewriteRule::where( + std::function FilterFn) & { + Filter = MatchFilter(std::move(FilterFn)); + return *this; +} + +RewriteRule& RewriteRule::change(const NodeId& TargetId, NodePart Part) & { + Target = std::string(TargetId.id()); + TargetPart = Part; + return *this; +} + +RewriteRule& RewriteRule::replaceWith(Stencil S) & { + Replacement = std::move(S); + return *this; +} + +RewriteRule makeRule(StatementMatcher Matcher, Stencil Replacement, + std::string Explanation) { + return RewriteRule() + .matching(stmt(Matcher)) + .replaceWith(std::move(Replacement)) + .explain(std::move(Explanation)); +} + +void Transformer::registerMatchers(MatchFinder* MatchFinder) { + MatchFinder->addDynamicMatcher(Rule.matcher(), this); +} + +void Transformer::run(const MatchResult& Result) { + auto ChangeOrErr = internal::transform(Result, Rule); + if (auto Err = ChangeOrErr.takeError()) { + llvm::errs() << "Rewrite failed: " << llvm::toString(std::move(Err)) + << "\n"; + return; + } + auto& Change = *ChangeOrErr; + auto& Range = Change.Range; + if (Range.isInvalid()) { + // No rewrite applied (but no error encountered either). + return; + } + AtomicChange AC(*Result.SourceManager, Range.getBegin()); + if (auto Err = AC.replace(*Result.SourceManager, Range, + Change.Replacement)) { + AC.setError(llvm::toString(std::move(Err))); + } else { + for (const auto& header : Rule.replacement().addedIncludes()) { + AC.addHeader(header); + } + for (const auto& header : Rule.replacement().removedIncludes()) { + AC.removeHeader(header); + } + } + Consumer(AC); +} + +MultiTransformer::MultiTransformer(std::vector Rules, + const Transformer::ChangeConsumer& Consumer, + MatchFinder* MF) { + for (auto& R : Rules) { + Transformers.emplace_back(std::move(R), Consumer); + Transformers.back().registerMatchers(MF); + } +} + +static llvm::SmallVector match( + const DynTypedMatcher& Matcher, + const clang::ast_type_traits::DynTypedNode& Node, + clang::ASTContext* Context) { + clang::ast_matchers::internal::CollectMatchesCallback Callback; + MatchFinder Finder; + Finder.addDynamicMatcher(Matcher, &Callback); + Finder.match(Node, *Context); + return std::move(Callback.Nodes); +} + +Expected> maybeTransform( + const RewriteRule& Rule, const clang::ast_type_traits::DynTypedNode& Node, + clang::ASTContext* Context) { + auto Matches = match(Rule.matcher(), Node, Context); + if (Matches.empty()) { + return llvm::None; + } + if (Matches.size() > 1) { + return invalidArgumentError("rule is ambiguous"); + } + auto ChangeOrErr = + internal::transform(MatchResult(Matches[0], Context), Rule); + if (auto Err = ChangeOrErr.takeError()) { + return std::move(Err); + } + auto& Change = *ChangeOrErr; + if (Change.Range.isInvalid()) { + return llvm::None; + } + return Change.Replacement; +} +} // namespace tooling +} // namespace clang diff --git a/clang/unittests/Tooling/CMakeLists.txt b/clang/unittests/Tooling/CMakeLists.txt --- a/clang/unittests/Tooling/CMakeLists.txt +++ b/clang/unittests/Tooling/CMakeLists.txt @@ -49,7 +49,9 @@ RefactoringTest.cpp ReplacementsYamlTest.cpp RewriterTest.cpp + StencilTest.cpp ToolingTest.cpp + TransformerTest.cpp ) target_link_libraries(ToolingTests diff --git a/clang/unittests/Tooling/StencilTest.cpp b/clang/unittests/Tooling/StencilTest.cpp new file mode 100644 --- /dev/null +++ b/clang/unittests/Tooling/StencilTest.cpp @@ -0,0 +1,632 @@ +//===- unittest/Tooling/StencilTest.cpp -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Tooling/Refactoring/Stencil.h" +#include "clang/ASTMatchers/ASTMatchers.h" +#include "clang/Tooling/FixIt.h" +#include "clang/Tooling/Tooling.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace clang { +namespace tooling { +namespace { + +using ::clang::ast_matchers::compoundStmt; +using ::clang::ast_matchers::decl; +using ::clang::ast_matchers::declStmt; +using ::clang::ast_matchers::expr; +using ::clang::ast_matchers::hasAnySubstatement; +using ::clang::ast_matchers::hasCondition; +using ::clang::ast_matchers::hasDescendant; +using ::clang::ast_matchers::hasElse; +using ::clang::ast_matchers::hasInitializer; +using ::clang::ast_matchers::hasName; +using ::clang::ast_matchers::hasReturnValue; +using ::clang::ast_matchers::hasSingleDecl; +using ::clang::ast_matchers::hasThen; +using ::clang::ast_matchers::ifStmt; +using ::clang::ast_matchers::ignoringImplicit; +using ::clang::ast_matchers::returnStmt; +using ::clang::ast_matchers::stmt; +using ::clang::ast_matchers::varDecl; + +using MatchResult = ::clang::ast_matchers::MatchFinder::MatchResult; + +using ::clang::tooling::stencil_generators::addInclude; +using ::clang::tooling::stencil_generators::apply; +using ::clang::tooling::stencil_generators::args; +using ::clang::tooling::stencil_generators::asAddress; +using ::clang::tooling::stencil_generators::asValue; +using ::clang::tooling::stencil_generators::member; +using ::clang::tooling::stencil_generators::name; +using ::clang::tooling::stencil_generators::parens; +using ::clang::tooling::stencil_generators::removeInclude; +using ::clang::tooling::stencil_generators::text; + +using ::testing::Eq; + +// We can't directly match on llvm::Expected since its accessors mutate the +// object. So, we collapse it to an Optional. +llvm::Optional toOptional(llvm::Expected V) { + if (V) + return *V; + ADD_FAILURE() << "Losing error in conversion to IsSomething: " + << llvm::toString(V.takeError()); + return llvm::None; +} + +// A very simple matcher for llvm::Optional values. +MATCHER_P(IsSomething, ValueMatcher, "") { + if (!arg) + return false; + return ::testing::ExplainMatchResult(ValueMatcher, *arg, result_listener); +} + +// Create a valid translation-unit from a statement. +std::string wrapSnippet(llvm::StringRef StatementCode) { + return ("auto stencil_test_snippet = []{" + StatementCode + "};").str(); +} + +clang::ast_matchers::DeclarationMatcher +wrapMatcher(const clang::ast_matchers::StatementMatcher &Matcher) { + return varDecl(hasName("stencil_test_snippet"), + hasDescendant(compoundStmt(hasAnySubstatement(Matcher)))); +} + +struct TestMatch { + // The AST unit from which `result` is built. We bundle it because it backs + // the result. Users are not expected to access it. + std::unique_ptr AstUnit; + // The result to use in the test. References `ast_unit`. + MatchResult Result; +}; + +// Matches `matcher` against the statement `statement_code` and returns the +// result. Handles putting the statement inside a function and modifying the +// matcher correspondingly. `matcher` should match `statement_code` exactly -- +// that is, produce exactly one match. +llvm::Optional +matchStmt(llvm::StringRef StatementCode, + clang::ast_matchers::StatementMatcher Matcher) { + auto AstUnit = buildASTFromCode(wrapSnippet(StatementCode)); + if (AstUnit == nullptr) { + ADD_FAILURE() << "AST construction failed"; + return llvm::None; + } + clang::ASTContext &Context = AstUnit->getASTContext(); + auto Matches = clang::ast_matchers::match(wrapMatcher(Matcher), Context); + // We expect a single, exact match for the statement. + if (Matches.size() != 1) { + ADD_FAILURE() << "Wrong number of matches: " << Matches.size(); + return llvm::None; + } + return TestMatch{std::move(AstUnit), MatchResult(Matches[0], &Context)}; +} + +class StencilTest : public ::testing::Test { +public: + StencilTest() : Id0("id0"), Id1("id1") {} + +protected: + // Verifies that filling a single-parameter stencil from `context` will result + // in `expected`, assuming that the code in `context` contains a statement + // `return e` and "id0" is bound to `e`. + void testSingle(llvm::StringRef Snippet, const Stencil &Stencil, + llvm::StringRef Expected) { + auto StmtMatch = matchStmt( + Snippet, + returnStmt(hasReturnValue(ignoringImplicit(expr().bind(Id0.id()))))); + ASSERT_TRUE(StmtMatch); + EXPECT_THAT(toOptional(Stencil.eval(StmtMatch->Result)), + IsSomething(Expected)); + EXPECT_THAT(Stencil.addedIncludes(), testing::IsEmpty()); + EXPECT_THAT(Stencil.removedIncludes(), testing::IsEmpty()); + } + + // Verifies that the given stencil fails when evaluated on a valid match + // result. Binds a statement to "stmt", a (non-member) ctor-initializer to + // "init", an expression to "expr" and a (nameless) declaration to "decl". + void testError(const Stencil &Stencil, + testing::Matcher Matcher) { + using ::clang::ast_matchers::cxxConstructExpr; + using ::clang::ast_matchers::cxxCtorInitializer; + using ::clang::ast_matchers::hasDeclaration; + using ::clang::ast_matchers::isBaseInitializer; + + const std::string Snippet = R"cc( + struct A {}; + class F : public A { + public: + F(int) {} + }; + F(1); + )cc"; + auto StmtMatch = matchStmt( + Snippet, + stmt(hasDescendant( + cxxConstructExpr( + hasDeclaration(decl(hasDescendant(cxxCtorInitializer( + isBaseInitializer()) + .bind("init"))) + .bind("decl"))) + .bind("expr"))) + .bind("stmt")); + ASSERT_TRUE(StmtMatch); + if (auto ResultOrErr = Stencil.eval(StmtMatch->Result)) { + ADD_FAILURE() << "Expected failure but succeeded: " << *ResultOrErr; + } else { + auto Err = llvm::handleErrors(ResultOrErr.takeError(), + [&Matcher](const llvm::StringError &Err) { + EXPECT_THAT(Err.getMessage(), Matcher); + }); + if (Err) { + ADD_FAILURE() << "Unhandled error: " << llvm::toString(std::move(Err)); + } + } + } + + // Tests failures caused by references to unbound nodes. `unbound_id` is the + // id that will cause the failure. + void testUnboundNodeError(const Stencil &Stencil, llvm::StringRef UnboundId) { + testError(Stencil, testing::AllOf(testing::HasSubstr(UnboundId), + testing::HasSubstr("not bound"))); + } + + NodeId Id0; + NodeId Id1; +}; + +TEST_F(StencilTest, SingleStatement) { + using stencil_generators::id; + + const std::string Snippet = R"cc( + if (true) + return 1; + else + return 0; + )cc"; + auto StmtMatch = matchStmt(Snippet, ifStmt(hasCondition(stmt().bind("a1")), + hasThen(stmt().bind("a2")), + hasElse(stmt().bind("a3")))); + ASSERT_TRUE(StmtMatch); + auto Stencil = + Stencil::cat("if(!", id("a1"), ") ", id("a3"), "; else ", id("a2")); + EXPECT_THAT(toOptional(Stencil.eval(StmtMatch->Result)), + IsSomething(Eq("if(!true) return 0; else return 1"))); + EXPECT_THAT(Stencil.addedIncludes(), testing::IsEmpty()); + EXPECT_THAT(Stencil.removedIncludes(), testing::IsEmpty()); +} + +TEST_F(StencilTest, UnboundNode) { + using stencil_generators::id; + + const std::string Snippet = R"cc( + if (true) + return 1; + else + return 0; + )cc"; + auto StmtMatch = matchStmt(Snippet, ifStmt(hasCondition(stmt().bind("a1")), + hasThen(stmt().bind("a2")))); + ASSERT_TRUE(StmtMatch); + auto Stencil = Stencil::cat("if(!", id("a1"), ") ", id("UNBOUND"), ";"); + auto ResultOrErr = Stencil.eval(StmtMatch->Result); + EXPECT_TRUE(llvm::errorToBool(ResultOrErr.takeError())) + << "Expected unbound node, got " << *ResultOrErr; +} + +TEST_F(StencilTest, NodeOp) { + using stencil_generators::node; + + const std::string Snippet = R"cc( + int x; + return x; + )cc"; + testSingle(Snippet, Stencil::cat(node(Id0)), "x"); + testSingle(Snippet, Stencil::cat(node("id0")), "x"); +} + +TEST_F(StencilTest, MemberOpValue) { + const std::string Snippet = R"cc( + int x; + return x; + )cc"; + testSingle(Snippet, Stencil::cat(member(Id0, "field")), "x.field"); +} + +TEST_F(StencilTest, MemberOpValueExplicitText) { + const std::string Snippet = R"cc( + int x; + return x; + )cc"; + testSingle(Snippet, Stencil::cat(member(Id0, text("field"))), "x.field"); +} + +TEST_F(StencilTest, MemberOpValueAddress) { + const std::string Snippet = R"cc( + int x; + return &x; + )cc"; + testSingle(Snippet, Stencil::cat(member(Id0, "field")), "x.field"); +} + +TEST_F(StencilTest, MemberOpPointer) { + const std::string Snippet = R"cc( + int *x; + return x; + )cc"; + testSingle(Snippet, Stencil::cat(member(Id0, "field")), "x->field"); +} + +TEST_F(StencilTest, MemberOpPointerDereference) { + const std::string Snippet = R"cc( + int *x; + return *x; + )cc"; + testSingle(Snippet, Stencil::cat(member(Id0, "field")), "x->field"); +} + +TEST_F(StencilTest, MemberOpThis) { + using clang::ast_matchers::hasObjectExpression; + using clang::ast_matchers::memberExpr; + + const std::string Snippet = R"cc( + class C { + public: + int x; + int foo() { return x; } + }; + )cc"; + auto StmtMatch = + matchStmt(Snippet, returnStmt(hasReturnValue(ignoringImplicit(memberExpr( + hasObjectExpression(expr().bind("obj"))))))); + ASSERT_TRUE(StmtMatch); + const Stencil Stencil = Stencil::cat(member("obj", "field")); + EXPECT_THAT(toOptional(Stencil.eval(StmtMatch->Result)), + IsSomething(Eq("field"))); + EXPECT_THAT(Stencil.addedIncludes(), testing::IsEmpty()); + EXPECT_THAT(Stencil.removedIncludes(), testing::IsEmpty()); +} + +TEST_F(StencilTest, MemberOpUnboundNode) { + // Mistyped. + testUnboundNodeError(Stencil::cat(member("decl", "field")), "decl"); + testUnboundNodeError(Stencil::cat(member("unbound", "field")), "unbound"); +} + +TEST_F(StencilTest, ValueOpValue) { + const std::string Snippet = R"cc( + int x; + return x; + )cc"; + testSingle(Snippet, Stencil::cat(asValue(Id0)), "x"); +} + +TEST_F(StencilTest, ValueOpPointer) { + const std::string Snippet = R"cc( + int *x; + return x; + )cc"; + testSingle(Snippet, Stencil::cat(asValue(Id0)), "*x"); +} + +TEST_F(StencilTest, ValueOpUnboundNode) { + // Mistyped. + testUnboundNodeError(Stencil::cat(asValue("decl")), "decl"); + testUnboundNodeError(Stencil::cat(asValue("unbound")), "unbound"); +} + +TEST_F(StencilTest, AddressOpValue) { + const std::string Snippet = R"cc( + int x; + return x; + )cc"; + testSingle(Snippet, Stencil::cat(asAddress(Id0)), "&x"); +} + +TEST_F(StencilTest, AddressOpPointer) { + const std::string Snippet = R"cc( + int *x; + return x; + )cc"; + testSingle(Snippet, Stencil::cat(asAddress(Id0)), "x"); +} + +TEST_F(StencilTest, AddressOpUnboundNode) { + // Mistyped. + testUnboundNodeError(Stencil::cat(asAddress("decl")), "decl"); + testUnboundNodeError(Stencil::cat(asAddress("unbound")), "unbound"); +} + +TEST_F(StencilTest, ParensOpVar) { + const std::string Snippet = R"cc( + int x; + return x; + )cc"; + testSingle(Snippet, Stencil::cat(parens(Id0)), "x"); +} + +TEST_F(StencilTest, ParensOpMinus) { + const std::string Snippet = R"cc( + int x; + return -x; + )cc"; + testSingle(Snippet, Stencil::cat(parens(Id0)), "(-x)"); +} + +TEST_F(StencilTest, ParensOpDeref) { + const std::string Snippet = R"cc( + int *x; + return *x; + )cc"; + testSingle(Snippet, Stencil::cat(parens(Id0)), "(*x)"); +} + +TEST_F(StencilTest, ParensOpExpr) { + const std::string Snippet = R"cc( + int x; + int y; + return x + y; + )cc"; + testSingle(Snippet, Stencil::cat(parens(Id0)), "(x + y)"); +} + +// Tests that parens are not added when the expression already has them. +TEST_F(StencilTest, ParensOpParens) { + const std::string Snippet = R"cc( + int x; + int y; + return (x + y); + )cc"; + testSingle(Snippet, Stencil::cat(parens(Id0)), "(x + y)"); +} + +TEST_F(StencilTest, ParensOpFun) { + const std::string Snippet = R"cc( + int bar(int); + int x; + int y; + return bar(x); + )cc"; + testSingle(Snippet, Stencil::cat(parens(Id0)), "bar(x)"); +} + +TEST_F(StencilTest, ParensOpUnboundNode) { + // Mistyped. + testUnboundNodeError(Stencil::cat(parens("decl")), "decl"); + testUnboundNodeError(Stencil::cat(parens("unbound")), "unbound"); +} + +TEST_F(StencilTest, NameOp) { + const std::string Snippet = R"cc( + int x; + return x; + )cc"; + auto StmtMatch = + matchStmt(Snippet, declStmt(hasSingleDecl(decl().bind("d")))); + ASSERT_TRUE(StmtMatch); + const Stencil Stencil = Stencil::cat(name("d")); + EXPECT_THAT(toOptional(Stencil.eval(StmtMatch->Result)), + IsSomething(Eq("x"))); + EXPECT_THAT(Stencil.addedIncludes(), testing::IsEmpty()); + EXPECT_THAT(Stencil.removedIncludes(), testing::IsEmpty()); +} + +TEST_F(StencilTest, NameOpCtorInitializer) { + using clang::ast_matchers::cxxCtorInitializer; + + const std::string Snippet = R"cc( + class C { + public: + C() : field(3) {} + int field; + int foo() { return field; } + }; + )cc"; + auto StmtMatch = matchStmt( + Snippet, stmt(hasDescendant(cxxCtorInitializer().bind("init")))); + ASSERT_TRUE(StmtMatch); + const Stencil Stencil = Stencil::cat(name("init")); + EXPECT_THAT(toOptional(Stencil.eval(StmtMatch->Result)), + IsSomething(Eq("field"))); + EXPECT_THAT(Stencil.addedIncludes(), testing::IsEmpty()); + EXPECT_THAT(Stencil.removedIncludes(), testing::IsEmpty()); +} + +TEST_F(StencilTest, NameOpUnboundNode) { + // Decl has no name. + testError(Stencil::cat(name("decl")), testing::HasSubstr("not identifier")); + // Non-member (hence, no name) initializer. + testError(Stencil::cat(name("init")), + testing::HasSubstr("non-member initializer")); + // Mistyped. + testUnboundNodeError(Stencil::cat(name("expr")), "expr"); + testUnboundNodeError(Stencil::cat(name("unbound")), "unbound"); +} + +TEST_F(StencilTest, ArgsOp) { + const std::string Snippet = R"cc( + struct C { + int bar(int, int); + }; + C x; + return x.bar(3, 4); + )cc"; + testSingle(Snippet, Stencil::cat(args(Id0)), "3, 4"); +} + +TEST_F(StencilTest, ArgsOpNoArgs) { + const std::string Snippet = R"cc( + struct C { + int bar(); + }; + C x; + return x.bar(); + )cc"; + testSingle(Snippet, Stencil::cat(args(Id0)), ""); +} + +TEST_F(StencilTest, ArgsOpNoArgsWithComments) { + const std::string Snippet = R"cc( + struct C { + int bar(); + }; + C x; + return x.bar(/*empty*/); + )cc"; + testSingle(Snippet, Stencil::cat(args(Id0)), "/*empty*/"); +} + +// Tests that arguments are extracted correctly when a temporary (with parens) +// is used. +TEST_F(StencilTest, ArgsOpWithParens) { + const std::string Snippet = R"cc( + struct C { + int bar(int, int) { return 3; } + }; + C x; + return C().bar(3, 4); + )cc"; + testSingle(Snippet, Stencil::cat(args(Id0)), "3, 4"); +} + +TEST_F(StencilTest, ArgsOpLeadingComments) { + const std::string Snippet = R"cc( + struct C { + int bar(int, int) { return 3; } + }; + C x; + return C().bar(/*leading*/ 3, 4); + )cc"; + testSingle(Snippet, Stencil::cat(args(Id0)), "/*leading*/ 3, 4"); +} + +TEST_F(StencilTest, ArgsOpTrailingComments) { + const std::string Snippet = R"cc( + struct C { + int bar(int, int) { return 3; } + }; + C x; + return C().bar(3 /*trailing*/, 4); + )cc"; + testSingle(Snippet, Stencil::cat(args(Id0)), "3 /*trailing*/, 4"); +} + +TEST_F(StencilTest, ArgsOpEolComments) { + const std::string Snippet = R"cc( + struct C { + int bar(int, int) { return 3; } + }; + C x; + return C().bar( // Header + 1, // foo + 2 // bar + ); + )cc"; + testSingle(Snippet, Stencil::cat(args(Id0)), R"( // Header + 1, // foo + 2 // bar + )"); +} + +TEST_F(StencilTest, ArgsOpUnboundNode) { + // Mistyped. + testUnboundNodeError(Stencil::cat(args("stmt")), "stmt"); + testUnboundNodeError(Stencil::cat(args("unbound")), "unbound"); +} + +TEST_F(StencilTest, MemberOpWithNameOp) { + const std::string Snippet = R"cc( + int object; + int* method = &object; + (void)method; + return object; + )cc"; + auto StmtMatch = matchStmt( + Snippet, declStmt(hasSingleDecl( + varDecl(hasInitializer(expr().bind("e"))).bind("d")))); + ASSERT_TRUE(StmtMatch); + const Stencil Stencil = Stencil::cat(member("e", name("d"))); + EXPECT_THAT(toOptional(Stencil.eval(StmtMatch->Result)), + IsSomething(Eq("object.method"))); + EXPECT_THAT(Stencil.addedIncludes(), testing::IsEmpty()); + EXPECT_THAT(Stencil.removedIncludes(), testing::IsEmpty()); +} + +TEST_F(StencilTest, NodeFunctionOp) { + const std::string Snippet = R"cc( + int x; + return x; + )cc"; + auto SimpleFn = [](const ast_type_traits::DynTypedNode &Node, + const ASTContext &Context) { + return fixit::getText(Node, Context).str(); + }; + testSingle(Snippet, Stencil::cat(apply(SimpleFn, Id0)), "x"); + testSingle(Snippet, Stencil::cat(apply(SimpleFn, "id0")), "x"); +} + +TEST_F(StencilTest, StringFunctionOp) { + const std::string Snippet = R"cc( + int x; + return x; + )cc"; + auto SimpleFn = [](llvm::StringRef S) { return (S + " - 3").str(); }; + testSingle(Snippet, Stencil::cat(apply(SimpleFn, Id0)), "x - 3"); + testSingle(Snippet, Stencil::cat(apply(SimpleFn, "id0")), "x - 3"); +} + +TEST_F(StencilTest, StringFunctionOpNameOp) { + const std::string Snippet = R"cc( + int x; + return x; + )cc"; + auto SimpleFn = [](llvm::StringRef S) { return (S + " - 3").str(); }; + auto StmtMatch = + matchStmt(Snippet, declStmt(hasSingleDecl(decl().bind("d")))); + ASSERT_TRUE(StmtMatch); + const Stencil Stencil = Stencil::cat(apply(SimpleFn, name("d"))); + EXPECT_THAT(toOptional(Stencil.eval(StmtMatch->Result)), + IsSomething(Eq("x - 3"))); +} + +TEST_F(StencilTest, AddIncludeOp) { + const std::string Snippet = R"cc( + int x; + return -x; + )cc"; + auto StmtMatch = matchStmt(Snippet, stmt()); + ASSERT_TRUE(StmtMatch); + auto Stencil = Stencil::cat(addInclude("include/me.h"), "foo", + addInclude("include/metoo.h")); + EXPECT_THAT(toOptional(Stencil.eval(StmtMatch->Result)), + IsSomething(Eq("foo"))); + EXPECT_THAT(Stencil.addedIncludes(), + testing::UnorderedElementsAre("include/me.h", "include/metoo.h")); + EXPECT_THAT(Stencil.removedIncludes(), testing::IsEmpty()); +} + +TEST_F(StencilTest, RemoveIncludeOp) { + const std::string Snippet = R"cc( + int x; + return -x; + )cc"; + auto StmtMatch = matchStmt(Snippet, stmt()); + ASSERT_TRUE(StmtMatch); + auto Stencil = Stencil::cat(removeInclude("include/me.h"), "foo"); + EXPECT_THAT(toOptional(Stencil.eval(StmtMatch->Result)), + IsSomething(Eq("foo"))); + EXPECT_THAT(Stencil.addedIncludes(), testing::IsEmpty()); + EXPECT_THAT(Stencil.removedIncludes(), + testing::UnorderedElementsAre("include/me.h")); +} + +} // namespace +} // namespace tooling +} // namespace clang diff --git a/clang/unittests/Tooling/TransformerTest.cpp b/clang/unittests/Tooling/TransformerTest.cpp new file mode 100644 --- /dev/null +++ b/clang/unittests/Tooling/TransformerTest.cpp @@ -0,0 +1,885 @@ +//===- unittest/Tooling/TransformerTest.cpp -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Tooling/Refactoring/Transformer.h" + +#include "clang/Tooling/Refactoring/Stencil.h" +#include "clang/ASTMatchers/ASTMatchers.h" +#include "clang/Tooling/Tooling.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "clang/ASTMatchers/ASTMatchersMacros.h" + +namespace clang { +namespace tooling { +namespace { +using ::clang::ast_matchers::anyOf; +using ::clang::ast_matchers::argumentCountIs; +using ::clang::ast_matchers::callee; +using ::clang::ast_matchers::callExpr; +using ::clang::ast_matchers::cxxMemberCallExpr; +using ::clang::ast_matchers::cxxMethodDecl; +using ::clang::ast_matchers::cxxOperatorCallExpr; +using ::clang::ast_matchers::cxxRecordDecl; +using ::clang::ast_matchers::declRefExpr; +using ::clang::ast_matchers::eachOf; +using ::clang::ast_matchers::expr; +using ::clang::ast_matchers::functionDecl; +using ::clang::ast_matchers::has; +using ::clang::ast_matchers::hasAnyName; +using ::clang::ast_matchers::hasArgument; +using ::clang::ast_matchers::hasCondition; +using ::clang::ast_matchers::hasDeclaration; +using ::clang::ast_matchers::hasElse; +using ::clang::ast_matchers::hasName; +using ::clang::ast_matchers::hasOverloadedOperatorName; +using ::clang::ast_matchers::hasReturnValue; +using ::clang::ast_matchers::hasThen; +using ::clang::ast_matchers::hasType; +using ::clang::ast_matchers::id; +using ::clang::ast_matchers::ifStmt; +using ::clang::ast_matchers::ignoringImplicit; +using ::clang::ast_matchers::member; +using ::clang::ast_matchers::memberExpr; +using ::clang::ast_matchers::namedDecl; +using ::clang::ast_matchers::on; +using ::clang::ast_matchers::pointsTo; +using ::clang::ast_matchers::returnStmt; +using ::clang::ast_matchers::stmt; +using ::clang::ast_matchers::to; +using ::clang::ast_matchers::unless; + +using stencil_generators::id; +using stencil_generators::member; +using stencil_generators::parens; +using stencil_generators::name; +using stencil_generators::addInclude; +using stencil_generators::removeInclude; + +constexpr char KHeaderContents[] = R"cc( + struct string { + string(const char *); + char* c_str(); + int size(); + }; + int strlen(const char*); + + class Logger {}; + void operator<<(Logger& l, string msg); + Logger& log(int level); + + namespace proto { + struct PCFProto { + int foo(); + }; + struct ProtoCommandLineFlag : PCFProto { + PCFProto& GetProto(); + }; + } // namespace proto +)cc"; +} // namespace + +static clang::ast_matchers::internal::Matcher isOrPointsTo( + const DeclarationMatcher& TypeMatcher) { + return anyOf(hasDeclaration(TypeMatcher), pointsTo(TypeMatcher)); +} + +static std::string format(llvm::StringRef Code) { + const std::vector Ranges(1, Range(0, Code.size())); + auto Style = format::getLLVMStyle(); + const auto Replacements = format::reformat(Style, Code, Ranges); + auto Formatted = applyAllReplacements(Code, Replacements); + if (!Formatted) { + ADD_FAILURE() << "Could not format code: " + << llvm::toString(Formatted.takeError()); + return std::string(); + } + return *Formatted; +} + +void compareSnippets(llvm::StringRef Expected, + const llvm::Optional& MaybeActual) { + ASSERT_TRUE(MaybeActual) << "Rewrite failed. Expecting: " << Expected; + auto Actual = *MaybeActual; + std::string HL = "#include \"header.h\"\n"; + auto I = Actual.find(HL); + if (I != std::string::npos) { + Actual.erase(I, HL.size()); + } + EXPECT_EQ(format(Expected), format(Actual)); +} + +// FIXME: consider separating this class into its own file(s). +class ClangRefactoringTestBase : public testing::Test { + protected: + void appendToHeader(llvm::StringRef S) { FileContents[0].second += S; } + + void addFile(llvm::StringRef Filename, llvm::StringRef Content) { + FileContents.emplace_back(Filename, Content); + } + + llvm::Optional rewrite(llvm::StringRef Input) { + std::string Code = ("#include \"header.h\"\n" + Input).str(); + auto Factory = newFrontendActionFactory(&MatchFinder); + if (!runToolOnCodeWithArgs( + Factory->create(), Code, std::vector(), "input.cc", + "clang-tool", std::make_shared(), + FileContents)) { + return None; + } + auto ChangedCodeOrErr = + applyAtomicChanges("input.cc", Code, Changes, ApplyChangesSpec()); + if (auto Err = ChangedCodeOrErr.takeError()) { + llvm::errs() << "Change failed: " << llvm::toString(std::move(Err)) + << "\n"; + return None; + } + return *ChangedCodeOrErr; + } + + clang::ast_matchers::MatchFinder MatchFinder; + AtomicChanges Changes; + + private: + FileContentMappings FileContents = {{"header.h", ""}}; +}; + +class TransformerTest : public ClangRefactoringTestBase { + protected: + TransformerTest() { + appendToHeader(KHeaderContents); + } + + Transformer::ChangeConsumer changeRecorder() { + return [this](const AtomicChange& C) { Changes.push_back(C); }; + } +}; + +// Change strlen($s.c_str()) to $s.size(). +RewriteRule ruleStrlenSizeAny() { + ExprId S; + return RewriteRule() + .matching(callExpr( + callee(functionDecl(hasName("strlen"))), + hasArgument( + 0, cxxMemberCallExpr(on(S.bind()), + callee(cxxMethodDecl(hasName("c_str"))))))) + .replaceWith(S, ".size()") + .explain("Call size() method directly on object '", S, "'"); +} + +// Tests that code that looks the same not involving the canonical string type +// is still transformed. +TEST_F(TransformerTest, OtherStringTypeWithAny) { + std::string Input = + R"cc(namespace foo { + struct mystring { + char* c_str(); + }; + int f(mystring s) { return strlen(s.c_str()); } + } // namespace foo)cc"; + std::string Expected = + R"cc(namespace foo { + struct mystring { + char* c_str(); + }; + int f(mystring s) { return s.size(); } + } // namespace foo)cc"; + + Transformer T(ruleStrlenSizeAny(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// Given string s, change strlen($s.c_str()) to $s.size() +RewriteRule ruleStrlenSize() { + ExprId StringExpr; + auto StringType = namedDecl(hasAnyName("::basic_string", "::string")); + return makeRule( + callExpr( + callee(functionDecl(hasName("strlen"))), + hasArgument(0, cxxMemberCallExpr( + on(bind(StringExpr, + expr(hasType(isOrPointsTo(StringType))))), + callee(cxxMethodDecl(hasName("c_str")))))), + Stencil::cat(member(StringExpr, "size()")), + "Use size() method directly on string."); +} + +TEST_F(TransformerTest, StrlenSize) { + std::string Input = "int f(string s) { return strlen(s.c_str()); }"; + std::string Expected = "int f(string s) { return s.size(); }"; + + Transformer T(ruleStrlenSize(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +TEST_F(TransformerTest, StrlenSizePointer) { + std::string Input = "int f(string* s) { return strlen(s->c_str()); }"; + std::string Expected = "int f(string* s) { return s->size(); }"; + + Transformer T(ruleStrlenSize(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// Variant of StrlenSizePointer where the source is more verbose. We check for +// the same result. +TEST_F(TransformerTest, StrlenSizePointerExplicit) { + std::string Input = "int f(string* s) { return strlen((*s).c_str()); }"; + std::string Expected = "int f(string* s) { return s->size(); }"; + + Transformer T(ruleStrlenSize(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// Tests that code that looks the same but involves another type, even with the +// (unqualified) name "string", does not match. +TEST_F(TransformerTest, OtherStringType) { + std::string Input = + R"cc(namespace foo { + struct string { + char* c_str(); + }; + int f(string s) { return strlen(s.c_str()); } + } // namespace foo)cc"; + + Transformer T(ruleStrlenSize(), changeRecorder()); + T.registerMatchers(&MatchFinder); + // Input should not be changed. + compareSnippets(Input, rewrite(Input)); +} + +// Tests that expressions in macro arguments are rewritten (when applicable). +TEST_F(TransformerTest, StrlenSizeMacro) { + std::string Input = R"cc( +#define ID(e) e + int f(string s) { return ID(strlen(s.c_str())); })cc"; + std::string Expected = R"cc( +#define ID(e) e + int f(string s) { return ID(s.size()); })cc"; + + Transformer T(ruleStrlenSize(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// RuleStrlenSize, but where the user manually manages the AST node ids. +RewriteRule ruleStrlenSizeManual() { + auto StringType = namedDecl(hasAnyName("::basic_string", "::string")); + return makeRule( + callExpr(callee(functionDecl(hasName("strlen"))), + hasArgument( + 0, cxxMemberCallExpr( + on(id("s", expr(hasType(isOrPointsTo(StringType))))), + callee(cxxMethodDecl(hasName("c_str")))))), + Stencil::cat(member("s", "size()")), + "Use size() method directly on string."); +} + +TEST_F(TransformerTest, StrlenSizeManual) { + std::string Input = "int f(string s) { return strlen(s.c_str()); }"; + std::string Expected = "int f(string s) { return s.size(); }"; + + Transformer T(ruleStrlenSizeManual(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +RewriteRule ruleRenameFunctionAddInclude() { + ExprId Arg; + return RewriteRule() + .matching(callExpr(callee(functionDecl(hasName("update"))), + hasArgument(0, Arg.bind()))) + .replaceWith(addInclude("foo/updater.h"), "updateAddress(", Arg, ")"); +} + +RewriteRule ruleRenameFunctionRemoveInclude() { + ExprId Arg; + return RewriteRule() + .matching(callExpr(callee(functionDecl(hasName("updateAddress"))), + hasArgument(0, Arg.bind()))) + .replaceWith(removeInclude("foo/updater.h"), "update(", Arg, ")"); +} + +RewriteRule ruleRenameFunctionChangeInclude() { + ExprId Arg; + return RewriteRule() + .matching(callExpr(callee(functionDecl(hasName("update"))), + hasArgument(0, Arg.bind()))) + .replaceWith(removeInclude("bar/updater.h"), addInclude("foo/updater.h"), + "updateAddress(", Arg, ")"); +} + +TEST_F(TransformerTest, AddInclude) { + std::string Input = R"cc( + int update(int *i); + int f(int i) { return update(&i); })cc"; + std::string Expected = + R"cc(#include "foo/updater.h" + + int update(int *i); + int f(int i) { return updateAddress(&i); })cc"; + + Transformer T(ruleRenameFunctionAddInclude(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +TEST_F(TransformerTest, RemoveInclude) { + addFile("foo/updater.h", "int updateAddress(int *i);"); + std::string Input = + R"cc(#include "foo/updater.h" + + int update(int *i); + int f(int i) { return updateAddress(&i); })cc"; + std::string Expected = R"cc( + int update(int *i); + int f(int i) { return update(&i); })cc"; + + Transformer T(ruleRenameFunctionRemoveInclude(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +TEST_F(TransformerTest, ChangeInclude) { + addFile("bar/updater.h", "int update(int *i);"); + std::string Input = + R"cc(#include "bar/updater.h" + int f(int i) { return update(&i); })cc"; + std::string Expected = + R"cc(#include "foo/updater.h" + int f(int i) { return updateAddress(&i); })cc"; + + Transformer T(ruleRenameFunctionChangeInclude(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// +// Inspired by a clang-tidy request: +// +// Change if ($e) {log($level) << $msg;} to +// LOG_IF($level, $e) << $msg; +// +// We use a function, log(), rather than a macro, LOG(), to simplify the matcher +// needed. +RewriteRule ruleLogIf() { + ExprId Condition; + ExprId Level; + ExprId Msg; + auto LogCall = callExpr(callee(functionDecl(hasName("log"))), + hasArgument(0, Level.bind())); + return makeRule( + ifStmt(hasCondition(Condition.bind()), + hasThen(expr(ignoringImplicit(cxxOperatorCallExpr( + hasOverloadedOperatorName("<<"), hasArgument(0, LogCall), + hasArgument(1, Msg.bind()))))), + unless(hasElse(expr()))), + Stencil::cat("LOG_IF(", Level, ", ", parens(Condition), ") << ", Msg), + "Use LOG_IF() when LOG() is only member of if statement."); +} + +TEST_F(TransformerTest, LogIf) { + std::string Input = R"cc( + double x = 3.0; + void foo() { + if (x > 1.0) log(1) << "oh no!"; + } + void bar() { + if (x > 1.0) + log(1) << "oh no!"; + else + log(0) << "ok"; + } + )cc"; + std::string Expected = R"cc( + double x = 3.0; + void foo() { LOG_IF(1, (x > 1.0)) << "oh no!"; } + void bar() { + if (x > 1.0) + log(1) << "oh no!"; + else + log(0) << "ok"; + } + )cc"; + + Transformer T(ruleLogIf(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// Tests RuleLogIf when we expect the output condition not to be wrapped in +// parens. +TEST_F(TransformerTest, LogIfNoParens) { + std::string Input = R"cc( + double x = 3.0; + void foo() { + bool condition = x > 1.0; + if (condition) log(1) << "oh no!"; + } + void bar() { + if (x > 1.0) + log(1) << "oh no!"; + else + log(0) << "ok"; + } + )cc"; + std::string Expected = R"cc( + double x = 3.0; + void foo() { + bool condition = x > 1.0; + LOG_IF(1, condition) << "oh no!"; + } + void bar() { + if (x > 1.0) + log(1) << "oh no!"; + else + log(0) << "ok"; + } + )cc"; + + Transformer T(ruleLogIf(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// Change `if ($c) $t $e` to `if (!$c) $e $t`. +// +// N.B. This rule is oversimplified (since it is just for testing): it won't +// construct the correct result if the input has compound statements. +RewriteRule invertIf() { + ExprId C; + StmtId T, E; + return RewriteRule() + .matching( + ifStmt(hasCondition(C.bind()), hasThen(T.bind()), hasElse(E.bind()))) + .replaceWith("if(!(", C, ")) ", E, "; else ", T); +} + +// Use the lvalue-ref overloads of the RewriteRule builder methods. +RewriteRule invertIfLvalue() { + ExprId C; + StmtId T,E; + RewriteRule Rule; + Rule.matching( + ifStmt(hasCondition(C.bind()), hasThen(T.bind()), hasElse(E.bind()))) + .replaceWith("if(!(", C, ")) ", E, "; else ", T); + return Rule; +} + +TEST_F(TransformerTest, InvertIf) { + std::string Input = R"cc( + void foo() { + if (10 > 1.0) + log(1) << "oh no!"; + else + log(0) << "ok"; + } + )cc"; + std::string Expected = R"cc( + void foo() { + if (!(10 > 1.0)) + log(0) << "ok"; + else + log(1) << "oh no!"; + } + )cc"; + + Transformer T(invertIf(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +TEST_F(TransformerTest, InvertIfLvalue) { + std::string Input = R"cc( + void foo() { + if (10 > 1.0) + log(1) << "oh no!"; + else + log(0) << "ok"; + } + )cc"; + std::string Expected = R"cc( + void foo() { + if (!(10 > 1.0)) + log(0) << "ok"; + else + log(1) << "oh no!"; + } + )cc"; + + Transformer T(invertIfLvalue(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// +// (ProtoCommandLineFlag f){ $f.foo() } => { $f.GetProto().foo() } +// +RewriteRule ruleFlag() { + ExprId Flag; + return RewriteRule() + .matching(cxxMemberCallExpr( + on(expr(Flag.bind(), hasType(cxxRecordDecl( + hasName("proto::ProtoCommandLineFlag"))))), + unless(callee(cxxMethodDecl(hasName("GetProto")))))) + .change(Flag) + .replaceWith(Flag, ".GetProto()") + .explain("Use GetProto() to access proto fields."); +} + +TEST_F(TransformerTest, Flag) { + std::string Input = R"cc( + proto::ProtoCommandLineFlag flag; + int x = flag.foo(); + int y = flag.GetProto().foo(); + )cc"; + std::string Expected = R"cc( + proto::ProtoCommandLineFlag flag; + int x = flag.GetProto().foo(); + int y = flag.GetProto().foo(); + )cc"; + + Transformer T(ruleFlag(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// Variant of RuleFlag that doesn't rely on specifying a particular +// target. Instead, we explicitly bind the decl of the invoked method and refer +// to it in the code using the Name() operator. +RewriteRule ruleFlagNoTarget() { + ExprId Flag; + DeclId Method; + return makeRule( + cxxMemberCallExpr( + on(bind(Flag, expr(hasType(cxxRecordDecl( + hasName("proto::ProtoCommandLineFlag")))))), + callee(bind(Method, cxxMethodDecl(unless(hasName("GetProto"))))), + argumentCountIs(0)), + Stencil::cat(addInclude("fake/for/test.h"), member(Flag, "GetProto()"), + ".", name(Method), "()"), + "Use GetProto() to access proto fields."); +} + +// Tests use of Name() operator. +TEST_F(TransformerTest, FlagWithName) { + std::string Input = R"cc( + proto::ProtoCommandLineFlag flag; + int x = flag.foo(); + int y = flag.GetProto().foo(); + )cc"; + std::string Expected = + R"cc(#include "fake/for/test.h" + + proto::ProtoCommandLineFlag flag; + int x = flag.GetProto().foo(); + int y = flag.GetProto().foo(); + )cc"; + + Transformer T(ruleFlagNoTarget(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +RewriteRule ruleChangeFunctionName() { + DeclId Fun; + return RewriteRule() + .matching(functionDecl(hasName("bad"), Fun.bind())) + .change(Fun, NodePart::kName) + .replaceWith("good"); +} + +TEST_F(TransformerTest, NodePartNameNamedDecl) { + std::string Input = R"cc( + int bad(int x); + int bad(int x) { return x * x; } + )cc"; + std::string Expected = R"cc( + int good(int x); + int good(int x) { return x * x; } + )cc"; + + Transformer T(ruleChangeFunctionName(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +TEST_F(TransformerTest, NodePartNameDeclRef) { + std::string Input = R"cc( + template + T bad(T x) { + return x; + } + int neutral(int x) { return bad(x) * x; } + )cc"; + std::string Expected = R"cc( + template + T bad(T x) { + return x; + } + int neutral(int x) { return good(x) * x; } + )cc"; + + ExprId Ref; + Transformer T( + RewriteRule() + .matching(declRefExpr(to(functionDecl(hasName("bad"))), Ref.bind())) + .change(Ref, NodePart::kName) + .replaceWith("good"), + changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +TEST_F(TransformerTest, NodePartNameDeclRefFailure) { + std::string Input = R"cc( + struct Y {}; + int operator*(const Y&); + int neutral(int x) { Y y; return *y + x; } + )cc"; + + ExprId Ref; + Transformer T(RewriteRule() + .matching(declRefExpr(to(functionDecl()), Ref.bind())) + .change(Ref, NodePart::kName) + .replaceWith("good"), + changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Input, rewrite(Input)); +} + +RewriteRule ruleChangeFieldName() { + ExprId E; + return RewriteRule() + .matching(memberExpr(member(hasName("bad")), E.bind())) + .change(E, NodePart::kMember) + .replaceWith("good"); +} + +TEST_F(TransformerTest, NodePartMember) { + std::string Input = R"cc( + struct S { int bad; }; + int g() { S s; return s.bad; } + )cc"; + std::string Expected = R"cc( + struct S { int bad; }; + int g() { S s; return s.good; } + )cc"; + + Transformer T(ruleChangeFieldName(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// Tests the MultiTransformer class and, generally, the combination of multiple +// rules. +TEST_F(TransformerTest, MultiRule) { + std::string Input = R"cc( + proto::ProtoCommandLineFlag flag; + int x = flag.foo(); + int y = flag.GetProto().foo(); + int f(string s) { return strlen(s.c_str()); } + )cc"; + std::string Expected = R"cc( + proto::ProtoCommandLineFlag flag; + int x = flag.GetProto().foo(); + int y = flag.GetProto().foo(); + int f(string s) { return s.size(); } + )cc"; + + std::vector Rules; + Rules.emplace_back(ruleStrlenSize()); + Rules.emplace_back(ruleFlag()); + MultiTransformer T(std::move(Rules), changeRecorder(), &MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// A rule that finds function calls with two arguments where the arguments are +// the same identifier. +RewriteRule ruleDuplicateArgs() { + ExprId Arg0, Arg1; + return RewriteRule() + .matching(callExpr(argumentCountIs(2), hasArgument(0, Arg0.bind()), + hasArgument(1, Arg1.bind()))) + .where([Arg0, Arg1]( + const clang::ast_matchers::MatchFinder::MatchResult& result) { + auto* Ref0 = Arg0.getNodeAs(result); + auto* Ref1 = Arg1.getNodeAs(result); + return Ref0 != nullptr && Ref1 != nullptr && + Ref0->getDecl() == Ref1->getDecl(); + }) + .replaceWith("42"); +} + +TEST_F(TransformerTest, FilterPassed) { + std::string Input = R"cc( + int foo(int x, int y); + int x = 3; + int z = foo(x, x); + )cc"; + std::string Expected = R"cc( + int foo(int x, int y); + int x = 3; + int z = 42; + )cc"; + + Transformer T(ruleDuplicateArgs(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// +// Negative tests (where we expect no transformation to occur). +// + +TEST_F(TransformerTest, FilterFailed) { + std::string Input = R"cc( + int foo(int x, int y); + int x = 3; + int y = 17; + // Different identifiers. + int z = foo(x, y); + // One identifier, one not. + int w = foo(x, 3); + )cc"; + + Transformer T(ruleDuplicateArgs(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Input, rewrite(Input)); +} + +TEST_F(TransformerTest, NoTransformationInMacro) { + std::string Input = R"cc( +#define MACRO(str) strlen((str).c_str()) + int f(string s) { return MACRO(s); })cc"; + + Transformer T(ruleStrlenSizeAny(), changeRecorder()); + T.registerMatchers(&MatchFinder); + // The macro should be ignored. + compareSnippets(Input, rewrite(Input)); +} + +// This test handles the corner case where a macro called within another macro +// expands to matching code, but the matched code is an argument to the nested +// macro. A simple check of isMacroArgExpansion() vs. isMacroBodyExpansion() +// will get this wrong, and transform the code. This test verifies that no such +// transformation occurs. +TEST_F(TransformerTest, NoTransformationInNestedMacro) { + std::string Input = R"cc( +#define NESTED(e) e +#define MACRO(str) NESTED(strlen((str).c_str())) + int f(string s) { return MACRO(s); })cc"; + + Transformer T(ruleStrlenSizeAny(), changeRecorder()); + T.registerMatchers(&MatchFinder); + // The macro should be ignored. + compareSnippets(Input, rewrite(Input)); +} + +// +// maybeTransform tests +// + +// Formats an Optional for error messages. +std::string errString(const llvm::Optional& O) { + return O ? "Some " + *O : "None"; +} + +class MaybeTransformTest : public ::testing::Test { + protected: + // We need to initialize nodes_ here because Node has no default constructor. + MaybeTransformTest() : Node(init()) {} + + // Also initializes AstUnit. + clang::ast_type_traits::DynTypedNode init() { + std::string Code = R"cc( + int strlen(char*); + namespace foo { + struct mystring { + char* c_str(); + }; + int f(mystring s) { return strlen(s.c_str()); } + } // namespace foo)cc"; + AstUnit = buildASTFromCode(Code); + assert(AstUnit != nullptr && "AST not constructed"); + + auto Matches = clang::ast_matchers::match( + returnStmt(hasReturnValue(callExpr().bind("expr"))), *context()); + assert(Matches.size() == 1); + auto It = Matches[0].getMap().find("expr"); + assert(It != Matches[0].getMap().end() && "Match failure"); + return It->second; + } + + // Convenience method. + clang::ASTContext* context() { return &AstUnit->getASTContext(); } + + std::unique_ptr AstUnit; + clang::ast_type_traits::DynTypedNode Node; +}; + +// TODO: belongs in utility location or whatnot +// A very simple matcher for llvm::Optional values. +MATCHER_P(IsSomething, ValueMatcher, "") { + if (!arg) + return false; + return ::testing::ExplainMatchResult(ValueMatcher, *arg, result_listener); +} + +// Tests case where rewriting succeeds and the rule is applied. +TEST_F(MaybeTransformTest, SuccessRuleApplies) { + auto ResultOrErr = maybeTransform(ruleStrlenSizeAny(), Node, context()); + if (auto Err = ResultOrErr.takeError()) { + GTEST_FAIL() << "Rewrite failed: " << llvm::toString(std::move(Err)); + } + auto& Result = *ResultOrErr; + EXPECT_THAT(Result, IsSomething(testing::Eq("s.size()"))); +} + +TEST_F(MaybeTransformTest, SuccessRuleDoesNotApply) { + auto Rule = RewriteRule() + .matching(callExpr(callee(functionDecl(hasName("foo"))))) + .replaceWith("bar()"); + auto ResultOrErr = maybeTransform(Rule, Node, context()); + if (auto Err = ResultOrErr.takeError()) { + GTEST_FAIL() << "Rewrite failed: " << llvm::toString(std::move(Err)); + } + auto& Result = *ResultOrErr; + EXPECT_EQ(Result, llvm::None) << "Actual result is: " << errString(Result); +} + +TEST_F(MaybeTransformTest, FailureUnbound) { + using stencil_generators::id; + // Note: pattern needs to bind at least on id or the match will return no + // results. + auto Rule = RewriteRule() + .matching(expr().bind("e")) + .replaceWith(id("unbound"), ".size()"); + auto ResultOrErr = maybeTransform(Rule, Node, context()); + EXPECT_TRUE(llvm::errorToBool(ResultOrErr.takeError())) + << "Expected rewrite to fail on unbound node: " + << errString(*ResultOrErr); +} + +TEST_F(MaybeTransformTest, FailureMultiMatch) { + auto Rule = RewriteRule() + .matching(stmt(eachOf(callExpr().bind("expr"), + has(callExpr().bind("expr"))))) + .replaceWith(stencil_generators::id("expr"), ".size()"); + auto ResultOrErr = maybeTransform(Rule, Node, context()); + EXPECT_TRUE(llvm::errorToBool(ResultOrErr.takeError())) + << "Expected rewrite to fail on too many matches: " + << errString(*ResultOrErr); +} +} // namespace tooling +} // namespace clang