diff --git a/clang/include/clang/Tooling/Refactoring/Transformer.h b/clang/include/clang/Tooling/Refactoring/Transformer.h --- a/clang/include/clang/Tooling/Refactoring/Transformer.h +++ b/clang/include/clang/Tooling/Refactoring/Transformer.h @@ -221,6 +221,26 @@ translateEdits(const ast_matchers::MatchFinder::MatchResult &Result, llvm::ArrayRef Edits); +/// Composes multiple rules into a single object that can be registered with a +/// single matcher. Upon match, the tags in said matcher can be used to +/// determine which rule in \c Rules to apply. +struct CompositeRewriteRule { + // Matcher that multiplexes the composed rules. Demultiplexing is done with + // \p findSelectedRule. + ast_matchers::internal::DynTypedMatcher Matcher; + std::vector Rules; +}; + +// Creates a composite rule that applies the first rule in `Rules` whose pattern +// matches a given node. All of the rules must use the same kind of matcher +// (that is, share a base class in the AST hierarchy). +CompositeRewriteRule makeOrderedRule(std::vector Rules); + +/// Returns the subrule of \c Rule that was selected in the given match result. +const RewriteRule & +findSelectedRule(const CompositeRewriteRule &Rule, + const ast_matchers::MatchFinder::MatchResult &Result); + /// Handles the matcher and callback registration for a single rewrite rule, as /// defined by the arguments of the constructor. class Transformer : public ast_matchers::MatchFinder::MatchCallback { @@ -233,9 +253,12 @@ /// because of macros, but doesn't fail. Note that clients are responsible /// for handling the case that independent \c AtomicChanges conflict with each /// other. - Transformer(RewriteRule Rule, ChangeConsumer Consumer) + Transformer(CompositeRewriteRule Rule, ChangeConsumer Consumer) : Rule(std::move(Rule)), Consumer(std::move(Consumer)) {} + Transformer(RewriteRule Rule, ChangeConsumer Consumer) + : Rule{Rule.Matcher, {std::move(Rule)}}, Consumer(std::move(Consumer)) {} + /// N.B. Passes `this` pointer to `MatchFinder`. So, this object should not /// be moved after this call. void registerMatchers(ast_matchers::MatchFinder *MatchFinder); @@ -245,7 +268,7 @@ void run(const ast_matchers::MatchFinder::MatchResult &Result) override; private: - RewriteRule Rule; + CompositeRewriteRule Rule; /// Receives each successful rewrites as an \c AtomicChange. ChangeConsumer Consumer; }; diff --git a/clang/lib/Tooling/Refactoring/Transformer.cpp b/clang/lib/Tooling/Refactoring/Transformer.cpp --- a/clang/lib/Tooling/Refactoring/Transformer.cpp +++ b/clang/lib/Tooling/Refactoring/Transformer.cpp @@ -28,6 +28,7 @@ using namespace tooling; using ast_matchers::MatchFinder; +using ast_matchers::internal::DynTypedMatcher; using ast_type_traits::ASTNodeKind; using ast_type_traits::DynTypedNode; using llvm::Error; @@ -171,7 +172,7 @@ return Transformations; } -RewriteRule tooling::makeRule(ast_matchers::internal::DynTypedMatcher M, +RewriteRule tooling::makeRule(DynTypedMatcher M, SmallVector Edits) { M.setAllowBind(true); // `tryBind` is guaranteed to succeed, because `AllowBind` was set to true. @@ -181,6 +182,80 @@ constexpr llvm::StringLiteral RewriteRule::RootId; +// Determines whether A is higher than B in the class hierarchy. +static bool isHigher(ASTNodeKind A, ASTNodeKind B) { + static auto QualKind = ASTNodeKind::getFromNodeKind(); + static auto TypeKind = ASTNodeKind::getFromNodeKind(); + /// Mimic the implicit conversions of Matcher<>. + /// - From Matcher to Matcher + /// - From Matcher to Matcher + return (A.isSame(TypeKind) && B.isSame(QualKind)) || A.isBaseOf(B); +} + +// Try to find a common kind to which all of the rule's matchers can be +// converted. +static ASTNodeKind findCommonKind(const std::vector &Rules) { + assert(!Rules.empty()); + ASTNodeKind JoinKind = Rules[0].Matcher.getSupportedKind(); + // Find a (least) Kind K, for which M.canConvertTo(K) holds, for all matchers + // M in Rules. + for (const auto &R : Rules) { + auto K = R.Matcher.getSupportedKind(); + if (isHigher(JoinKind, K)) { + JoinKind = K; + continue; + } + if (K.isSame(JoinKind) || isHigher(K, JoinKind)) + // JoinKind is already the lowest. + continue; + // K and JoinKind are unrelated -- there is no least common kind. + return ASTNodeKind(); + } + return JoinKind; +} + +// Binds each rule's matcher to a unique (and deterministic) tag based on +// `TagBase`. +static std::vector +taggedMatchers(StringRef TagBase, const std::vector &Rules) { + std::vector Matchers; + Matchers.reserve(Rules.size()); + size_t count = 0; + for (const auto &R : Rules) { + std::string Tag = (TagBase + Twine(count)).str(); + ++count; + auto M = R.Matcher.tryBind(Tag); + assert(M && "RewriteRule matchers should be bindable."); + Matchers.push_back(*std::move(M)); + } + return Matchers; +} + +CompositeRewriteRule tooling::makeOrderedRule(std::vector Rules) { + auto CommonKind = findCommonKind(Rules); + assert(!CommonKind.isNone() && "Rules must have compatible matchers."); + // Explicitly bind `M` to ensure we use `Rules` before it is moved. + auto M = DynTypedMatcher::constructVariadic( + DynTypedMatcher::VO_AnyOf, CommonKind, taggedMatchers("Tag", Rules)); + return {std::move(M), std::move(Rules)}; +} + +// Finds the rule that was "selected" -- that is, whose matcher triggered the +// `MatchResult`. +const RewriteRule &tooling::findSelectedRule(const CompositeRewriteRule &Rule, + const MatchResult &Result) { + if (Rule.Rules.size() == 1) + return Rule.Rules[0]; + + auto &NodesMap = Result.Nodes.getMap(); + for (size_t i = 0, N = Rule.Rules.size(); i < N; ++i) { + std::string Tag = ("Tag" + Twine(i)).str(); + if (NodesMap.find(Tag) != NodesMap.end()) + return Rule.Rules[i]; + } + llvm_unreachable("No tag found for rule set."); +} + void Transformer::registerMatchers(MatchFinder *MatchFinder) { MatchFinder->addDynamicMatcher(Rule.Matcher, this); } @@ -197,7 +272,8 @@ Root->second.getSourceRange().getBegin()); assert(RootLoc.isValid() && "Invalid location for Root node of match."); - auto Transformations = translateEdits(Result, Rule.Edits); + auto Transformations = + translateEdits(Result, findSelectedRule(Rule, Result).Edits); if (!Transformations) { Consumer(Transformations.takeError()); return; diff --git a/clang/unittests/Tooling/TransformerTest.cpp b/clang/unittests/Tooling/TransformerTest.cpp --- a/clang/unittests/Tooling/TransformerTest.cpp +++ b/clang/unittests/Tooling/TransformerTest.cpp @@ -116,7 +116,8 @@ }; } - void testRule(RewriteRule Rule, StringRef Input, StringRef Expected) { + template + void testRule(R Rule, StringRef Input, StringRef Expected) { Transformer T(std::move(Rule), consumer()); T.registerMatchers(&MatchFinder); compareSnippets(Expected, rewrite(Input)); @@ -375,6 +376,92 @@ Input, Expected); } +TEST_F(TransformerTest, OrderedRuleUnrelated) { + StringRef Flag = "flag"; + RewriteRule FlagRule = makeRule( + cxxMemberCallExpr(on(expr(hasType(cxxRecordDecl( + hasName("proto::ProtoCommandLineFlag")))) + .bind(Flag)), + unless(callee(cxxMethodDecl(hasName("GetProto"))))), + change(Flag, "PROTO")); + + std::string Input = R"cc( + proto::ProtoCommandLineFlag flag; + int x = flag.foo(); + int y = flag.GetProto().foo(); + int f(string s) { return strlen(s.c_str()); } + )cc"; + std::string Expected = R"cc( + proto::ProtoCommandLineFlag flag; + int x = PROTO.foo(); + int y = flag.GetProto().foo(); + int f(string s) { return REPLACED; } + )cc"; + + testRule(makeOrderedRule({ruleStrlenSize(), FlagRule}), Input, Expected); +} + +// Version of ruleStrlenSizeAny that inserts a method with a different name than +// ruleStrlenSize, so we can tell their effect apart. +RewriteRule ruleStrlenSizeDistinct() { + StringRef S; + return makeRule( + callExpr(callee(functionDecl(hasName("strlen"))), + hasArgument(0, cxxMemberCallExpr( + on(expr().bind(S)), + callee(cxxMethodDecl(hasName("c_str")))))), + change("DISTINCT")); +} + +TEST_F(TransformerTest, OrderedRuleRelated) { + std::string Input = R"cc( + namespace foo { + struct mystring { + char* c_str(); + }; + int f(mystring s) { return strlen(s.c_str()); } + } // namespace foo + int g(string s) { return strlen(s.c_str()); } + )cc"; + std::string Expected = R"cc( + namespace foo { + struct mystring { + char* c_str(); + }; + int f(mystring s) { return DISTINCT; } + } // namespace foo + int g(string s) { return REPLACED; } + )cc"; + + testRule(makeOrderedRule({ruleStrlenSize(), ruleStrlenSizeDistinct()}), Input, + Expected); +} + +// Change the order of the rules to get a different result. +TEST_F(TransformerTest, OrderedRuleRelatedSwapped) { + std::string Input = R"cc( + namespace foo { + struct mystring { + char* c_str(); + }; + int f(mystring s) { return strlen(s.c_str()); } + } // namespace foo + int g(string s) { return strlen(s.c_str()); } + )cc"; + std::string Expected = R"cc( + namespace foo { + struct mystring { + char* c_str(); + }; + int f(mystring s) { return DISTINCT; } + } // namespace foo + int g(string s) { return DISTINCT; } + )cc"; + + testRule(makeOrderedRule({ruleStrlenSizeDistinct(), ruleStrlenSize()}), Input, + Expected); +} + // // Negative tests (where we expect no transformation to occur). //