diff --git a/clang/include/clang/Tooling/Refactoring/Transformer.h b/clang/include/clang/Tooling/Refactoring/Transformer.h new file mode 100644 --- /dev/null +++ b/clang/include/clang/Tooling/Refactoring/Transformer.h @@ -0,0 +1,249 @@ +//===--- Transformer.h - Clang source-rewriting library ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Defines a library supporting the concise specification of clang-based +/// source-to-source transformations. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_TOOLING_REFACTOR_TRANSFORMER_H_ +#define LLVM_CLANG_TOOLING_REFACTOR_TRANSFORMER_H_ + +#include "NodeId.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/ASTMatchers/ASTMatchers.h" +#include "clang/ASTMatchers/ASTMatchersInternal.h" +#include "clang/Tooling/Refactoring/AtomicChange.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Error.h" +#include +#include +#include +#include +#include +#include + +namespace clang { +namespace tooling { +/// Determines the part of the AST node to replace. We support this to work +/// around the fact that the AST does not differentiate various syntactic +/// elements into their own nodes, so users can specify them relative to a node, +/// instead. +enum class NodePart { + /// The node itself. + Node, + /// Given a \c MemberExpr, selects the member's token. + Member, + /// Given a \c NamedDecl or \c CxxCtorInitializer, selects that token of the + /// relevant name, not including qualifiers. + Name, +}; + +using TextGenerator = + std::function; + +/// Description of a source-code transformation. +// +// A *rewrite rule* describes a transformation of source code. It has the +// following components: +// +// * Matcher: the pattern term, expressed as clang matchers (with Transformer +// extensions). +// +// * Target: the source code impacted by the rule. This identifies an AST node, +// or part thereof (\c TargetPart), whose source range indicates the extent of +// the replacement applied by the replacement term. By default, the extent is +// the node matched by the pattern term (\c NodePart::Node). Target's are +// typed (\c TargetKind), which guides the determination of the node extent +// and might, in the future, statically constrain the set of eligible +// NodeParts for a given node. +// +// * Replacement: a function that produces a replacement string for the target, +// based on the match result. +// +// * Explanation: explanation of the rewrite. +// +// Rules have an additional, implicit, component: the parameters. These are +// portions of the pattern which are left unspecified, yet named so that we can +// reference them in the replacement term. The structure of parameters can be +// partially or even fully specified, in which case they serve just to identify +// matched nodes for later reference rather than abstract over portions of the +// AST. However, in all cases, we refer to named portions of the pattern as +// parameters. +// +// Parameters can be declared explicitly using the NodeId type and its +// derivatives or left implicit by using the native support for binding ids in +// the clang matchers. +// +// RewriteRule is constructed in a "fluent" style, by creating a builder and +// chaining setters of individual components. We provide ref-qualified +// overloads of the setters to avoid an unnecessary copy when a RewriteRule is +// initialized from a temporary (which we expected to be the common case). +// \code +// RewriteRule R = buildRule(functionDecl(...)).replaceWith(...); +// \endcode +class RewriteRule { + // Id used as the default target of each match. + static constexpr char RootId[] = "___root___"; + + // Supports any (top-level node) matcher type. + ast_matchers::internal::DynTypedMatcher Matcher; + // The (bound) id of the node whose source will be replaced. This id should + // never be the empty string. + std::string Target; + ast_type_traits::ASTNodeKind TargetKind; + NodePart TargetPart; + TextGenerator Replacement; + TextGenerator Explanation; + + static ast_matchers::internal::DynTypedMatcher + makeMatcher(ast_matchers::internal::DynTypedMatcher M) { + M.setAllowBind(true); + // RewriteRule guarantees that the node described by the matcher will always + // be accessible as `matchedNode()`, so we bind it here. `tryBind` is + // guaranteed to succeed, because `AllowBind` was set to true. + return *M.tryBind(matchedNode()); + } + +public: + RewriteRule(ast_matchers::internal::DynTypedMatcher M) + : Matcher(makeMatcher(std::move(M))), Target(RootId), + TargetKind(Matcher.getSupportedKind()), TargetPart(NodePart::Node) { + Matcher.setAllowBind(true); + } + + // The bound id of the node corresponding to the matcher. + static llvm::StringRef matchedNode() { return RootId; } + + void setTarget(StringRef S) { Target = S; } + void setTargetKind(ast_type_traits::ASTNodeKind K) { TargetKind = K; } + void setTargetPart(NodePart P) { TargetPart = P; } + void setReplacement(TextGenerator T) { Replacement = std::move(T); } + void setExplanation(TextGenerator T) { Explanation = std::move(T); } + + const ast_matchers::internal::DynTypedMatcher &matcher() const { + return Matcher; + } + llvm::StringRef target() const { return Target; } + ast_type_traits::ASTNodeKind targetKind() const { return TargetKind; } + NodePart targetPart() const { return TargetPart; } + + std::string + replacement(const ast_matchers::MatchFinder::MatchResult &R) const { + return Replacement(R); + } + + std::string + explanation(const ast_matchers::MatchFinder::MatchResult &R) const { + return Explanation(R); + } +}; + +/// A fluent, builder class for \c RewriteRule. See comments on \c RewriteRule, +/// above. +class RewriteRuleBuilder { + RewriteRule Rule; + +public: + RewriteRuleBuilder(ast_matchers::internal::DynTypedMatcher M) + : Rule(std::move(M)) {} + + /// (Implicit) "build" operator to build a RewriteRule from this builder. + operator RewriteRule() && { return std::move(Rule); } + + // Sets the target kind based on a clang AST node type. + template RewriteRuleBuilder &&as() &&; + + template + RewriteRuleBuilder &&change(const TypedNodeId &Target, + NodePart Part = NodePart::Node) &&; + + RewriteRuleBuilder &&replaceWith(TextGenerator Replacement) &&; + RewriteRuleBuilder &&replaceWith(std::string Replacement) && { + return std::move( + std::move(*this).replaceWith(text(std::move(Replacement)))); + } + + RewriteRuleBuilder &&because(TextGenerator Explanation) &&; + RewriteRuleBuilder &&because(std::string Explanation) && { + return std::move(std::move(*this).because(text(std::move(Explanation)))); + } + +private: + // Wraps a string as a TextGenerator. + static TextGenerator text(std::string M) { + return [M](const ast_matchers::MatchFinder::MatchResult &) { return M; }; + } +}; + +/// Convenience factory functions for starting construction of a \c RewriteRule. +inline RewriteRuleBuilder buildRule(ast_matchers::internal::DynTypedMatcher M) { + return RewriteRuleBuilder(std::move(M)); +} +template +RewriteRuleBuilder buildRule(ast_matchers::internal::Matcher M) { + return RewriteRuleBuilder(M); +} + +template RewriteRuleBuilder &&RewriteRuleBuilder::as() && { + Rule.setTargetKind(ast_type_traits::ASTNodeKind::getFromNodeKind()); + return std::move(*this); +} + +template +RewriteRuleBuilder &&RewriteRuleBuilder::change(const TypedNodeId &TargetId, + NodePart Part) && { + Rule.setTarget(TargetId.id()); + Rule.setTargetKind(ast_type_traits::ASTNodeKind::getFromNodeKind()); + Rule.setTargetPart(Part); + return std::move(*this); +} + +/// A source "transformation," represented by a character range in the source to +/// be replaced and a corresponding replacement string. +struct Transformation { + CharSourceRange Range; + std::string Replacement; +}; + +/// Attempts to apply a rule to a match. Fails if the match is not eligible for +/// rewriting or, for example, if any invariants are violated relating to bound +/// nodes in the match. +Expected +applyRewriteRule(const RewriteRule &Rule, + const ast_matchers::MatchFinder::MatchResult &Match); + +/// Handles the matcher and callback registration for a single rewrite rule, as +/// defined by the arguments of the constructor. +class Transformer : public ast_matchers::MatchFinder::MatchCallback { +public: + using ChangeConsumer = + std::function; + + /// \param Consumer Receives each successful rewrites as an \c AtomicChange. + Transformer(RewriteRule Rule, ChangeConsumer Consumer) + : Rule(std::move(Rule)), Consumer(std::move(Consumer)) {} + + /// N.B. Passes `this` pointer to `MatchFinder`. So, this object should not + /// be moved after this call. + void registerMatchers(ast_matchers::MatchFinder *MatchFinder); + + /// Not called directly by users -- called by the framework, via base class + /// pointer. + void run(const ast_matchers::MatchFinder::MatchResult &Result) override; + +private: + RewriteRule Rule; + /// Receives each successful rewrites as an \c AtomicChange. + ChangeConsumer Consumer; +}; +} // namespace tooling +} // namespace clang + +#endif // LLVM_CLANG_TOOLING_REFACTOR_TRANSFORMER_H_ diff --git a/clang/lib/Tooling/Refactoring/CMakeLists.txt b/clang/lib/Tooling/Refactoring/CMakeLists.txt --- a/clang/lib/Tooling/Refactoring/CMakeLists.txt +++ b/clang/lib/Tooling/Refactoring/CMakeLists.txt @@ -13,6 +13,7 @@ Rename/USRFindingAction.cpp Rename/USRLocFinder.cpp NodeId.cpp + Transformer.cpp LINK_LIBS clangAST diff --git a/clang/lib/Tooling/Refactoring/Transformer.cpp b/clang/lib/Tooling/Refactoring/Transformer.cpp new file mode 100644 --- /dev/null +++ b/clang/lib/Tooling/Refactoring/Transformer.cpp @@ -0,0 +1,197 @@ +//===--- Transformer.cpp - Transformer library implementation ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Tooling/Refactoring/Transformer.h" +#include "clang/AST/Expr.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/ASTMatchers/ASTMatchers.h" +#include "clang/Basic/Diagnostic.h" +#include "clang/Basic/SourceLocation.h" +#include "clang/Rewrite/Core/Rewriter.h" +#include "clang/Tooling/FixIt.h" +#include "clang/Tooling/Refactoring.h" +#include "clang/Tooling/Refactoring/AtomicChange.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Errc.h" +#include "llvm/Support/Error.h" +#include +#include +#include +#include + +using namespace clang; +using namespace tooling; + +using ::clang::ast_matchers::MatchFinder; +using ::clang::ast_type_traits::ASTNodeKind; +using ::clang::ast_type_traits::DynTypedNode; +using ::llvm::Error; +using ::llvm::Expected; +using ::llvm::Optional; +using ::llvm::StringError; +using ::llvm::StringRef; + +using MatchResult = MatchFinder::MatchResult; + +static bool isOriginMacroBody(const clang::SourceManager &SM, + clang::SourceLocation Loc) { + while (Loc.isMacroID()) { + if (SM.isMacroBodyExpansion(Loc)) + return true; + // Otherwise, it must be in an argument, so we continue searching up the + // invocation stack. getImmediateMacroCallerLoc() gives the location of the + // argument text, inside the call text. + Loc = SM.getImmediateMacroCallerLoc(Loc); + } + return false; +} + +static llvm::Error invalidArgumentError(llvm::Twine Message) { + return llvm::make_error(llvm::errc::invalid_argument, Message); +} + +static llvm::Error unboundNodeError(StringRef Role, StringRef Id) { + return invalidArgumentError(Role + " (=" + Id + ") references unbound node"); +} + +static llvm::Error typeError(llvm::Twine Message, const ASTNodeKind &Kind) { + return invalidArgumentError(Message + " (node kind is " + Kind.asStringRef() + + ")"); +} + +static llvm::Error missingPropertyError(llvm::Twine Description, + StringRef Property) { + return invalidArgumentError(Description + " requires property '" + Property + + "'"); +} + +static Expected getTarget(const DynTypedNode &Node, + ASTNodeKind Kind, + NodePart TargetPart, + ASTContext &Context) { + switch (TargetPart) { + case NodePart::Node: { + // For non-expression statements, associate any trailing semicolon with the + // statement text. However, if the target was intended as an expression (as + // indicated by its kind) then we do not associate any trailing semicolon + // with it. We only associate the exact expression text. + if (Node.get() != nullptr) { + auto ExprKind = ASTNodeKind::getFromNodeKind(); + if (!ExprKind.isBaseOf(Kind)) + return fixit::getExtendedRange(Node, tok::TokenKind::semi, Context); + } + return CharSourceRange::getTokenRange(Node.getSourceRange()); + } + case NodePart::Member: + if (auto *M = Node.get()) + return CharSourceRange::getTokenRange( + M->getMemberNameInfo().getSourceRange()); + return typeError("NodePart::Member applied to non-MemberExpr", + Node.getNodeKind()); + case NodePart::Name: + if (const auto *D = Node.get()) { + if (!D->getDeclName().isIdentifier()) + return missingPropertyError("NodePart::Name", "identifier"); + SourceLocation L = D->getLocation(); + auto R = CharSourceRange::getTokenRange(L, L); + // Verify that the range covers exactly the name. + // FIXME: extend this code to support cases like `operator +` or + // `foo` for which this range will be too short. Doing so will + // require subcasing `NamedDecl`, because it doesn't provide virtual + // access to the \c DeclarationNameInfo. + if (fixit::internal::getText(R, Context) != D->getName()) + return CharSourceRange(); + return R; + } + if (const auto *E = Node.get()) { + if (!E->getNameInfo().getName().isIdentifier()) + return missingPropertyError("NodePart::Name", "identifier"); + SourceLocation L = E->getLocation(); + return CharSourceRange::getTokenRange(L, L); + } + if (const auto *I = Node.get()) { + if (!I->isMemberInitializer() && I->isWritten()) + return missingPropertyError("NodePart::Name", + "explicit member initializer"); + SourceLocation L = I->getMemberLocation(); + return CharSourceRange::getTokenRange(L, L); + } + return typeError( + "NodePart::Name applied to neither DeclRefExpr, NamedDecl nor " + "CXXCtorInitializer", + Node.getNodeKind()); + } + llvm_unreachable("Unexpected case in NodePart type."); +} + +namespace clang { +namespace tooling { +Expected +applyRewriteRule(const RewriteRule &Rule, + const ast_matchers::MatchFinder::MatchResult &Match) { + if (Match.Context->getDiagnostics().hasErrorOccurred()) + return Transformation(); + + auto &NodesMap = Match.Nodes.getMap(); + auto It = NodesMap.find(Rule.target()); + if (It == NodesMap.end()) + return unboundNodeError("rule.target", Rule.target()); + + Expected TargetOrErr = getTarget( + It->second, Rule.targetKind(), Rule.targetPart(), *Match.Context); + if (auto Err = + llvm::handleErrors(TargetOrErr.takeError(), [&Rule](StringError &E) { + return invalidArgumentError("Failure targeting node" + + Rule.target() + ": " + E.getMessage()); + })) + return std::move(Err); + auto &Target = *TargetOrErr; + if (Target.isInvalid() || + isOriginMacroBody(*Match.SourceManager, Target.getBegin())) + return Transformation(); + + return Transformation{Target, Rule.replacement(Match)}; +} +} // namespace tooling +} // namespace clang + +constexpr char RewriteRule::RootId[]; + +RewriteRuleBuilder &&RewriteRuleBuilder::replaceWith(TextGenerator T) && { + Rule.setReplacement(std::move(T)); + return std::move(*this); +} + +RewriteRuleBuilder &&RewriteRuleBuilder::because(TextGenerator T) && { + Rule.setExplanation(std::move(T)); + return std::move(*this); +} + +void Transformer::registerMatchers(MatchFinder *MatchFinder) { + MatchFinder->addDynamicMatcher(Rule.matcher(), this); +} + +void Transformer::run(const MatchResult &Result) { + auto ChangeOrErr = applyRewriteRule(Rule, Result); + if (auto Err = ChangeOrErr.takeError()) { + llvm::errs() << "Rewrite failed: " << llvm::toString(std::move(Err)) + << "\n"; + return; + } + auto &Change = *ChangeOrErr; + auto &Range = Change.Range; + if (Range.isInvalid()) { + // No rewrite applied (but no error encountered either). + return; + } + AtomicChange AC(*Result.SourceManager, Range.getBegin()); + if (auto Err = AC.replace(*Result.SourceManager, Range, Change.Replacement)) + AC.setError(llvm::toString(std::move(Err))); + Consumer(AC); +} diff --git a/clang/unittests/Tooling/CMakeLists.txt b/clang/unittests/Tooling/CMakeLists.txt --- a/clang/unittests/Tooling/CMakeLists.txt +++ b/clang/unittests/Tooling/CMakeLists.txt @@ -50,6 +50,7 @@ ReplacementsYamlTest.cpp RewriterTest.cpp ToolingTest.cpp + TransformerTest.cpp ) target_link_libraries(ToolingTests diff --git a/clang/unittests/Tooling/TransformerTest.cpp b/clang/unittests/Tooling/TransformerTest.cpp new file mode 100644 --- /dev/null +++ b/clang/unittests/Tooling/TransformerTest.cpp @@ -0,0 +1,418 @@ +//===- unittest/Tooling/TransformerTest.cpp -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Tooling/Refactoring/Transformer.h" + +#include "clang/ASTMatchers/ASTMatchers.h" +#include "clang/Tooling/Tooling.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace clang { +namespace tooling { +namespace { +using ::clang::ast_matchers::anyOf; +using ::clang::ast_matchers::argumentCountIs; +using ::clang::ast_matchers::callee; +using ::clang::ast_matchers::callExpr; +using ::clang::ast_matchers::cxxMemberCallExpr; +using ::clang::ast_matchers::cxxMethodDecl; +using ::clang::ast_matchers::cxxRecordDecl; +using ::clang::ast_matchers::declRefExpr; +using ::clang::ast_matchers::expr; +using ::clang::ast_matchers::functionDecl; +using ::clang::ast_matchers::hasAnyName; +using ::clang::ast_matchers::hasArgument; +using ::clang::ast_matchers::hasDeclaration; +using ::clang::ast_matchers::hasElse; +using ::clang::ast_matchers::hasName; +using ::clang::ast_matchers::hasType; +using ::clang::ast_matchers::ifStmt; +using ::clang::ast_matchers::member; +using ::clang::ast_matchers::memberExpr; +using ::clang::ast_matchers::namedDecl; +using ::clang::ast_matchers::on; +using ::clang::ast_matchers::pointsTo; +using ::clang::ast_matchers::to; +using ::clang::ast_matchers::unless; + +constexpr char KHeaderContents[] = R"cc( + struct string { + string(const char*); + char* c_str(); + int size(); + }; + int strlen(const char*); + + namespace proto { + struct PCFProto { + int foo(); + }; + struct ProtoCommandLineFlag : PCFProto { + PCFProto& GetProto(); + }; + } // namespace proto +)cc"; +} // namespace + +static clang::ast_matchers::internal::Matcher +isOrPointsTo(const clang::ast_matchers::DeclarationMatcher &TypeMatcher) { + return anyOf(hasDeclaration(TypeMatcher), pointsTo(TypeMatcher)); +} + +static std::string format(llvm::StringRef Code) { + const std::vector Ranges(1, Range(0, Code.size())); + auto Style = format::getLLVMStyle(); + const auto Replacements = format::reformat(Style, Code, Ranges); + auto Formatted = applyAllReplacements(Code, Replacements); + if (!Formatted) { + ADD_FAILURE() << "Could not format code: " + << llvm::toString(Formatted.takeError()); + return std::string(); + } + return *Formatted; +} + +void compareSnippets(llvm::StringRef Expected, + const llvm::Optional &MaybeActual) { + ASSERT_TRUE(MaybeActual) << "Rewrite failed. Expecting: " << Expected; + auto Actual = *MaybeActual; + std::string HL = "#include \"header.h\"\n"; + auto I = Actual.find(HL); + if (I != std::string::npos) { + Actual.erase(I, HL.size()); + } + EXPECT_EQ(format(Expected), format(Actual)); +} + +// FIXME: consider separating this class into its own file(s). +class ClangRefactoringTestBase : public testing::Test { +protected: + void appendToHeader(llvm::StringRef S) { FileContents[0].second += S; } + + void addFile(llvm::StringRef Filename, llvm::StringRef Content) { + FileContents.emplace_back(Filename, Content); + } + + llvm::Optional rewrite(llvm::StringRef Input) { + std::string Code = ("#include \"header.h\"\n" + Input).str(); + auto Factory = newFrontendActionFactory(&MatchFinder); + if (!runToolOnCodeWithArgs( + Factory->create(), Code, std::vector(), "input.cc", + "clang-tool", std::make_shared(), + FileContents)) { + return None; + } + auto ChangedCodeOrErr = + applyAtomicChanges("input.cc", Code, Changes, ApplyChangesSpec()); + if (auto Err = ChangedCodeOrErr.takeError()) { + llvm::errs() << "Change failed: " << llvm::toString(std::move(Err)) + << "\n"; + return None; + } + return *ChangedCodeOrErr; + } + + clang::ast_matchers::MatchFinder MatchFinder; + AtomicChanges Changes; + +private: + FileContentMappings FileContents = {{"header.h", ""}}; +}; + +class TransformerTest : public ClangRefactoringTestBase { +protected: + TransformerTest() { appendToHeader(KHeaderContents); } + + Transformer::ChangeConsumer changeRecorder() { + return [this](const AtomicChange &C) { Changes.push_back(C); }; + } +}; + +// Given string s, change strlen($s.c_str()) to $s.size(). +RewriteRule ruleStrlenSize() { + ExprId StringExpr; + auto StringType = namedDecl(hasAnyName("::basic_string", "::string")); + return buildRule( + callExpr( + callee(functionDecl(hasName("strlen"))), + hasArgument(0, cxxMemberCallExpr( + on(expr(StringExpr.bind(), + hasType(isOrPointsTo(StringType)))), + callee(cxxMethodDecl(hasName("c_str"))))))) + // Specify the intended type explicitly, because the matcher "type" of + // `callExpr()` is `Stmt`, not `Expr`. + .as() + .replaceWith("REPLACED") + .because("Use size() method directly on string."); +} + +TEST_F(TransformerTest, StrlenSize) { + std::string Input = "int f(string s) { return strlen(s.c_str()); }"; + std::string Expected = "int f(string s) { return REPLACED; }"; + + Transformer T(ruleStrlenSize(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// Tests that no change is applied when a match is not expected. +TEST_F(TransformerTest, NoMatch) { + std::string Input = "int f(string s) { return s.size(); }"; + + Transformer T(ruleStrlenSize(), changeRecorder()); + T.registerMatchers(&MatchFinder); + // Input should not be changed. + compareSnippets(Input, rewrite(Input)); +} + +// Tests that expressions in macro arguments are rewritten (when applicable). +TEST_F(TransformerTest, StrlenSizeMacro) { + std::string Input = R"cc( +#define ID(e) e + int f(string s) { return ID(strlen(s.c_str())); })cc"; + std::string Expected = R"cc( +#define ID(e) e + int f(string s) { return ID(REPLACED); })cc"; + + Transformer T(ruleStrlenSize(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// Tests replacing an expression. +TEST_F(TransformerTest, Flag) { + ExprId Flag; + RewriteRule Rule = + buildRule(cxxMemberCallExpr( + on(expr(Flag.bind(), hasType(cxxRecordDecl(hasName( + "proto::ProtoCommandLineFlag"))))), + unless(callee(cxxMethodDecl(hasName("GetProto")))))) + .change(Flag) + .replaceWith("EXPR") + .because("Use GetProto() to access proto fields."); + + std::string Input = R"cc( + proto::ProtoCommandLineFlag flag; + int x = flag.foo(); + int y = flag.GetProto().foo(); + )cc"; + std::string Expected = R"cc( + proto::ProtoCommandLineFlag flag; + int x = EXPR.foo(); + int y = flag.GetProto().foo(); + )cc"; + + Transformer T(Rule, changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +TEST_F(TransformerTest, NodePartNameNamedDecl) { + DeclId Fun; + RewriteRule Rule = buildRule(functionDecl(hasName("bad"), Fun.bind())) + .change(Fun, NodePart::Name) + .replaceWith("good"); + + std::string Input = R"cc( + int bad(int x); + int bad(int x) { return x * x; } + )cc"; + std::string Expected = R"cc( + int good(int x); + int good(int x) { return x * x; } + )cc"; + + Transformer T(Rule, changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +TEST_F(TransformerTest, NodePartNameDeclRef) { + std::string Input = R"cc( + template + T bad(T x) { + return x; + } + int neutral(int x) { return bad(x) * x; } + )cc"; + std::string Expected = R"cc( + template + T bad(T x) { + return x; + } + int neutral(int x) { return good(x) * x; } + )cc"; + + ExprId Ref; + Transformer T( + buildRule(declRefExpr(to(functionDecl(hasName("bad"))), Ref.bind())) + .change(Ref, NodePart::Name) + .replaceWith("good"), + changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +TEST_F(TransformerTest, NodePartNameDeclRefFailure) { + std::string Input = R"cc( + struct Y { + int operator*(); + }; + int neutral(int x) { + Y y; + int (Y::*ptr)() = &Y::operator*; + return *y + x; + } + )cc"; + + ExprId Ref; + Transformer T(buildRule(declRefExpr(to(functionDecl()), Ref.bind())) + .change(Ref, NodePart::Name) + .replaceWith("good"), + changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Input, rewrite(Input)); +} + +TEST_F(TransformerTest, NodePartMember) { + ExprId E; + RewriteRule Rule = buildRule(memberExpr(member(hasName("bad")), E.bind())) + .change(E, NodePart::Member) + .replaceWith("good"); + + std::string Input = R"cc( + struct S { + int bad; + }; + int g() { + S s; + return s.bad; + } + )cc"; + std::string Expected = R"cc( + struct S { + int bad; + }; + int g() { + S s; + return s.good; + } + )cc"; + + Transformer T(Rule, changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +TEST_F(TransformerTest, NodePartMemberQualified) { + ExprId E; + RewriteRule Rule = buildRule(memberExpr(E.bind())) + .change(E, NodePart::Member) + .replaceWith("good"); + + std::string Input = R"cc( + struct S { + int bad; + int good; + }; + struct T : public S { + int bad; + }; + int g() { + T t; + return t.S::bad; + } + )cc"; + std::string Expected = R"cc( + struct S { + int bad; + int good; + }; + struct T : public S { + int bad; + }; + int g() { + T t; + return t.S::good; + } + )cc"; + + Transformer T(Rule, changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +TEST_F(TransformerTest, NodePartMemberMultiToken) { + std::string Input = R"cc( + struct Y { + int operator*(); + int good(); + template void foo(T t); + }; + int neutral(int x) { + Y y; + y.template foo(3); + return y.operator *(); + } + )cc"; + std::string Expected = R"cc( + struct Y { + int operator*(); + int good(); + template void foo(T t); + }; + int neutral(int x) { + Y y; + y.template good(3); + return y.good(); + } + )cc"; + + ExprId MemExpr; + Transformer T(buildRule(memberExpr(MemExpr.bind())) + .change(MemExpr, NodePart::Member) + .replaceWith("good"), + changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// +// Negative tests (where we expect no transformation to occur). +// + +TEST_F(TransformerTest, NoTransformationInMacro) { + std::string Input = R"cc( +#define MACRO(str) strlen((str).c_str()) + int f(string s) { return MACRO(s); })cc"; + + Transformer T(ruleStrlenSize(), changeRecorder()); + T.registerMatchers(&MatchFinder); + // The macro should be ignored. + compareSnippets(Input, rewrite(Input)); +} + +// This test handles the corner case where a macro called within another macro +// expands to matching code, but the matched code is an argument to the nested +// macro. A simple check of isMacroArgExpansion() vs. isMacroBodyExpansion() +// will get this wrong, and transform the code. This test verifies that no such +// transformation occurs. +TEST_F(TransformerTest, NoTransformationInNestedMacro) { + std::string Input = R"cc( +#define NESTED(e) e +#define MACRO(str) NESTED(strlen((str).c_str())) + int f(string s) { return MACRO(s); })cc"; + + Transformer T(ruleStrlenSize(), changeRecorder()); + T.registerMatchers(&MatchFinder); + // The macro should be ignored. + compareSnippets(Input, rewrite(Input)); +} +} // namespace tooling +} // namespace clang