diff --git a/clang/include/clang/Tooling/Transformer/RewriteRule.h b/clang/include/clang/Tooling/Transformer/RewriteRule.h --- a/clang/include/clang/Tooling/Transformer/RewriteRule.h +++ b/clang/include/clang/Tooling/Transformer/RewriteRule.h @@ -31,14 +31,30 @@ namespace clang { namespace transformer { +// Specifies how to interpret an edit. +enum class EditKind { + // Edits a source range in the file. + Range, + // Inserts an include in the file. The `Replacement` field is the name of the + // file for which to add an include. + AddInclude, +}; + /// A concrete description of a source edit, represented by a character range in /// the source to be replaced and a corresponding replacement string. struct Edit { + EditKind Kind = EditKind::Range; CharSourceRange Range; std::string Replacement; llvm::Any Metadata; }; +/// Format of the path in an include directive -- angle brackets or quotes. +enum class IncludeFormat { + Quoted, + Angled, +}; + /// Maps a match result to a list of concrete edits (with possible /// failure). This type is a building block of rewrite rules, but users will /// generally work in terms of `ASTEdit`s (below) rather than directly in terms @@ -86,6 +102,7 @@ // changeTo(cat("different_expr")) // \endcode struct ASTEdit { + EditKind Kind = EditKind::Range; RangeSelector TargetRange; TextGenerator Replacement; TextGenerator Note; @@ -185,6 +202,18 @@ /// Removes the source selected by \p S. ASTEdit remove(RangeSelector S); +/// Adds an include directive for the given header to the file of `Target`. The +/// particular location specified by `Target` is ignored. +ASTEdit addInclude(RangeSelector Target, StringRef Header, + IncludeFormat Format = IncludeFormat::Quoted); + +/// Adds an include directive for the given header to the file associated with +/// `RootID`. +inline ASTEdit addInclude(StringRef Header, + IncludeFormat Format = IncludeFormat::Quoted) { + return addInclude(before(node(RootID)), Header, Format); +} + // FIXME: If `Metadata` returns an `llvm::Expected` the `AnyGenerator` will // construct an `llvm::Expected` where no error is present but the // `llvm::Any` holds the error. This is unlikely but potentially surprising. @@ -216,12 +245,6 @@ remove(enclose(after(inner), after(outer)))}); } -/// Format of the path in an include directive -- angle brackets or quotes. -enum class IncludeFormat { - Quoted, - Angled, -}; - /// Description of a source-code transformation. // // A *rewrite rule* describes a transformation of source code. A simple rule @@ -250,10 +273,6 @@ ast_matchers::internal::DynTypedMatcher Matcher; EditGenerator Edits; TextGenerator Explanation; - /// Include paths to add to the file affected by this case. These are - /// bundled with the `Case`, rather than the `RewriteRule`, because each - /// case might have different associated changes to the includes. - std::vector> AddedIncludes; }; // We expect RewriteRules will most commonly include only one case. SmallVector Cases; diff --git a/clang/lib/Tooling/Transformer/RewriteRule.cpp b/clang/lib/Tooling/Transformer/RewriteRule.cpp --- a/clang/lib/Tooling/Transformer/RewriteRule.cpp +++ b/clang/lib/Tooling/Transformer/RewriteRule.cpp @@ -52,6 +52,7 @@ if (!Metadata) return Metadata.takeError(); transformer::Edit T; + T.Kind = E.Kind; T.Range = *EditRange; T.Replacement = std::move(*Replacement); T.Metadata = std::move(*Metadata); @@ -115,14 +116,36 @@ }; } // namespace +static TextGenerator makeText(std::string S) { + return std::make_shared(std::move(S)); +} + ASTEdit transformer::remove(RangeSelector S) { - return change(std::move(S), std::make_shared("")); + return change(std::move(S), makeText("")); +} + +static std::string formatHeaderPath(StringRef Header, IncludeFormat Format) { + switch (Format) { + case transformer::IncludeFormat::Quoted: + return Header.str(); + case transformer::IncludeFormat::Angled: + return ("<" + Header + ">").str(); + } +} + +ASTEdit transformer::addInclude(RangeSelector Target, StringRef Header, + IncludeFormat Format) { + ASTEdit E; + E.Kind = EditKind::AddInclude; + E.TargetRange = Target; + E.Replacement = makeText(formatHeaderPath(Header, Format)); + return E; } RewriteRule transformer::makeRule(DynTypedMatcher M, EditGenerator Edits, TextGenerator Explanation) { - return RewriteRule{{RewriteRule::Case{ - std::move(M), std::move(Edits), std::move(Explanation), {}}}}; + return RewriteRule{{RewriteRule::Case{std::move(M), std::move(Edits), + std::move(Explanation)}}}; } namespace { @@ -258,7 +281,7 @@ void transformer::addInclude(RewriteRule &Rule, StringRef Header, IncludeFormat Format) { for (auto &Case : Rule.Cases) - Case.AddedIncludes.emplace_back(Header.str(), Format); + Case.Edits = flatten(std::move(Case.Edits), addInclude(Header, Format)); } #ifndef NDEBUG diff --git a/clang/lib/Tooling/Transformer/Transformer.cpp b/clang/lib/Tooling/Transformer/Transformer.cpp --- a/clang/lib/Tooling/Transformer/Transformer.cpp +++ b/clang/lib/Tooling/Transformer/Transformer.cpp @@ -51,29 +51,20 @@ T.Range.getBegin(), T.Metadata)) .first; auto &AC = Iter->second; - if (auto Err = AC.replace(*Result.SourceManager, T.Range, T.Replacement)) { - Consumer(std::move(Err)); - return; - } - } - - for (auto &IDChangePair : ChangesByFileID) { - auto &AC = IDChangePair.second; - // FIXME: this will add includes to *all* changed files, which may not be - // the intent. We should upgrade the representation to allow associating - // headers with specific edits. - for (const auto &I : Case.AddedIncludes) { - auto &Header = I.first; - switch (I.second) { - case transformer::IncludeFormat::Quoted: - AC.addHeader(Header); - break; - case transformer::IncludeFormat::Angled: - AC.addHeader((llvm::Twine("<") + Header + ">").str()); - break; + switch (T.Kind) { + case transformer::EditKind::Range: + if (auto Err = + AC.replace(*Result.SourceManager, T.Range, T.Replacement)) { + Consumer(std::move(Err)); + return; } + break; + case transformer::EditKind::AddInclude: + AC.addHeader(T.Replacement); + break; } - - Consumer(std::move(AC)); } + + for (auto &IDChangePair : ChangesByFileID) + Consumer(std::move(IDChangePair.second)); } diff --git a/clang/unittests/Tooling/TransformerTest.cpp b/clang/unittests/Tooling/TransformerTest.cpp --- a/clang/unittests/Tooling/TransformerTest.cpp +++ b/clang/unittests/Tooling/TransformerTest.cpp @@ -21,6 +21,7 @@ using namespace tooling; using namespace ast_matchers; namespace { +using ::testing::ElementsAre; using ::testing::IsEmpty; using transformer::cat; using transformer::changeTo; @@ -194,6 +195,43 @@ } TEST_F(TransformerTest, AddIncludeQuoted) { + RewriteRule Rule = + makeRule(callExpr(callee(functionDecl(hasName("f")))), + {addInclude("clang/OtherLib.h"), changeTo(cat("other()"))}); + + std::string Input = R"cc( + int f(int x); + int h(int x) { return f(x); } + )cc"; + std::string Expected = R"cc(#include "clang/OtherLib.h" + + int f(int x); + int h(int x) { return other(); } + )cc"; + + testRule(Rule, Input, Expected); +} + +TEST_F(TransformerTest, AddIncludeAngled) { + RewriteRule Rule = makeRule( + callExpr(callee(functionDecl(hasName("f")))), + {addInclude("clang/OtherLib.h", transformer::IncludeFormat::Angled), + changeTo(cat("other()"))}); + + std::string Input = R"cc( + int f(int x); + int h(int x) { return f(x); } + )cc"; + std::string Expected = R"cc(#include + + int f(int x); + int h(int x) { return other(); } + )cc"; + + testRule(Rule, Input, Expected); +} + +TEST_F(TransformerTest, AddIncludeQuotedForRule) { RewriteRule Rule = makeRule(callExpr(callee(functionDecl(hasName("f")))), changeTo(cat("other()"))); addInclude(Rule, "clang/OtherLib.h"); @@ -211,7 +249,7 @@ testRule(Rule, Input, Expected); } -TEST_F(TransformerTest, AddIncludeAngled) { +TEST_F(TransformerTest, AddIncludeAngledForRule) { RewriteRule Rule = makeRule(callExpr(callee(functionDecl(hasName("f")))), changeTo(cat("other()"))); addInclude(Rule, "clang/OtherLib.h", transformer::IncludeFormat::Angled); @@ -1180,4 +1218,32 @@ EXPECT_EQ(format(*UpdatedCode), format(R"cc(#include "input.h" ;)cc")); } + +TEST_F(TransformerTest, AddIncludeMultipleFiles) { + std::string Header = R"cc(void RemoveThisFunction();)cc"; + std::string Source = R"cc(#include "input.h" + void Foo() {RemoveThisFunction();})cc"; + Transformer T( + makeRule(callExpr(callee( + functionDecl(hasName("RemoveThisFunction")).bind("fun"))), + addInclude(node("fun"), "header.h")), + consumer()); + T.registerMatchers(&MatchFinder); + auto Factory = newFrontendActionFactory(&MatchFinder); + EXPECT_TRUE(runToolOnCodeWithArgs( + Factory->create(), Source, std::vector(), "input.cc", + "clang-tool", std::make_shared(), + {{"input.h", Header}})); + + ASSERT_EQ(Changes.size(), 1U); + ASSERT_EQ(Changes[0].getFilePath(), "./input.h"); + EXPECT_THAT(Changes[0].getInsertedHeaders(), ElementsAre("header.h")); + EXPECT_THAT(Changes[0].getRemovedHeaders(), IsEmpty()); + llvm::Expected UpdatedCode = + clang::tooling::applyAllReplacements(Header, + Changes[0].getReplacements()); + ASSERT_TRUE(static_cast(UpdatedCode)) + << "Could not update code: " << llvm::toString(UpdatedCode.takeError()); + EXPECT_EQ(format(*UpdatedCode), format(Header)); +} } // namespace