diff --git a/clang/include/clang/Tooling/Refactoring/Transformer.h b/clang/include/clang/Tooling/Refactoring/Transformer.h --- a/clang/include/clang/Tooling/Refactoring/Transformer.h +++ b/clang/include/clang/Tooling/Refactoring/Transformer.h @@ -160,11 +160,9 @@ void addInclude(RewriteRule &Rule, llvm::StringRef Header, IncludeFormat Format = IncludeFormat::Quoted); -/// Applies the first rule whose pattern matches; other rules are ignored. -/// -/// N.B. All of the rules must use the same kind of matcher (that is, share a -/// base class in the AST hierarchy). However, this constraint is caused by an -/// implementation detail and should be lifted in the future. +/// Applies the first rule whose pattern matches; other rules are ignored. If +/// the matchers are independent then order doesn't matter. In that case, +/// `applyFirst` is simply joining the set of rules into one. // // `applyFirst` is like an `anyOf` matcher with an edit action attached to each // of its cases. Anywhere you'd use `anyOf(m1.bind("id1"), m2.bind("id2"))` and @@ -243,8 +241,16 @@ // public and well-supported and move them out of `detail`. namespace detail { /// Builds a single matcher for the rule, covering all of the rule's cases. +/// Only supports Rules whose cases' matchers all share the same base "kind" +/// (`Stmt`, `Decl`, etc.) Deprecated: use `buildMatchers` instead, which +/// supports mixing matchers of different kinds. ast_matchers::internal::DynTypedMatcher buildMatcher(const RewriteRule &Rule); +/// Builds a set of matchers that cover the rule (one for each distinct matcher +/// base kind: Stmt, Decl, etc.) +std::vector +buildMatchers(const RewriteRule &Rule); + /// Returns the \c Case of \c Rule that was selected in the match result. /// Assumes a matcher built with \c buildMatcher. const RewriteRule::Case & diff --git a/clang/lib/Tooling/Refactoring/Transformer.cpp b/clang/lib/Tooling/Refactoring/Transformer.cpp --- a/clang/lib/Tooling/Refactoring/Transformer.cpp +++ b/clang/lib/Tooling/Refactoring/Transformer.cpp @@ -91,41 +91,53 @@ return (A.isSame(TypeKind) && B.isSame(QualKind)) || A.isBaseOf(B); } -// Try to find a common kind to which all of the rule's matchers can be -// converted. -static ASTNodeKind -findCommonKind(const SmallVectorImpl &Cases) { - assert(!Cases.empty() && "Rule must have at least one case."); - ASTNodeKind JoinKind = Cases[0].Matcher.getSupportedKind(); - // Find a (least) Kind K, for which M.canConvertTo(K) holds, for all matchers - // M in Rules. - for (const auto &Case : Cases) { - auto K = Case.Matcher.getSupportedKind(); - if (isBaseOf(JoinKind, K)) { - JoinKind = K; - continue; +namespace { +struct KindBucket { + // Common kind which all of the bucket's matchers can support. Specifically, a + // (least) Kind K, for which M.canConvertTo(K) holds, for all matchers M in + // `Cases`. + ASTNodeKind Kind; + // Each case is paired with an id that is used to bind it in the matcher. + SmallVector, 1> Cases; +}; +} // namespace + +// Finds a bucket that is compatible with the given case and adds the case to +// that bucket; otherwise, creates a new bucket for the case. Compatibility +// means that the type of the case's matcher is a sub/super class of all other +// cases in that bucket. +static void insertCase(std::vector &Buckets, size_t CaseId, + RewriteRule::Case Case) { + ASTNodeKind CaseKind = Case.Matcher.getSupportedKind(); + for (int I = Buckets.size() - 1; I >= 0; --I) { + KindBucket &Bucket = Buckets[I]; + if (isBaseOf(Bucket.Kind, CaseKind)) { + // Case belongs in this bucket and bucket's new kind is `CaseKind`. + Bucket.Kind = CaseKind; + Bucket.Cases.emplace_back(CaseId, std::move(Case)); + return; + } + if (CaseKind.isSame(Bucket.Kind) || isBaseOf(CaseKind, Bucket.Kind)) { + // Case belongs in this bucket; no update to bucket needed. + Bucket.Cases.emplace_back(CaseId, std::move(Case)); + return; } - if (K.isSame(JoinKind) || isBaseOf(K, JoinKind)) - // JoinKind is already the lowest. - continue; - // K and JoinKind are unrelated -- there is no least common kind. - return ASTNodeKind(); - } - return JoinKind; + } + // Couldn't find a compatible bucket. Start a new one. + Buckets.push_back(KindBucket{CaseKind, {}}); + Buckets.back().Cases.emplace_back(CaseId, std::move(Case)); } // Binds each rule's matcher to a unique (and deterministic) tag based on -// `TagBase`. -static std::vector -taggedMatchers(StringRef TagBase, - const SmallVectorImpl &Cases) { +// `TagBase` and the id paired with the case. +static std::vector taggedMatchers( + StringRef TagBase, + const SmallVectorImpl> &Cases) { std::vector Matchers; Matchers.reserve(Cases.size()); - size_t count = 0; for (const auto &Case : Cases) { - std::string Tag = (TagBase + Twine(count)).str(); - ++count; - auto M = Case.Matcher.tryBind(Tag); + std::string Tag = (TagBase + Twine(Case.first)).str(); + auto M = Case.second.Matcher.tryBind(Tag); assert(M && "RewriteRule matchers should be bindable."); Matchers.push_back(*std::move(M)); } @@ -142,22 +154,29 @@ return R; } -static DynTypedMatcher joinCaseMatchers(const RewriteRule &Rule) { - assert(!Rule.Cases.empty() && "Rule must have at least one case."); - if (Rule.Cases.size() == 1) - return Rule.Cases[0].Matcher; +std::vector +tooling::detail::buildMatchers(const RewriteRule &Rule) { + // Sort the cases into buckets of compatible matchers. + std::vector Buckets; + for (int I = 0, N = Rule.Cases.size(); I < N; ++I) + insertCase(Buckets, I, Rule.Cases[I]); - auto CommonKind = findCommonKind(Rule.Cases); - assert(!CommonKind.isNone() && "Cases must have compatible matchers."); - return DynTypedMatcher::constructVariadic( - DynTypedMatcher::VO_AnyOf, CommonKind, taggedMatchers("Tag", Rule.Cases)); + std::vector Matchers; + for (const auto &Bucket : Buckets) { + DynTypedMatcher M = DynTypedMatcher::constructVariadic( + DynTypedMatcher::VO_AnyOf, Bucket.Kind, + taggedMatchers("Tag", Bucket.Cases)); + M.setAllowBind(true); + // `tryBind` is guaranteed to succeed, because `AllowBind` was set to true. + Matchers.push_back(*M.tryBind(RewriteRule::RootID)); + } + return Matchers; } DynTypedMatcher tooling::detail::buildMatcher(const RewriteRule &Rule) { - DynTypedMatcher M = joinCaseMatchers(Rule); - M.setAllowBind(true); - // `tryBind` is guaranteed to succeed, because `AllowBind` was set to true. - return *M.tryBind(RewriteRule::RootID); + std::vector Ms = buildMatchers(Rule); + assert(Ms.size() == 1 && "Cases must have compatible matchers."); + return Ms[0]; } // Finds the case that was "selected" -- that is, whose matcher triggered the @@ -180,7 +199,8 @@ constexpr llvm::StringLiteral RewriteRule::RootID; void Transformer::registerMatchers(MatchFinder *MatchFinder) { - MatchFinder->addDynamicMatcher(tooling::detail::buildMatcher(Rule), this); + for (auto &Matcher : tooling::detail::buildMatchers(Rule)) + MatchFinder->addDynamicMatcher(Matcher, this); } void Transformer::run(const MatchResult &Result) { @@ -222,12 +242,12 @@ for (const auto &I : Case.AddedIncludes) { auto &Header = I.first; switch (I.second) { - case IncludeFormat::Quoted: - AC.addHeader(Header); - break; - case IncludeFormat::Angled: - AC.addHeader((llvm::Twine("<") + Header + ">").str()); - break; + case IncludeFormat::Quoted: + AC.addHeader(Header); + break; + case IncludeFormat::Angled: + AC.addHeader((llvm::Twine("<") + Header + ">").str()); + break; } } diff --git a/clang/unittests/Tooling/TransformerTest.cpp b/clang/unittests/Tooling/TransformerTest.cpp --- a/clang/unittests/Tooling/TransformerTest.cpp +++ b/clang/unittests/Tooling/TransformerTest.cpp @@ -482,16 +482,17 @@ testRule(applyFirst({ruleStrlenSize(), FlagRule}), Input, Expected); } -// Version of ruleStrlenSizeAny that inserts a method with a different name than +// Version of ruleStrlenSize that doesn't check the type of the targeted +// object. Note that we insert a method with a different name than // ruleStrlenSize, so we can tell their effect apart. -RewriteRule ruleStrlenSizeDistinct() { +RewriteRule ruleStrlenSizeAny() { StringRef S; return makeRule( callExpr(callee(functionDecl(hasName("strlen"))), hasArgument(0, cxxMemberCallExpr( on(expr().bind(S)), callee(cxxMethodDecl(hasName("c_str")))))), - change(text("DISTINCT"))); + change(text("ANY"))); } TEST_F(TransformerTest, OrderedRuleRelated) { @@ -509,16 +510,18 @@ struct mystring { char* c_str(); }; - int f(mystring s) { return DISTINCT; } + int f(mystring s) { return ANY; } } // namespace foo int g(string s) { return REPLACED; } )cc"; - testRule(applyFirst({ruleStrlenSize(), ruleStrlenSizeDistinct()}), Input, + testRule(applyFirst({ruleStrlenSize(), ruleStrlenSizeAny()}), Input, Expected); } -// Change the order of the rules to get a different result. +// Change the order of the rules to get a different result. When +// `ruleStrlenSizeAny` comes first, it applies for both uses, so +// `ruleStrlenSize` never applies. TEST_F(TransformerTest, OrderedRuleRelatedSwapped) { std::string Input = R"cc( namespace foo { @@ -534,15 +537,51 @@ struct mystring { char* c_str(); }; - int f(mystring s) { return DISTINCT; } + int f(mystring s) { return ANY; } } // namespace foo - int g(string s) { return DISTINCT; } + int g(string s) { return ANY; } )cc"; - testRule(applyFirst({ruleStrlenSizeDistinct(), ruleStrlenSize()}), Input, + testRule(applyFirst({ruleStrlenSizeAny(), ruleStrlenSize()}), Input, Expected); } +// Verify that a set of rules whose matchers have different base kinds works +// properly, including that `applyFirst` produces multiple matchers. +TEST_F(TransformerTest, OrderedRuleMultipleKinds) { + RewriteRule DeclRule = makeRule(functionDecl(hasName("bad")).bind("fun"), + change(name("fun"), text("good"))); + // We test two different kinds of rules: Expr and Decl. We place the Decl rule + // in the middle to test that `buildMatchers` works even when the kinds aren't + // grouped together. + RewriteRule Rule = + applyFirst({ruleStrlenSize(), DeclRule, ruleStrlenSizeAny()}); + + std::string Input = R"cc( + namespace foo { + struct mystring { + char* c_str(); + }; + int f(mystring s) { return strlen(s.c_str()); } + } // namespace foo + int g(string s) { return strlen(s.c_str()); } + int bad(int x); + )cc"; + std::string Expected = R"cc( + namespace foo { + struct mystring { + char* c_str(); + }; + int f(mystring s) { return ANY; } + } // namespace foo + int g(string s) { return REPLACED; } + int good(int x); + )cc"; + + EXPECT_EQ(tooling::detail::buildMatchers(Rule).size(), 2UL); + testRule(Rule, Input, Expected); +} + // // Negative tests (where we expect no transformation to occur). //