diff --git a/clang/include/clang/ASTMatchers/ASTMatchersInternal.h b/clang/include/clang/ASTMatchers/ASTMatchersInternal.h --- a/clang/include/clang/ASTMatchers/ASTMatchersInternal.h +++ b/clang/include/clang/ASTMatchers/ASTMatchersInternal.h @@ -379,6 +379,12 @@ constructRestrictedWrapper(const DynTypedMatcher &InnerMatcher, ASTNodeKind RestrictKind); + /// Creates a new matcher that is identical to the old one, but sets the + /// traversal kind. If `InnerMatcher` had already set a traversal kind, then + /// the new one overrides it. + static DynTypedMatcher + constructWithTraversalKind(DynTypedMatcher InnerMatcher, TraversalKind TK); + /// Get a "true" matcher for \p NodeKind. /// /// It only checks that the node is of the right kind. @@ -458,6 +464,14 @@ /// If it is not compatible, then this matcher will never match anything. template Matcher unconditionalConvertTo() const; + /// Returns the \c TraversalKind respected by calls to `match()`, if any. + /// + /// Most matchers will not have a traversal kind set, instead relying on the + /// surrounding context. For those, \c llvm::None is returned. + llvm::Optional getTraversalKind() const { + return Implementation->TraversalKind(); + } + private: DynTypedMatcher(ASTNodeKind SupportedKind, ASTNodeKind RestrictKind, IntrusiveRefCntPtr Implementation) diff --git a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp --- a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp +++ b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp @@ -136,6 +136,31 @@ } }; +/// A matcher that specifies a particular \c TraversalKind. +/// +/// The kind provided to the constructor overrides any kind that may be +/// specified by the `InnerMatcher`. +class DynTraversalMatcherImpl : public DynMatcherInterface { +public: + explicit DynTraversalMatcherImpl( + clang::TraversalKind TK, + IntrusiveRefCntPtr InnerMatcher) + : TK(TK), InnerMatcher(std::move(InnerMatcher)) {} + + bool dynMatches(const DynTypedNode &DynNode, ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder) const override { + return this->InnerMatcher->dynMatches(DynNode, Finder, Builder); + } + + llvm::Optional TraversalKind() const override { + return TK; + } + +private: + clang::TraversalKind TK; + IntrusiveRefCntPtr InnerMatcher; +}; + } // namespace static llvm::ManagedStatic TrueMatcherInstance; @@ -204,6 +229,14 @@ return Copy; } +DynTypedMatcher +DynTypedMatcher::constructWithTraversalKind(DynTypedMatcher InnerMatcher, + ast_type_traits::TraversalKind TK) { + InnerMatcher.Implementation = + new DynTraversalMatcherImpl(TK, std::move(InnerMatcher.Implementation)); + return InnerMatcher; +} + DynTypedMatcher DynTypedMatcher::trueMatcher(ASTNodeKind NodeKind) { return DynTypedMatcher(NodeKind, NodeKind, &*TrueMatcherInstance); } diff --git a/clang/unittests/ASTMatchers/ASTMatchersInternalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersInternalTest.cpp --- a/clang/unittests/ASTMatchers/ASTMatchersInternalTest.cpp +++ b/clang/unittests/ASTMatchers/ASTMatchersInternalTest.cpp @@ -17,6 +17,7 @@ namespace clang { namespace ast_matchers { +using internal::DynTypedMatcher; #if GTEST_HAS_DEATH_TEST TEST(HasNameDeathTest, DiesOnEmptyName) { @@ -171,6 +172,34 @@ EXPECT_NE(nullptr, PT); } +TEST(DynTypedMatcherTest, TraversalKindForwardsToImpl) { + auto M = DynTypedMatcher(decl()); + EXPECT_FALSE(M.getTraversalKind().hasValue()); + + M = DynTypedMatcher(traverse(TK_AsIs, decl())); + llvm::Optional TK = M.getTraversalKind(); + EXPECT_TRUE(TK.hasValue()); + EXPECT_EQ(*TK, TK_AsIs); +} + +TEST(DynTypedMatcherTest, ConstructWithTraversalKindSetsTK) { + auto M = DynTypedMatcher::constructWithTraversalKind(DynTypedMatcher(decl()), + TK_AsIs); + llvm::Optional TK = M.getTraversalKind(); + EXPECT_TRUE(TK.hasValue()); + EXPECT_EQ(*TK, TK_AsIs); +} + +TEST(DynTypedMatcherTest, ConstructWithTraversalKindOverridesNestedTK) { + auto M = DynTypedMatcher::constructWithTraversalKind(DynTypedMatcher(decl()), + TK_AsIs); + auto M2 = DynTypedMatcher::constructWithTraversalKind( + M, TK_IgnoreUnlessSpelledInSource); + llvm::Optional TK = M2.getTraversalKind(); + EXPECT_TRUE(TK.hasValue()); + EXPECT_EQ(*TK, TK_IgnoreUnlessSpelledInSource); +} + TEST(IsInlineMatcher, IsInline) { EXPECT_TRUE(matches("void g(); inline void f();", functionDecl(isInline(), hasName("f"))));