Index: include/clang/AST/ASTTypeTraits.h =================================================================== --- include/clang/AST/ASTTypeTraits.h +++ include/clang/AST/ASTTypeTraits.h @@ -45,7 +45,7 @@ class ASTNodeKind { public: /// \brief Empty identifier. It matches nothing. - ASTNodeKind() : KindId(NKI_None) {} + ASTNodeKind() : KindId(NKI_Nothing) {} /// \brief Construct an identifier for T. template @@ -53,6 +53,16 @@ return ASTNodeKind(KindToKindId::Id); } + /// \brief Return a kind that represents no type. + /// + /// \c nothing().isBaseOf(X) is always false for any X. + static ASTNodeKind nothing() { return ASTNodeKind(NKI_Nothing); } + + /// \brief Return a kind that represents any type. + /// + /// \c anything().isBaseOf(X) is always true for any X other than nothing(). + static ASTNodeKind anything() { return ASTNodeKind(NKI_Anything); } + /// \{ /// \brief Construct an identifier for the dynamic type of the node static ASTNodeKind getFromNode(const Decl &D); @@ -76,12 +86,24 @@ return KindId < Other.KindId; } + /// \brief Return the most derived type between \p Kind1 and \p Kind2. + /// + /// Return nothing() if they are not related. + static ASTNodeKind getMostDerivedType(ASTNodeKind Kind1, ASTNodeKind Kind2); + + /// \brief Return the most derived common anscestor between Kind1 and Kind2. + /// + /// Return anything() if they are not related. + static ASTNodeKind getMostDerivedCommonAnscestor(ASTNodeKind Kind1, + ASTNodeKind Kind2); + private: /// \brief Kind ids. /// /// Includes all possible base and derived kinds. enum NodeKindId { - NKI_None, + NKI_Nothing, + NKI_Anything, NKI_CXXCtorInitializer, NKI_TemplateArgument, NKI_NestedNameSpecifier, @@ -113,7 +135,7 @@ /// /// This struct is specialized below for all known kinds. template struct KindToKindId { - static const NodeKindId Id = NKI_None; + static const NodeKindId Id = NKI_Nothing; }; /// \brief Per kind info. Index: lib/AST/ASTTypeTraits.cpp =================================================================== --- lib/AST/ASTTypeTraits.cpp +++ lib/AST/ASTTypeTraits.cpp @@ -21,20 +21,21 @@ namespace ast_type_traits { const ASTNodeKind::KindInfo ASTNodeKind::AllKindInfo[] = { - { NKI_None, "" }, - { NKI_None, "CXXCtorInitializer" }, - { NKI_None, "TemplateArgument" }, - { NKI_None, "NestedNameSpecifier" }, - { NKI_None, "NestedNameSpecifierLoc" }, - { NKI_None, "QualType" }, - { NKI_None, "TypeLoc" }, - { NKI_None, "Decl" }, + { NKI_Nothing, "" }, + { NKI_Anything, "" }, + { NKI_Anything, "CXXCtorInitializer" }, + { NKI_Anything, "TemplateArgument" }, + { NKI_Anything, "NestedNameSpecifier" }, + { NKI_Anything, "NestedNameSpecifierLoc" }, + { NKI_Anything, "QualType" }, + { NKI_Anything, "TypeLoc" }, + { NKI_Anything, "Decl" }, #define DECL(DERIVED, BASE) { NKI_##BASE, #DERIVED "Decl" }, #include "clang/AST/DeclNodes.inc" - { NKI_None, "Stmt" }, + { NKI_Anything, "Stmt" }, #define STMT(DERIVED, BASE) { NKI_##BASE, #DERIVED }, #include "clang/AST/StmtNodes.inc" - { NKI_None, "Type" }, + { NKI_Anything, "Type" }, #define TYPE(DERIVED, BASE) { NKI_##BASE, #DERIVED "Type" }, #include "clang/AST/TypeNodes.def" }; @@ -44,24 +45,41 @@ } bool ASTNodeKind::isSame(ASTNodeKind Other) const { - return KindId != NKI_None && KindId == Other.KindId; + return KindId != NKI_Nothing && KindId == Other.KindId; } bool ASTNodeKind::isBaseOf(NodeKindId Base, NodeKindId Derived, unsigned *Distance) { - if (Base == NKI_None || Derived == NKI_None) return false; + if (Base == NKI_Nothing) return false; unsigned Dist = 0; - while (Derived != Base && Derived != NKI_None) { + // All types always have a higher enum value than their parent types. + while (Derived > Base) { Derived = AllKindInfo[Derived].ParentId; ++Dist; } - if (Distance) - *Distance = Dist; - return Derived == Base; + if (Derived != Base) return false; + if (Distance) *Distance = Dist; + return true; } StringRef ASTNodeKind::asStringRef() const { return AllKindInfo[KindId].Name; } +ASTNodeKind ASTNodeKind::getMostDerivedType(ASTNodeKind Kind1, + ASTNodeKind Kind2) { + if (Kind1.isBaseOf(Kind2)) return Kind2; + if (Kind2.isBaseOf(Kind1)) return Kind1; + return nothing(); +} + +ASTNodeKind ASTNodeKind::getMostDerivedCommonAnscestor(ASTNodeKind Kind1, + ASTNodeKind Kind2) { + NodeKindId Parent = Kind1.KindId; + while (!isBaseOf(Parent, Kind2.KindId, nullptr)) { + Parent = AllKindInfo[Parent].ParentId; + } + return ASTNodeKind(Parent); +} + ASTNodeKind ASTNodeKind::getFromNode(const Decl &D) { switch (D.getKind()) { #define DECL(DERIVED, BASE) \ @@ -74,7 +92,7 @@ ASTNodeKind ASTNodeKind::getFromNode(const Stmt &S) { switch (S.getStmtClass()) { - case Stmt::NoStmtClass: return NKI_None; + case Stmt::NoStmtClass: return NKI_Nothing; #define STMT(CLASS, PARENT) \ case Stmt::CLASS##Class: return ASTNodeKind(NKI_##CLASS); #define ABSTRACT_STMT(S) Index: lib/ASTMatchers/ASTMatchersInternal.cpp =================================================================== --- lib/ASTMatchers/ASTMatchersInternal.cpp +++ lib/ASTMatchers/ASTMatchersInternal.cpp @@ -64,28 +64,6 @@ const IntrusiveRefCntPtr InnerMatcher; }; -/// \brief Return the most derived type between \p Kind1 and \p Kind2. -/// -/// Return the null type if they are not related. -ast_type_traits::ASTNodeKind getMostDerivedType( - const ast_type_traits::ASTNodeKind Kind1, - const ast_type_traits::ASTNodeKind Kind2) { - if (Kind1.isBaseOf(Kind2)) return Kind2; - if (Kind2.isBaseOf(Kind1)) return Kind1; - return ast_type_traits::ASTNodeKind(); -} - -/// \brief Return the least derived type between \p Kind1 and \p Kind2. -/// -/// Return the null type if they are not related. -static ast_type_traits::ASTNodeKind getLeastDerivedType( - const ast_type_traits::ASTNodeKind Kind1, - const ast_type_traits::ASTNodeKind Kind2) { - if (Kind1.isBaseOf(Kind2)) return Kind1; - if (Kind2.isBaseOf(Kind1)) return Kind2; - return ast_type_traits::ASTNodeKind(); -} - } // namespace DynTypedMatcher DynTypedMatcher::constructVariadic( @@ -98,7 +76,8 @@ assert(Result.SupportedKind.isSame(M.SupportedKind) && "SupportedKind must match!"); Result.RestrictKind = - getLeastDerivedType(Result.RestrictKind, M.RestrictKind); + ast_type_traits::ASTNodeKind::getMostDerivedCommonAnscestor( + Result.RestrictKind, M.RestrictKind); } Result.Implementation = new VariadicMatcher(Func, std::move(InnerMatchers)); return Result; @@ -108,7 +87,8 @@ const ast_type_traits::ASTNodeKind Kind) const { auto Copy = *this; Copy.SupportedKind = Kind; - Copy.RestrictKind = getMostDerivedType(Kind, RestrictKind); + Copy.RestrictKind = + ast_type_traits::ASTNodeKind::getMostDerivedType(Kind, RestrictKind); return Copy; } Index: unittests/AST/ASTTypeTraitsTest.cpp =================================================================== --- unittests/AST/ASTTypeTraitsTest.cpp +++ unittests/AST/ASTTypeTraitsTest.cpp @@ -17,15 +17,39 @@ namespace clang { namespace ast_type_traits { -TEST(ASTNodeKind, NoKind) { - EXPECT_FALSE(ASTNodeKind().isBaseOf(ASTNodeKind())); - EXPECT_FALSE(ASTNodeKind().isSame(ASTNodeKind())); -} - template static ASTNodeKind DNT() { return ASTNodeKind::getFromNodeKind(); } +TEST(ASTNodeKind, NothingAnything) { + auto nothing = ASTNodeKind::nothing(); + auto anything = ASTNodeKind::anything(); + + // Nothing + EXPECT_FALSE(nothing.isBaseOf(nothing)); + EXPECT_FALSE(nothing.isSame(nothing)); + EXPECT_FALSE(nothing.isBaseOf(DNT())); + EXPECT_FALSE(DNT().isBaseOf(nothing)); + EXPECT_FALSE(nothing.isBaseOf(DNT())); + EXPECT_FALSE(DNT().isBaseOf(nothing)); + + // Anything + EXPECT_TRUE(anything.isBaseOf(anything)); + EXPECT_TRUE(anything.isSame(anything)); + EXPECT_TRUE(anything.isBaseOf(DNT())); + EXPECT_TRUE(anything.isBaseOf(DNT())); + EXPECT_FALSE(DNT().isBaseOf(anything)); + EXPECT_FALSE(DNT().isBaseOf(anything)); + EXPECT_FALSE(anything.isSame(DNT())); + EXPECT_FALSE(anything.isSame(DNT())); + + // Nothing vs Anything + EXPECT_FALSE(nothing.isSame(anything)); + EXPECT_FALSE(anything.isSame(nothing)); + EXPECT_FALSE(nothing.isBaseOf(anything)); + EXPECT_FALSE(anything.isBaseOf(nothing)); +} + TEST(ASTNodeKind, Bases) { EXPECT_TRUE(DNT().isBaseOf(DNT())); EXPECT_FALSE(DNT().isSame(DNT())); @@ -60,6 +84,40 @@ EXPECT_FALSE(DNT().isSame(DNT())); } +TEST(ASTNodeKind, MostDerivedType) { + EXPECT_TRUE(DNT().isSame( + ASTNodeKind::getMostDerivedType(DNT(), DNT()))); + EXPECT_TRUE(DNT().isSame( + ASTNodeKind::getMostDerivedType(DNT(), DNT()))); + EXPECT_TRUE(DNT().isSame( + ASTNodeKind::getMostDerivedType(DNT(), DNT()))); + + // Not related. Returns nothing. + EXPECT_FALSE(ASTNodeKind::anything().isBaseOf( + ASTNodeKind::getMostDerivedType(DNT(), DNT()))); + EXPECT_FALSE(ASTNodeKind::anything().isBaseOf( + ASTNodeKind::getMostDerivedType(DNT(), DNT()))); +} + +TEST(ASTNodeKind, MostDerivedCommonAnscestor) { + EXPECT_TRUE(DNT().isSame(ASTNodeKind::getMostDerivedCommonAnscestor( + DNT(), DNT()))); + EXPECT_TRUE(DNT().isSame(ASTNodeKind::getMostDerivedCommonAnscestor( + DNT(), DNT()))); + EXPECT_TRUE(DNT().isSame(ASTNodeKind::getMostDerivedCommonAnscestor( + DNT(), DNT()))); + + // A little related. Returns the ancestor. + EXPECT_TRUE( + DNT().isSame(ASTNodeKind::getMostDerivedCommonAnscestor( + DNT(), DNT()))); + + // Not related. Returns anything. + EXPECT_TRUE( + ASTNodeKind::anything().isSame(ASTNodeKind::getMostDerivedCommonAnscestor( + DNT(), DNT()))); +} + struct Foo {}; TEST(ASTNodeKind, UnknownKind) { @@ -71,7 +129,8 @@ EXPECT_EQ("Decl", DNT().asStringRef()); EXPECT_EQ("CallExpr", DNT().asStringRef()); EXPECT_EQ("ConstantArrayType", DNT().asStringRef()); - EXPECT_EQ("", ASTNodeKind().asStringRef()); + EXPECT_EQ("", ASTNodeKind().asStringRef()); + EXPECT_EQ("", ASTNodeKind::anything().asStringRef()); } TEST(DynTypedNode, DeclSourceRange) { Index: unittests/ASTMatchers/ASTMatchersTest.cpp =================================================================== --- unittests/ASTMatchers/ASTMatchersTest.cpp +++ unittests/ASTMatchers/ASTMatchersTest.cpp @@ -460,6 +460,11 @@ EXPECT_TRUE(matches("class U {};", XOrYOrZOrUOrV)); EXPECT_TRUE(matches("class V {};", XOrYOrZOrUOrV)); EXPECT_TRUE(notMatches("class A {};", XOrYOrZOrUOrV)); + + StatementMatcher MixedTypes = stmt(anyOf(ifStmt(), binaryOperator())); + EXPECT_TRUE(matches("int F() { return 1 + 2; }", MixedTypes)); + EXPECT_TRUE(matches("int F() { if (true) return 1; }", MixedTypes)); + EXPECT_TRUE(notMatches("int F() { return 1; }", MixedTypes)); } TEST(DeclarationMatcher, MatchHas) { Index: unittests/ASTMatchers/Dynamic/RegistryTest.cpp =================================================================== --- unittests/ASTMatchers/Dynamic/RegistryTest.cpp +++ unittests/ASTMatchers/Dynamic/RegistryTest.cpp @@ -347,7 +347,7 @@ "anyOf", constructMatcher("recordDecl", constructMatcher("hasName", std::string("Foo"))), - constructMatcher("namedDecl", + constructMatcher("functionDecl", constructMatcher("hasName", std::string("foo")))) .getTypedMatcher();