diff --git a/clang/include/clang/ASTMatchers/ASTMatchersInternal.h b/clang/include/clang/ASTMatchers/ASTMatchersInternal.h --- a/clang/include/clang/ASTMatchers/ASTMatchersInternal.h +++ b/clang/include/clang/ASTMatchers/ASTMatchersInternal.h @@ -379,6 +379,10 @@ constructRestrictedWrapper(const DynTypedMatcher &InnerMatcher, ASTNodeKind RestrictKind); + static DynTypedMatcher + constructTraversalWrapper(const DynTypedMatcher &InnerMatcher, + ast_type_traits::TraversalKind TK); + /// Get a "true" matcher for \p NodeKind. /// /// It only checks that the node is of the right kind. diff --git a/clang/include/clang/ASTMatchers/Dynamic/VariantValue.h b/clang/include/clang/ASTMatchers/Dynamic/VariantValue.h --- a/clang/include/clang/ASTMatchers/Dynamic/VariantValue.h +++ b/clang/include/clang/ASTMatchers/Dynamic/VariantValue.h @@ -138,6 +138,9 @@ /// Clones the provided matcher. static VariantMatcher SingleMatcher(const DynTypedMatcher &Matcher); + static VariantMatcher TraversalMatcher(ast_type_traits::TraversalKind TK, + const DynTypedMatcher &Matcher); + /// Clones the provided matchers. /// /// They should be the result of a polymorphic matcher. @@ -214,6 +217,7 @@ template struct TypedMatcherOps; class SinglePayload; + class TraversalPayload; class PolymorphicPayload; class VariadicOpPayload; diff --git a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp --- a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp +++ b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp @@ -204,6 +204,36 @@ return Copy; } +class DynTraversalMatcher : public DynMatcherInterface { +public: + explicit DynTraversalMatcher( + clang::TraversalKind TK, + IntrusiveRefCntPtr ChildMatcher) + : Traversal(TK), InnerMatcher(ChildMatcher) {} + + bool dynMatches(const DynTypedNode &DynNode, ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder) const override { + return this->InnerMatcher->dynMatches(DynNode, Finder, Builder); + } + + llvm::Optional TraversalKind() const override { + return Traversal; + } + +private: + clang::TraversalKind Traversal; + IntrusiveRefCntPtr InnerMatcher; +}; + +DynTypedMatcher +DynTypedMatcher::constructTraversalWrapper(const DynTypedMatcher &InnerMatcher, + TraversalKind TK) { + DynTypedMatcher Copy = InnerMatcher; + Copy.Implementation = + new DynTraversalMatcher(TK, std::move(Copy.Implementation)); + return Copy; +} + DynTypedMatcher DynTypedMatcher::trueMatcher(ASTNodeKind NodeKind) { return DynTypedMatcher(NodeKind, NodeKind, &*TrueMatcherInstance); } diff --git a/clang/lib/ASTMatchers/Dynamic/Marshallers.h b/clang/lib/ASTMatchers/Dynamic/Marshallers.h --- a/clang/lib/ASTMatchers/Dynamic/Marshallers.h +++ b/clang/lib/ASTMatchers/Dynamic/Marshallers.h @@ -244,6 +244,31 @@ static llvm::Optional getBestGuess(const VariantValue &Value); }; +template <> struct ArgTypeTraits { +private: + static Optional + getTraversalKind(llvm::StringRef TK) { + return llvm::StringSwitch>(TK) + .Case("TK::AsIs", ast_type_traits::TK_AsIs) + .Case("TK::IgnoreUnlessSpelledInSource", + ast_type_traits::TK_IgnoreUnlessSpelledInSource) + .Case("TK::IgnoreImplicitCastsAndParentheses", + ast_type_traits::TK_IgnoreImplicitCastsAndParentheses) + .Default(llvm::None); + } + +public: + static bool is(const VariantValue &Value) { + return Value.isString() && getTraversalKind(Value.getString()); + } + + static ast_type_traits::TraversalKind get(const VariantValue &Value) { + return *getTraversalKind(Value.getString()); + } + + static ArgKind getKind() { return ArgKind(ArgKind::AK_String); } +}; + /// Matcher descriptor interface. /// /// Provides a \c create() method that constructs the matcher from the provided @@ -299,6 +324,64 @@ return false; } +class TraverseMatcherDescriptor : public MatcherDescriptor { +public: + TraverseMatcherDescriptor() {} + + VariantMatcher create(SourceRange NameRange, ArrayRef Args, + Diagnostics *Error) const override { + if (Args.size() != 2) + return {}; + auto TraversalArg = Args[0].Value; + if (!TraversalArg.isString()) + return {}; + auto TraversalCode = + llvm::StringSwitch(TraversalArg.getString()) + .Case("AsIs", ast_type_traits::TK_AsIs) + .Case("IgnoreImplicitCastsAndParentheses", + ast_type_traits::TK_IgnoreImplicitCastsAndParentheses) + .Case("IgnoreUnlessSpelledInSource", + ast_type_traits::TK_IgnoreUnlessSpelledInSource) + .Default(~0u); + if (TraversalCode == ~0u) { + return {}; + } + auto MatcherArg = Args[1].Value; + if (!MatcherArg.isMatcher()) + return {}; + auto M = MatcherArg.getMatcher(); + if (M.isNull()) + return {}; + auto SM = M.getSingleMatcher(); + if (!SM) + return {}; + return VariantMatcher::TraversalMatcher( + static_cast(TraversalCode), *SM); + } + + bool isVariadic() const override { return false; } + unsigned getNumArgs() const override { return 2; } + + void getArgKinds(ast_type_traits::ASTNodeKind ThisKind, unsigned ArgNo, + std::vector &Kinds) const override { + if (ArgNo == 0) { + Kinds.push_back(ArgKind::AK_String); + return; + } + Kinds.push_back(ThisKind); + } + + bool isConvertibleTo( + ast_type_traits::ASTNodeKind Kind, unsigned *Specificity, + ast_type_traits::ASTNodeKind *LeastDerivedKind) const override { + return true; + // return isRetKindConvertibleTo(RetKinds, Kind, Specificity, + // LeastDerivedKind); + } + +private: +}; + /// Simple callback implementation. Marshaller and function are provided. /// /// This class wraps a function of arbitrary signature and a marshaller diff --git a/clang/lib/ASTMatchers/Dynamic/Registry.cpp b/clang/lib/ASTMatchers/Dynamic/Registry.cpp --- a/clang/lib/ASTMatchers/Dynamic/Registry.cpp +++ b/clang/lib/ASTMatchers/Dynamic/Registry.cpp @@ -101,6 +101,9 @@ // Other: // equalsNode + registerMatcher("traverse", + std::make_unique()); + REGISTER_OVERLOADED_2(callee); REGISTER_OVERLOADED_2(hasAnyCapture); REGISTER_OVERLOADED_2(hasPrefix); diff --git a/clang/lib/ASTMatchers/Dynamic/VariantValue.cpp b/clang/lib/ASTMatchers/Dynamic/VariantValue.cpp --- a/clang/lib/ASTMatchers/Dynamic/VariantValue.cpp +++ b/clang/lib/ASTMatchers/Dynamic/VariantValue.cpp @@ -80,6 +80,44 @@ VariantMatcher::Payload::~Payload() {} +class VariantMatcher::TraversalPayload : public VariantMatcher::Payload { +public: + TraversalPayload(ast_type_traits::TraversalKind TK, + const DynTypedMatcher &Matcher) + : TK(TK), Matcher(Matcher) {} + + llvm::Optional getSingleMatcher() const override { + return Matcher; + } + + std::string getTypeAsString() const override { + return (Twine("Matcher<") + Matcher.getSupportedKind().asStringRef() + ">") + .str(); + } + + llvm::Optional + getTypedMatcher(const MatcherOps &Ops) const override { + // llvm::errs() << "TraversalCode" << "\n"; + // std::terminate(); + + bool Ignore; + if (Ops.canConstructFrom(Matcher, Ignore)) { + return DynTypedMatcher::constructTraversalWrapper(Matcher, TK); + } + return llvm::None; + } + + bool isConvertibleTo(ast_type_traits::ASTNodeKind Kind, + unsigned *Specificity) const override { + return ArgKind(Matcher.getSupportedKind()) + .isConvertibleTo(Kind, Specificity); + } + +private: + ast_type_traits::TraversalKind TK; + const DynTypedMatcher Matcher; +}; + class VariantMatcher::SinglePayload : public VariantMatcher::Payload { public: SinglePayload(const DynTypedMatcher &Matcher) : Matcher(Matcher) {} @@ -219,6 +257,12 @@ return VariantMatcher(std::make_shared(Matcher)); } +VariantMatcher +VariantMatcher::TraversalMatcher(ast_type_traits::TraversalKind TK, + const DynTypedMatcher &Matcher) { + return VariantMatcher(std::make_shared(TK, Matcher)); +} + VariantMatcher VariantMatcher::PolymorphicMatcher(std::vector Matchers) { return VariantMatcher( diff --git a/clang/unittests/ASTMatchers/Dynamic/RegistryTest.cpp b/clang/unittests/ASTMatchers/Dynamic/RegistryTest.cpp --- a/clang/unittests/ASTMatchers/Dynamic/RegistryTest.cpp +++ b/clang/unittests/ASTMatchers/Dynamic/RegistryTest.cpp @@ -505,6 +505,25 @@ EXPECT_FALSE(matches("struct X {};", Value)); } +TEST_F(RegistryTest, Traverse) { + EXPECT_TRUE( + matches(R"cpp( +int foo() +{ + double d = 0; + return d; +} +)cpp", + constructMatcher( + "returnStmt", + constructMatcher( + "hasReturnValue", + constructMatcher("traverse", + StringRef("IgnoreUnlessSpelledInSource"), + constructMatcher("declRefExpr")))) + .getTypedMatcher())); +} + TEST_F(RegistryTest, ParenExpr) { Matcher Value = constructMatcher("parenExpr").getTypedMatcher(); EXPECT_TRUE(