diff --git a/clang/include/clang/Tooling/Refactoring/Transformer.h b/clang/include/clang/Tooling/Refactoring/Transformer.h new file mode 100644 --- /dev/null +++ b/clang/include/clang/Tooling/Refactoring/Transformer.h @@ -0,0 +1,285 @@ +//===--- Transformer.h - Clang source-rewriting library ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Defines a library supporting the concise specification of clang-based +/// source-to-source transformations. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_TOOLING_REFACTOR_TRANSFORMER_H_ +#define LLVM_CLANG_TOOLING_REFACTOR_TRANSFORMER_H_ + +#include "NodeId.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/ASTMatchers/ASTMatchers.h" +#include "clang/ASTMatchers/ASTMatchersInternal.h" +#include "clang/Tooling/Refactoring/AtomicChange.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Error.h" +#include +#include +#include +#include +#include +#include + +namespace clang { +namespace tooling { + +/// \name Matcher-type abbreviations for all top-level classes in the +/// AST class hierarchy. +/// @{ +using ast_matchers::CXXCtorInitializerMatcher; +using ast_matchers::DeclarationMatcher; +using ast_matchers::NestedNameSpecifierLocMatcher; +using ast_matchers::NestedNameSpecifierMatcher; +using ast_matchers::StatementMatcher; +using ast_matchers::TypeLocMatcher; +using ast_matchers::TypeMatcher; +using TemplateArgumentMatcher = + ast_matchers::internal::Matcher; +using TemplateNameMatcher = ast_matchers::internal::Matcher; +using ast_matchers::internal::DynTypedMatcher; +/// @} + +/// A simple abstraction of a filter for match results. Currently, it simply +/// wraps a predicate, but we may extend the functionality to support a simple +/// boolean expression language for constructing filters. +class MatchFilter { +public: + using Predicate = + std::function; + + MatchFilter() + : Filter([](const ast_matchers::MatchFinder::MatchResult &) { + return true; + }) {} + explicit MatchFilter(Predicate P) : Filter(std::move(P)) {} + + MatchFilter(const MatchFilter &) = default; + MatchFilter(MatchFilter &&) = default; + MatchFilter &operator=(const MatchFilter &) = default; + MatchFilter &operator=(MatchFilter &&) = default; + + bool matches(const ast_matchers::MatchFinder::MatchResult &Result) const { + return Filter(Result); + } + +private: + Predicate Filter; +}; + +/// Determines the part of the AST node to replace. We support this to work +/// around the fact that the AST does not differentiate various syntactic +/// elements into their own nodes, so users can specify them relative to a node, +/// instead. +enum class NodePart { + /// The node itself. + Node, + /// Given a \c MemberExpr, selects the member's token. + Member, + /// Given a \c NamedDecl or \c CxxCtorInitializer, selects that token of the + /// relevant name, not including qualifiers. + Name, +}; + +using TextGenerator = std::function( + const ast_matchers::MatchFinder::MatchResult &)>; + +/// Description of a source-code transformation. +// +// A *rewrite rule* describes a transformation of source code. It has the +// following components: +// +// * Matcher: the pattern term, expressed as clang matchers (with Transformer +// extensions). +// +// * Where: a "where clause" -- that is, a predicate over (matched) AST nodes +// that restricts matches beyond what is (easily) expressable as a pattern. +// +// * Target: the source code impacted by the rule. This identifies an AST node, +// or part thereof, whose source range indicates the extent of the replacement +// applied by the replacement term. By default, the extent is the node +// matched by the pattern term. +// +// * Replacement: a function that produces a replacement string for the target, +// based on the match result. +// +// * Explanation: explanation of the rewrite. +// +// Rules have an additional, implicit, component: the parameters. These are +// portions of the pattern which are left unspecified, yet named so that we can +// reference them in the replacement term. The structure of parameters can be +// partially or even fully specified, in which case they serve just to identify +// matched nodes for later reference rather than abstract over portions of the +// AST. However, in all cases, we refer to named portions of the pattern as +// parameters. +// +// Parameters can be declared explicitly using the NodeId type and its +// derivatives or left implicit by using the native support for binding ids in +// the clang matchers. +// +// RewriteRule is constructed in a "fluent" style, by chaining setters of +// individual components. We provide ref-qualified overloads of the setters to +// avoid an unnecessary copy when a RewriteRule is initialized from a temporary, +// like: +// \code +// RewriteRule R = RewriteRule().matching(functionDecl(...)).replaceWith(...); +// \endcode +class RewriteRule { +public: + RewriteRule(DynTypedMatcher M) + : Matcher(std::move(M)), TargetKind(Matcher.getSupportedKind()) { + Matcher.setAllowBind(true); + } + template + RewriteRule(ast_matchers::internal::Matcher M) + : RewriteRule(makeMatcher(std::move(M))) {} + + RewriteRule(const RewriteRule &) = default; + RewriteRule(RewriteRule &&) = default; + RewriteRule &operator=(const RewriteRule &) = default; + RewriteRule &operator=(RewriteRule &&) = default; + + RewriteRule &where(MatchFilter::Predicate Filter) &; + RewriteRule &&where(MatchFilter::Predicate Filter) && { + return std::move(where(std::move(Filter))); + } + + template RewriteRule &as() &; + template RewriteRule &&as() && { return std::move(as()); } + + RewriteRule &change(const NodeId &Target, NodePart Part = NodePart::Node) &; + RewriteRule &&change(const NodeId &Target, + NodePart Part = NodePart::Node) && { + return std::move(change(Target, Part)); + } + template + RewriteRule &change(const TypedNodeId &Target, + NodePart Part = NodePart::Node) &; + template + RewriteRule &&change(const TypedNodeId &Target, + NodePart Part = NodePart::Node) && { + return std::move(change(Target, Part)); + } + + RewriteRule &replaceWith(TextGenerator Replacement) &; + RewriteRule &&replaceWith(TextGenerator Replacement) && { + return std::move(replaceWith(std::move(Replacement))); + } + + RewriteRule &because(TextGenerator Explanation) &; + RewriteRule &&because(TextGenerator Explanation) && { + return std::move(because(std::move(Explanation))); + } + + const DynTypedMatcher &matcher() const { return Matcher; } + const MatchFilter &filter() const { return Filter; } + llvm::StringRef target() const { return Target; } + ast_type_traits::ASTNodeKind targetKind() const { return TargetKind; } + NodePart targetPart() const { return TargetPart; } + + llvm::Expected + replacement(const ast_matchers::MatchFinder::MatchResult &R) const { + return Replacement(R); + } + + llvm::Expected + explanation(const ast_matchers::MatchFinder::MatchResult &R) const { + return Explanation(R); + } + +private: + template static DynTypedMatcher makeMatcher(MatcherT M) { + // Copy `M`'s (underlying) `DynTypedMatcher`. + DynTypedMatcher DM = M; + DM.setAllowBind(true); + // RewriteRule guarantees that the node described by the matcher will always + // be accessible as `RootId`, so we bind it here. `tryBind` is guaranteed to + // succeed, because `AllowBind` is true. + return *DM.tryBind(RootId); + } + + // Id used as the default target of each match. + static constexpr char RootId[] = "___root___"; + + // Supports any (top-level node) matcher type. + DynTypedMatcher Matcher; + MatchFilter Filter; + // The (bound) id of the node whose source will be replaced. This id should + // never be the empty string. By default, refers to the node matched by + // `Matcher`. + std::string Target = RootId; + ast_type_traits::ASTNodeKind TargetKind; + NodePart TargetPart = NodePart::Node; + TextGenerator Replacement; + TextGenerator Explanation; +}; + +template RewriteRule &RewriteRule::as() & { + TargetKind = ast_type_traits::ASTNodeKind::getFromNodeKind(); + return *this; +} + +template +RewriteRule &RewriteRule::change(const TypedNodeId &TargetId, + NodePart Part) & { + Target = std::string(TargetId.id()); + TargetKind = ast_type_traits::ASTNodeKind::getFromNodeKind(); + TargetPart = Part; + return *this; +} + +// Convenience factory function for the common case where a rule has a statement +// matcher, template and explanation. +RewriteRule makeRule(StatementMatcher Matcher, TextGenerator Replacement, + const std::string &Explanation); + +/// A source "transformation," represented by a character range in the source to +/// be replaced and a corresponding replacement string. +struct Transformation { + CharSourceRange Range; + std::string Replacement; +}; + +/// Attempts to apply a rule to a match. Fails if the match is not eligible for +/// rewriting or, for example, if any invariants are violated relating to bound +/// nodes in the match. +Expected +applyRewriteRule(const RewriteRule &Rule, + const ast_matchers::MatchFinder::MatchResult &Match); + +/// Handles the matcher and callback registration for a single rewrite rule, as +/// defined by the arguments of the constructor. +class Transformer : public ast_matchers::MatchFinder::MatchCallback { +public: + using ChangeConsumer = + std::function; + + /// \param Consumer Receives each successful rewrites as an \c AtomicChange. + Transformer(RewriteRule Rule, ChangeConsumer Consumer) + : Rule(std::move(Rule)), Consumer(std::move(Consumer)) {} + + /// N.B. Passes `this` pointer to `MatchFinder`. So, this object should not + /// be moved after this call. + void registerMatchers(ast_matchers::MatchFinder *MatchFinder); + + /// Not called directly by users -- called by the framework, via base class + /// pointer. + void run(const ast_matchers::MatchFinder::MatchResult &Result) override; + +private: + RewriteRule Rule; + /// Receives each successful rewrites as an \c AtomicChange. + ChangeConsumer Consumer; +}; +} // namespace tooling +} // namespace clang + +#endif // LLVM_CLANG_TOOLING_REFACTOR_TRANSFORMER_H_ diff --git a/clang/lib/Tooling/Refactoring/CMakeLists.txt b/clang/lib/Tooling/Refactoring/CMakeLists.txt --- a/clang/lib/Tooling/Refactoring/CMakeLists.txt +++ b/clang/lib/Tooling/Refactoring/CMakeLists.txt @@ -13,6 +13,7 @@ Rename/USRFindingAction.cpp Rename/USRLocFinder.cpp NodeId.cpp + Transformer.cpp LINK_LIBS clangAST diff --git a/clang/lib/Tooling/Refactoring/Transformer.cpp b/clang/lib/Tooling/Refactoring/Transformer.cpp new file mode 100644 --- /dev/null +++ b/clang/lib/Tooling/Refactoring/Transformer.cpp @@ -0,0 +1,239 @@ +//===--- Transformer.cpp - Transformer library implementation ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Tooling/Refactoring/Transformer.h" +#include "clang/AST/Expr.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/ASTMatchers/ASTMatchers.h" +#include "clang/Basic/Diagnostic.h" +#include "clang/Basic/SourceLocation.h" +#include "clang/Rewrite/Core/Rewriter.h" +#include "clang/Tooling/FixIt.h" +#include "clang/Tooling/Refactoring.h" +#include "clang/Tooling/Refactoring/AtomicChange.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Errc.h" +#include "llvm/Support/Error.h" +#include +#include +#include +#include + +namespace clang { +namespace tooling { +namespace { +using ::clang::ast_matchers::MatchFinder; +using ::clang::ast_matchers::stmt; +using ::clang::ast_type_traits::ASTNodeKind; +using ::clang::ast_type_traits::DynTypedNode; +using ::llvm::Error; +using ::llvm::Expected; +using ::llvm::Optional; +using ::llvm::StringError; +using ::llvm::StringRef; + +using MatchResult = MatchFinder::MatchResult; +} // namespace + +static bool isOriginMacroBody(const clang::SourceManager &source_manager, + clang::SourceLocation loc) { + while (loc.isMacroID()) { + if (source_manager.isMacroBodyExpansion(loc)) + return true; + // Otherwise, it must be in an argument, so we continue searching up the + // invocation stack. getImmediateMacroCallerLoc() gives the location of the + // argument text, inside the call text. + loc = source_manager.getImmediateMacroCallerLoc(loc); + } + return false; +} + +static llvm::Error invalidArgumentError(llvm::Twine Message) { + return llvm::make_error(llvm::errc::invalid_argument, Message); +} + +static llvm::Error unboundNodeError(StringRef Role, StringRef Id) { + return invalidArgumentError(Role + " (=" + Id + ") references unbound node"); +} + +static llvm::Error typeError(llvm::Twine Message, const ASTNodeKind &Kind) { + return invalidArgumentError(Message + " (node kind is " + Kind.asStringRef() + + ")"); +} + +static llvm::Error missingPropertyError(llvm::Twine Description, + StringRef Property) { + return invalidArgumentError(Description + " requires property '" + Property + + "'"); +} + +// Verifies that `node` is appropriate for the given `target_part`. +static Error verifyTarget(const DynTypedNode &Node, NodePart TargetPart) { + switch (TargetPart) { + case NodePart::Node: + return Error::success(); + case NodePart::Member: + if (Node.get() != nullptr) + return Error::success(); + return typeError("NodePart::Member applied to non-MemberExpr", + Node.getNodeKind()); + case NodePart::Name: + if (const auto *D = Node.get()) { + if (D->getDeclName().isIdentifier()) + return Error::success(); + return missingPropertyError("NodePart::Name", "identifier"); + } + if (const auto *E = Node.get()) { + if (E->getNameInfo().getName().isIdentifier()) + return Error::success(); + return missingPropertyError("NodePart::Name", "identifier"); + } + if (const auto *I = Node.get()) { + if (I->isMemberInitializer()) + return Error::success(); + return missingPropertyError("NodePart::Name", "member initializer"); + } + return typeError( + "NodePart::Name applied to neither DeclRefExpr, NamedDecl nor " + "CXXCtorInitializer", + Node.getNodeKind()); + } + llvm_unreachable("Unexpected case in NodePart type."); +} + +// Requires VerifyTarget(node, target_part) == success. +static CharSourceRange getTarget(const DynTypedNode &Node, ASTNodeKind Kind, + NodePart TargetPart, ASTContext &Context) { + SourceLocation TokenLoc; + switch (TargetPart) { + case NodePart::Node: { + // For non-expression statements, associate any trailing semicolon with the + // statement text. However, if the target was intended as an expression (as + // indicated by its kind) then we do not associate any trailing semicolon + // with it. We only associate the exact expression text. + if (Node.get() != nullptr) { + auto ExprKind = ASTNodeKind::getFromNodeKind(); + if (!ExprKind.isBaseOf(Kind)) + return fixit::getExtendedRange(Node, tok::TokenKind::semi, Context); + } + return CharSourceRange::getTokenRange(Node.getSourceRange()); + } + case NodePart::Member: + TokenLoc = Node.get()->getMemberLoc(); + break; + case NodePart::Name: + if (const auto *D = Node.get()) { + TokenLoc = D->getLocation(); + break; + } + if (const auto *E = Node.get()) { + TokenLoc = E->getLocation(); + break; + } + if (const auto *I = Node.get()) { + TokenLoc = I->getMemberLocation(); + break; + } + // This should be unreachable if the target was already verified. + llvm_unreachable("NodePart::Name applied to neither NamedDecl nor " + "CXXCtorInitializer"); + } + return CharSourceRange::getTokenRange(TokenLoc, TokenLoc); +} + +Expected applyRewriteRule(const RewriteRule &Rule, + const MatchResult &Result) { + // Ignore results in failing TUs or those rejected by the where clause. + if (Result.Context->getDiagnostics().hasErrorOccurred() || + !Rule.filter().matches(Result)) + return Transformation(); + + auto &NodesMap = Result.Nodes.getMap(); + auto It = NodesMap.find(Rule.target()); + if (It == NodesMap.end()) + return unboundNodeError("rule.target", Rule.target()); + if (auto Err = llvm::handleErrors( + verifyTarget(It->second, Rule.targetPart()), [&Rule](StringError &E) { + return invalidArgumentError("Failure targeting node" + + Rule.target() + ": " + E.getMessage()); + })) { + return std::move(Err); + } + CharSourceRange Target = getTarget(It->second, Rule.targetKind(), + Rule.targetPart(), *Result.Context); + if (Target.isInvalid() || + isOriginMacroBody(*Result.SourceManager, Target.getBegin())) + return Transformation(); + + auto ReplacementOrErr = Rule.replacement(Result); + if (auto Err = ReplacementOrErr.takeError()) + return std::move(Err); + return Transformation{Target, std::move(*ReplacementOrErr)}; +} + +constexpr char RewriteRule::RootId[]; + +RewriteRule & +RewriteRule::where(std::function FilterFn) & { + Filter = MatchFilter(std::move(FilterFn)); + return *this; +} + +RewriteRule &RewriteRule::change(const NodeId &TargetId, NodePart Part) & { + Target = std::string(TargetId.id()); + TargetKind = ASTNodeKind(); + TargetPart = Part; + return *this; +} + +RewriteRule &RewriteRule::replaceWith(TextGenerator TG) & { + Replacement = std::move(TG); + return *this; +} + +RewriteRule &RewriteRule::because(TextGenerator TG) & { + Explanation = std::move(TG); + return *this; +} + +// `Explanation` is a `string&`, rather than a `string` or `StringRef` to save +// an extra copy needed to intialize the captured lambda variable. After C++14, +// we can use intializers to do this properly. +RewriteRule makeRule(StatementMatcher Matcher, TextGenerator Replacement, + const std::string &Explanation) { + return RewriteRule(Matcher) + .replaceWith(std::move(Replacement)) + .because([Explanation](const MatchResult &) { return Explanation; }); +} + +void Transformer::registerMatchers(MatchFinder *MatchFinder) { + MatchFinder->addDynamicMatcher(Rule.matcher(), this); +} + +void Transformer::run(const MatchResult &Result) { + auto ChangeOrErr = applyRewriteRule(Rule, Result); + if (auto Err = ChangeOrErr.takeError()) { + llvm::errs() << "Rewrite failed: " << llvm::toString(std::move(Err)) + << "\n"; + return; + } + auto &Change = *ChangeOrErr; + auto &Range = Change.Range; + if (Range.isInvalid()) { + // No rewrite applied (but no error encountered either). + return; + } + AtomicChange AC(*Result.SourceManager, Range.getBegin()); + if (auto Err = AC.replace(*Result.SourceManager, Range, Change.Replacement)) { + AC.setError(llvm::toString(std::move(Err))); + } + Consumer(AC); +} +} // namespace tooling +} // namespace clang diff --git a/clang/unittests/Tooling/CMakeLists.txt b/clang/unittests/Tooling/CMakeLists.txt --- a/clang/unittests/Tooling/CMakeLists.txt +++ b/clang/unittests/Tooling/CMakeLists.txt @@ -50,6 +50,7 @@ ReplacementsYamlTest.cpp RewriterTest.cpp ToolingTest.cpp + TransformerTest.cpp ) target_link_libraries(ToolingTests diff --git a/clang/unittests/Tooling/TransformerTest.cpp b/clang/unittests/Tooling/TransformerTest.cpp new file mode 100644 --- /dev/null +++ b/clang/unittests/Tooling/TransformerTest.cpp @@ -0,0 +1,428 @@ +//===- unittest/Tooling/TransformerTest.cpp -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Tooling/Refactoring/Transformer.h" + +#include "clang/ASTMatchers/ASTMatchers.h" +#include "clang/Tooling/Tooling.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace clang { +namespace tooling { +namespace { +using ::clang::ast_matchers::anyOf; +using ::clang::ast_matchers::argumentCountIs; +using ::clang::ast_matchers::callee; +using ::clang::ast_matchers::callExpr; +using ::clang::ast_matchers::cxxMemberCallExpr; +using ::clang::ast_matchers::cxxMethodDecl; +using ::clang::ast_matchers::cxxRecordDecl; +using ::clang::ast_matchers::declRefExpr; +using ::clang::ast_matchers::expr; +using ::clang::ast_matchers::functionDecl; +using ::clang::ast_matchers::hasAnyName; +using ::clang::ast_matchers::hasArgument; +using ::clang::ast_matchers::hasDeclaration; +using ::clang::ast_matchers::hasElse; +using ::clang::ast_matchers::hasName; +using ::clang::ast_matchers::hasType; +using ::clang::ast_matchers::ifStmt; +using ::clang::ast_matchers::member; +using ::clang::ast_matchers::memberExpr; +using ::clang::ast_matchers::namedDecl; +using ::clang::ast_matchers::on; +using ::clang::ast_matchers::pointsTo; +using ::clang::ast_matchers::to; +using ::clang::ast_matchers::unless; + +constexpr char KHeaderContents[] = R"cc( + struct string { + string(const char*); + char* c_str(); + int size(); + }; + int strlen(const char*); + + namespace proto { + struct PCFProto { + int foo(); + }; + struct ProtoCommandLineFlag : PCFProto { + PCFProto& GetProto(); + }; + } // namespace proto +)cc"; +} // namespace + +static clang::ast_matchers::internal::Matcher +isOrPointsTo(const DeclarationMatcher &TypeMatcher) { + return anyOf(hasDeclaration(TypeMatcher), pointsTo(TypeMatcher)); +} + +static std::string format(llvm::StringRef Code) { + const std::vector Ranges(1, Range(0, Code.size())); + auto Style = format::getLLVMStyle(); + const auto Replacements = format::reformat(Style, Code, Ranges); + auto Formatted = applyAllReplacements(Code, Replacements); + if (!Formatted) { + ADD_FAILURE() << "Could not format code: " + << llvm::toString(Formatted.takeError()); + return std::string(); + } + return *Formatted; +} + +void compareSnippets(llvm::StringRef Expected, + const llvm::Optional &MaybeActual) { + ASSERT_TRUE(MaybeActual) << "Rewrite failed. Expecting: " << Expected; + auto Actual = *MaybeActual; + std::string HL = "#include \"header.h\"\n"; + auto I = Actual.find(HL); + if (I != std::string::npos) { + Actual.erase(I, HL.size()); + } + EXPECT_EQ(format(Expected), format(Actual)); +} + +// FIXME: consider separating this class into its own file(s). +class ClangRefactoringTestBase : public testing::Test { +protected: + void appendToHeader(llvm::StringRef S) { FileContents[0].second += S; } + + void addFile(llvm::StringRef Filename, llvm::StringRef Content) { + FileContents.emplace_back(Filename, Content); + } + + llvm::Optional rewrite(llvm::StringRef Input) { + std::string Code = ("#include \"header.h\"\n" + Input).str(); + auto Factory = newFrontendActionFactory(&MatchFinder); + if (!runToolOnCodeWithArgs( + Factory->create(), Code, std::vector(), "input.cc", + "clang-tool", std::make_shared(), + FileContents)) { + return None; + } + auto ChangedCodeOrErr = + applyAtomicChanges("input.cc", Code, Changes, ApplyChangesSpec()); + if (auto Err = ChangedCodeOrErr.takeError()) { + llvm::errs() << "Change failed: " << llvm::toString(std::move(Err)) + << "\n"; + return None; + } + return *ChangedCodeOrErr; + } + + clang::ast_matchers::MatchFinder MatchFinder; + AtomicChanges Changes; + +private: + FileContentMappings FileContents = {{"header.h", ""}}; +}; + +class TransformerTest : public ClangRefactoringTestBase { +protected: + TransformerTest() { appendToHeader(KHeaderContents); } + + Transformer::ChangeConsumer changeRecorder() { + return [this](const AtomicChange &C) { Changes.push_back(C); }; + } +}; + +// Wraps a (simple) string as a TextGenerator. +static TextGenerator text(const std::string &M) { + return + [M](const clang::ast_matchers::MatchFinder::MatchResult &) { return M; }; +} + +// Given string s, change strlen($s.c_str()) to $s.size() TODO: my type +// inference from matchers doesn't work since Matcher types are broken: callExpr +// is a statement matcher, which i'm pretty sure it shouldn't be. +RewriteRule ruleStrlenSize() { + ExprId StringExpr; + auto StringType = namedDecl(hasAnyName("::basic_string", "::string")); + return RewriteRule( + callExpr( + callee(functionDecl(hasName("strlen"))), + hasArgument(0, cxxMemberCallExpr( + on(expr(StringExpr.bind(), + hasType(isOrPointsTo(StringType)))), + callee(cxxMethodDecl(hasName("c_str"))))))) + .as() + .replaceWith(text("REPLACED")) + .because(text("Use size() method directly on string.")); +} + +TEST_F(TransformerTest, StrlenSize) { + std::string Input = "int f(string s) { return strlen(s.c_str()); }"; + std::string Expected = "int f(string s) { return REPLACED; }"; + + Transformer T(ruleStrlenSize(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// Tests that no change is applied when a match is not expected. +TEST_F(TransformerTest, NoMatch) { + std::string Input = "int f(string s) { return s.size(); }"; + + Transformer T(ruleStrlenSize(), changeRecorder()); + T.registerMatchers(&MatchFinder); + // Input should not be changed. + compareSnippets(Input, rewrite(Input)); +} + +// Tests that expressions in macro arguments are rewritten (when applicable). +TEST_F(TransformerTest, StrlenSizeMacro) { + std::string Input = R"cc( +#define ID(e) e + int f(string s) { return ID(strlen(s.c_str())); })cc"; + std::string Expected = R"cc( +#define ID(e) e + int f(string s) { return ID(REPLACED); })cc"; + + Transformer T(ruleStrlenSize(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// Use the lvalue-ref overloads of the RewriteRule builder methods. +TEST_F(TransformerTest, LvalueRefOverloads) { + StmtId E; + RewriteRule Rule(ifStmt(hasElse(E.bind()))); + Rule.change(E).replaceWith(text("bar();")); + + std::string Input = R"cc( + void foo() { + if (10 > 1.0) + return; + else + foo(); + } + )cc"; + std::string Expected = R"cc( + void foo() { + if (10 > 1.0) + return; + else + bar(); + } + )cc"; + + Transformer T(Rule, changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// Tests replacing an expression. +TEST_F(TransformerTest, Flag) { + ExprId Flag; + auto Rule = + RewriteRule( + cxxMemberCallExpr( + on(expr(Flag.bind(), hasType(cxxRecordDecl(hasName( + "proto::ProtoCommandLineFlag"))))), + unless(callee(cxxMethodDecl(hasName("GetProto")))))) + .change(Flag) + .replaceWith(text("EXPR")) + .because(text("Use GetProto() to access proto fields.")); + + std::string Input = R"cc( + proto::ProtoCommandLineFlag flag; + int x = flag.foo(); + int y = flag.GetProto().foo(); + )cc"; + std::string Expected = R"cc( + proto::ProtoCommandLineFlag flag; + int x = EXPR.foo(); + int y = flag.GetProto().foo(); + )cc"; + + Transformer T(Rule, changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +TEST_F(TransformerTest, NodePartNameNamedDecl) { + DeclId Fun; + auto Rule = RewriteRule(functionDecl(hasName("bad"), Fun.bind())) + .change(Fun, NodePart::Name) + .replaceWith(text("good")); + + std::string Input = R"cc( + int bad(int x); + int bad(int x) { return x * x; } + )cc"; + std::string Expected = R"cc( + int good(int x); + int good(int x) { return x * x; } + )cc"; + + Transformer T(Rule, changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +TEST_F(TransformerTest, NodePartNameDeclRef) { + std::string Input = R"cc( + template + T bad(T x) { + return x; + } + int neutral(int x) { return bad(x) * x; } + )cc"; + std::string Expected = R"cc( + template + T bad(T x) { + return x; + } + int neutral(int x) { return good(x) * x; } + )cc"; + + ExprId Ref; + Transformer T( + RewriteRule(declRefExpr(to(functionDecl(hasName("bad"))), Ref.bind())) + .change(Ref, NodePart::Name) + .replaceWith(text("good")), + changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +TEST_F(TransformerTest, NodePartNameDeclRefFailure) { + std::string Input = R"cc( + struct Y {}; + int operator*(const Y&); + int neutral(int x) { + Y y; + return *y + x; + } + )cc"; + + ExprId Ref; + Transformer T(RewriteRule(declRefExpr(to(functionDecl()), Ref.bind())) + .change(Ref, NodePart::Name) + .replaceWith(text("good")), + changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Input, rewrite(Input)); +} + +TEST_F(TransformerTest, NodePartMember) { + ExprId E; + auto Rule = RewriteRule(memberExpr(member(hasName("bad")), E.bind())) + .change(E, NodePart::Member) + .replaceWith(text("good")); + + std::string Input = R"cc( + struct S { + int bad; + }; + int g() { + S s; + return s.bad; + } + )cc"; + std::string Expected = R"cc( + struct S { + int bad; + }; + int g() { + S s; + return s.good; + } + )cc"; + + Transformer T(Rule, changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// A rule that finds function calls with two arguments where the arguments are +// the same identifier. +RewriteRule ruleDuplicateArgs() { + ExprId Arg0, Arg1; + return RewriteRule(callExpr(argumentCountIs(2), hasArgument(0, Arg0.bind()), + hasArgument(1, Arg1.bind()))) + .where([Arg0, Arg1]( + const clang::ast_matchers::MatchFinder::MatchResult &result) { + auto *Ref0 = Arg0.getNodeAs(result); + auto *Ref1 = Arg1.getNodeAs(result); + return Ref0 != nullptr && Ref1 != nullptr && + Ref0->getDecl() == Ref1->getDecl(); + }) + .as() + .replaceWith(text("42")); +} + +TEST_F(TransformerTest, FilterPassed) { + std::string Input = R"cc( + int foo(int x, int y); + int x = 3; + int z = foo(x, x); + )cc"; + std::string Expected = R"cc( + int foo(int x, int y); + int x = 3; + int z = 42; + )cc"; + + Transformer T(ruleDuplicateArgs(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Expected, rewrite(Input)); +} + +// +// Negative tests (where we expect no transformation to occur). +// + +TEST_F(TransformerTest, FilterFailed) { + std::string Input = R"cc( + int foo(int x, int y); + int x = 3; + int y = 17; + // Different identifiers. + int z = foo(x, y); + // One identifier, one not. + int w = foo(x, 3); + )cc"; + + Transformer T(ruleDuplicateArgs(), changeRecorder()); + T.registerMatchers(&MatchFinder); + compareSnippets(Input, rewrite(Input)); +} + +TEST_F(TransformerTest, NoTransformationInMacro) { + std::string Input = R"cc( +#define MACRO(str) strlen((str).c_str()) + int f(string s) { return MACRO(s); })cc"; + + Transformer T(ruleStrlenSize(), changeRecorder()); + T.registerMatchers(&MatchFinder); + // The macro should be ignored. + compareSnippets(Input, rewrite(Input)); +} + +// This test handles the corner case where a macro called within another macro +// expands to matching code, but the matched code is an argument to the nested +// macro. A simple check of isMacroArgExpansion() vs. isMacroBodyExpansion() +// will get this wrong, and transform the code. This test verifies that no such +// transformation occurs. +TEST_F(TransformerTest, NoTransformationInNestedMacro) { + std::string Input = R"cc( +#define NESTED(e) e +#define MACRO(str) NESTED(strlen((str).c_str())) + int f(string s) { return MACRO(s); })cc"; + + Transformer T(ruleStrlenSize(), changeRecorder()); + T.registerMatchers(&MatchFinder); + // The macro should be ignored. + compareSnippets(Input, rewrite(Input)); +} +} // namespace tooling +} // namespace clang