diff --git a/clang/include/clang/Tooling/Transformer/RewriteRule.h b/clang/include/clang/Tooling/Transformer/RewriteRule.h --- a/clang/include/clang/Tooling/Transformer/RewriteRule.h +++ b/clang/include/clang/Tooling/Transformer/RewriteRule.h @@ -332,6 +332,23 @@ remove(enclose(after(inner), after(outer)))}); } +/// Applies `Rule` to all descendants of the node bound to `NodeId`. `Rule` can +/// refer to nodes bound by the calling rule. `Rule` is not applied to the node +/// itself. +/// +/// For example, +/// ``` +/// auto InlineX = +/// makeRule(declRefExpr(to(varDecl(hasName("x")))), changeTo(cat("3"))); +/// makeRule(functionDecl(hasName("f"), hasBody(stmt().bind("body"))).bind("f"), +/// flatten( +/// changeTo(name("f"), cat("newName")), +/// rewriteDescendants("body", InlineX))); +/// ``` +/// Here, we find the function `f`, change its name to `newName` and change all +/// appearances of `x` in its body to `3`. +EditGenerator rewriteDescendants(std::string NodeId, RewriteRule Rule); + /// The following three functions are a low-level part of the RewriteRule /// API. We expose them for use in implementing the fixtures that interpret /// RewriteRule, like Transformer and TransfomerTidy, or for more advanced diff --git a/clang/lib/Tooling/Transformer/RewriteRule.cpp b/clang/lib/Tooling/Transformer/RewriteRule.cpp --- a/clang/lib/Tooling/Transformer/RewriteRule.cpp +++ b/clang/lib/Tooling/Transformer/RewriteRule.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "clang/Tooling/Transformer/RewriteRule.h" +#include "clang/AST/ASTTypeTraits.h" +#include "clang/AST/Stmt.h" #include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/ASTMatchers/ASTMatchers.h" #include "clang/Basic/SourceLocation.h" @@ -115,15 +117,144 @@ return change(std::move(S), std::make_shared("")); } -RewriteRule transformer::makeRule(ast_matchers::internal::DynTypedMatcher M, - EditGenerator Edits, +RewriteRule transformer::makeRule(DynTypedMatcher M, EditGenerator Edits, TextGenerator Explanation) { return RewriteRule{{RewriteRule::Case{ std::move(M), std::move(Edits), std::move(Explanation), {}}}}; } +namespace { + +/// Unconditionally binds the given node set before trying `InnerMatcher` and +/// keeps the bound nodes on a successful match. +template +class BindingsMatcher : public ast_matchers::internal::MatcherInterface { + ast_matchers::BoundNodes Nodes; + const ast_matchers::internal::Matcher InnerMatcher; + +public: + explicit BindingsMatcher(ast_matchers::BoundNodes Nodes, + ast_matchers::internal::Matcher InnerMatcher) + : Nodes(std::move(Nodes)), InnerMatcher(std::move(InnerMatcher)) {} + + bool matches( + const T &Node, ast_matchers::internal::ASTMatchFinder *Finder, + ast_matchers::internal::BoundNodesTreeBuilder *Builder) const override { + ast_matchers::internal::BoundNodesTreeBuilder Result(*Builder); + for (const auto &N : Nodes.getMap()) + Result.setBinding(N.first, N.second); + if (InnerMatcher.matches(Node, Finder, &Result)) { + *Builder = std::move(Result); + return true; + } + return false; + } +}; + +/// Matches nodes of type T that have at least one descendant node for which the +/// given inner matcher matches. Will match for each descendant node that +/// matches. Based on ForEachDescendantMatcher, but takes a dynamic matcher, +/// instead of a static one, because it is used by RewriteRule, which carries +/// (only top-level) dynamic matchers. +template +class DynamicForEachDescendantMatcher + : public ast_matchers::internal::MatcherInterface { + const DynTypedMatcher DescendantMatcher; + +public: + explicit DynamicForEachDescendantMatcher(DynTypedMatcher DescendantMatcher) + : DescendantMatcher(std::move(DescendantMatcher)) {} + + bool matches( + const T &Node, ast_matchers::internal::ASTMatchFinder *Finder, + ast_matchers::internal::BoundNodesTreeBuilder *Builder) const override { + return Finder->matchesDescendantOf( + Node, this->DescendantMatcher, Builder, + ast_matchers::internal::ASTMatchFinder::BK_All); + } +}; + +template +ast_matchers::internal::Matcher +forEachDescendantDynamically(ast_matchers::BoundNodes Nodes, + DynTypedMatcher M) { + return ast_matchers::internal::makeMatcher(new BindingsMatcher( + std::move(Nodes), + ast_matchers::internal::makeMatcher( + new DynamicForEachDescendantMatcher(std::move(M))))); +} + +class ApplyRuleCallback : public MatchFinder::MatchCallback { +public: + ApplyRuleCallback(RewriteRule Rule) : Rule(std::move(Rule)) {} + + template + void registerMatchers(const ast_matchers::BoundNodes &Nodes, + MatchFinder *MF) { + for (auto &Matcher : transformer::detail::buildMatchers(Rule)) + MF->addMatcher(forEachDescendantDynamically(Nodes, Matcher), this); + } + + void run(const MatchFinder::MatchResult &Result) override { + if (!Edits) + return; + transformer::RewriteRule::Case Case = + transformer::detail::findSelectedCase(Result, Rule); + auto Transformations = Case.Edits(Result); + if (!Transformations) { + Edits = Transformations.takeError(); + return; + } + Edits->append(Transformations->begin(), Transformations->end()); + } + + RewriteRule Rule; + + // Initialize to a non-error state. + Expected> Edits = SmallVector(); +}; +} // namespace + +template +llvm::Expected> +rewriteDescendantsImpl(const T &Node, RewriteRule Rule, + const MatchResult &Result) { + ApplyRuleCallback Callback(std::move(Rule)); + MatchFinder Finder; + Callback.registerMatchers(Result.Nodes, &Finder); + Finder.match(Node, *Result.Context); + return std::move(Callback.Edits); +} + +EditGenerator transformer::rewriteDescendants(std::string NodeId, + RewriteRule Rule) { + // FIXME: warn or return error if `Rule` contains any `AddedIncludes`, since + // these will be dropped. + return [NodeId = std::move(NodeId), + Rule = std::move(Rule)](const MatchResult &Result) + -> llvm::Expected> { + const ast_matchers::BoundNodes::IDToNodeMap &NodesMap = + Result.Nodes.getMap(); + auto It = NodesMap.find(NodeId); + if (It == NodesMap.end()) + return llvm::make_error(llvm::errc::invalid_argument, + "ID not bound: " + NodeId); + if (auto *Node = It->second.get()) + return rewriteDescendantsImpl(*Node, std::move(Rule), Result); + if (auto *Node = It->second.get()) + return rewriteDescendantsImpl(*Node, std::move(Rule), Result); + if (auto *Node = It->second.get()) + return rewriteDescendantsImpl(*Node, std::move(Rule), Result); + + return llvm::make_error( + llvm::errc::invalid_argument, + "type unsupported for recursive rewriting, ID=\"" + NodeId + + "\", Kind=" + It->second.getNodeKind().asStringRef()); + }; +} + void transformer::addInclude(RewriteRule &Rule, StringRef Header, - IncludeFormat Format) { + IncludeFormat Format) { for (auto &Case : Rule.Cases) Case.AddedIncludes.emplace_back(Header.str(), Format); } diff --git a/clang/lib/Tooling/Transformer/Transformer.cpp b/clang/lib/Tooling/Transformer/Transformer.cpp --- a/clang/lib/Tooling/Transformer/Transformer.cpp +++ b/clang/lib/Tooling/Transformer/Transformer.cpp @@ -38,13 +38,8 @@ return; } - if (Transformations->empty()) { - // No rewrite applied (but no error encountered either). - transformer::detail::getRuleMatchLoc(Result).print( - llvm::errs() << "note: skipping match at loc ", *Result.SourceManager); - llvm::errs() << "\n"; + if (Transformations->empty()) return; - } // Group the transformations, by file, into AtomicChanges, each anchored by // the location of the first change in that file. diff --git a/clang/unittests/Tooling/TransformerTest.cpp b/clang/unittests/Tooling/TransformerTest.cpp --- a/clang/unittests/Tooling/TransformerTest.cpp +++ b/clang/unittests/Tooling/TransformerTest.cpp @@ -114,7 +114,9 @@ if (C) { Changes.push_back(std::move(*C)); } else { - consumeError(C.takeError()); + // FIXME: stash this error rather then printing. + llvm::errs() << "Error generating changes: " + << llvm::toString(C.takeError()) << "\n"; ++ErrorCount; } }; @@ -414,6 +416,105 @@ Input, Expected); } +// Rewrite various Stmts inside a Decl. +TEST_F(TransformerTest, RewriteDescendantsDeclChangeStmt) { + std::string Input = + "int f(int x) { int y = x; { int z = x * x; } return x; }"; + std::string Expected = + "int f(int x) { int y = 3; { int z = 3 * 3; } return 3; }"; + auto InlineX = + makeRule(declRefExpr(to(varDecl(hasName("x")))), changeTo(cat("3"))); + testRule(makeRule(functionDecl(hasName("f")).bind("fun"), + rewriteDescendants("fun", InlineX)), + Input, Expected); +} + +// Rewrite various TypeLocs inside a Decl. +TEST_F(TransformerTest, RewriteDescendantsDeclChangeTypeLoc) { + std::string Input = "int f(int *x) { return *x; }"; + std::string Expected = "char f(char *x) { return *x; }"; + auto IntToChar = makeRule(typeLoc(loc(qualType(isInteger(), builtinType()))), + changeTo(cat("char"))); + testRule(makeRule(functionDecl(hasName("f")).bind("fun"), + rewriteDescendants("fun", IntToChar)), + Input, Expected); +} + +TEST_F(TransformerTest, RewriteDescendantsStmt) { + // Add an unrelated definition to the header that also has a variable named + // "x", to test that the rewrite is limited to the scope we intend. + appendToHeader(R"cc(int g(int x) { return x; })cc"); + std::string Input = + "int f(int x) { int y = x; { int z = x * x; } return x; }"; + std::string Expected = + "int f(int x) { int y = 3; { int z = 3 * 3; } return 3; }"; + auto InlineX = + makeRule(declRefExpr(to(varDecl(hasName("x")))), changeTo(cat("3"))); + testRule(makeRule(functionDecl(hasName("f"), hasBody(stmt().bind("body"))), + rewriteDescendants("body", InlineX)), + Input, Expected); +} + +TEST_F(TransformerTest, RewriteDescendantsTypeLoc) { + std::string Input = "int f(int *x) { return *x; }"; + std::string Expected = "int f(char *x) { return *x; }"; + auto IntToChar = + makeRule(typeLoc(loc(qualType(isInteger(), builtinType()))).bind("loc"), + changeTo(cat("char"))); + testRule( + makeRule(functionDecl(hasName("f"), + hasParameter(0, varDecl(hasTypeLoc( + typeLoc().bind("parmType"))))), + rewriteDescendants("parmType", IntToChar)), + Input, Expected); +} + +TEST_F(TransformerTest, RewriteDescendantsReferToParentBinding) { + std::string Input = + "int f(int p) { int y = p; { int z = p * p; } return p; }"; + std::string Expected = + "int f(int p) { int y = 3; { int z = 3 * 3; } return 3; }"; + std::string VarId = "var"; + auto InlineVar = makeRule(declRefExpr(to(varDecl(equalsBoundNode(VarId)))), + changeTo(cat("3"))); + testRule(makeRule(functionDecl(hasName("f"), + hasParameter(0, varDecl().bind(VarId))) + .bind("fun"), + rewriteDescendants("fun", InlineVar)), + Input, Expected); +} + +TEST_F(TransformerTest, RewriteDescendantsUnboundNode) { + std::string Input = + "int f(int x) { int y = x; { int z = x * x; } return x; }"; + auto InlineX = + makeRule(declRefExpr(to(varDecl(hasName("x")))), changeTo(cat("3"))); + Transformer T(makeRule(functionDecl(hasName("f")), + rewriteDescendants("UNBOUND", InlineX)), + consumer()); + T.registerMatchers(&MatchFinder); + EXPECT_FALSE(rewrite(Input)); + EXPECT_THAT(Changes, IsEmpty()); + EXPECT_EQ(ErrorCount, 1); +} + +TEST_F(TransformerTest, RewriteDescendantsInvalidNodeType) { + std::string Input = + "int f(int x) { int y = x; { int z = x * x; } return x; }"; + auto IntToChar = + makeRule(qualType(isInteger(), builtinType()), changeTo(cat("char"))); + Transformer T( + makeRule(functionDecl( + hasName("f"), + hasParameter(0, varDecl(hasType(qualType().bind("type"))))), + rewriteDescendants("type", IntToChar)), + consumer()); + T.registerMatchers(&MatchFinder); + EXPECT_FALSE(rewrite(Input)); + EXPECT_THAT(Changes, IsEmpty()); + EXPECT_EQ(ErrorCount, 1); +} + TEST_F(TransformerTest, InsertBeforeEdit) { std::string Input = R"cc( int f() {