diff --git a/clang/include/clang/AST/ASTTypeTraits.h b/clang/include/clang/AST/ASTTypeTraits.h --- a/clang/include/clang/AST/ASTTypeTraits.h +++ b/clang/include/clang/AST/ASTTypeTraits.h @@ -77,10 +77,13 @@ /// Returns \c true only for the default \c ASTNodeKind() constexpr bool isNone() const { return KindId == NKI_None; } + /// Returns \c true if \c this is a base kind of (or same as) \c Other. + bool isBaseOf(ASTNodeKind Other) const; + /// Returns \c true if \c this is a base kind of (or same as) \c Other. /// \param Distance If non-null, used to return the distance between \c this /// and \c Other in the class hierarchy. - bool isBaseOf(ASTNodeKind Other, unsigned *Distance = nullptr) const; + bool isBaseOf(ASTNodeKind Other, unsigned *Distance) const; /// String representation of the kind. StringRef asStringRef() const; @@ -166,6 +169,10 @@ /// Use getFromNodeKind() to construct the kind. constexpr ASTNodeKind(NodeKindId KindId) : KindId(KindId) {} + /// Returns \c true if \c Base is a base kind of (or same as) \c + /// Derived. + static bool isBaseOf(NodeKindId Base, NodeKindId Derived); + /// Returns \c true if \c Base is a base kind of (or same as) \c /// Derived. /// \param Distance If non-null, used to return the distance between \c Base diff --git a/clang/lib/AST/ASTTypeTraits.cpp b/clang/lib/AST/ASTTypeTraits.cpp --- a/clang/lib/AST/ASTTypeTraits.cpp +++ b/clang/lib/AST/ASTTypeTraits.cpp @@ -56,10 +56,23 @@ {NKI_None, "ObjCProtocolLoc"}, }; +bool ASTNodeKind::isBaseOf(ASTNodeKind Other) const { + return isBaseOf(KindId, Other.KindId); +} + bool ASTNodeKind::isBaseOf(ASTNodeKind Other, unsigned *Distance) const { return isBaseOf(KindId, Other.KindId, Distance); } +bool ASTNodeKind::isBaseOf(NodeKindId Base, NodeKindId Derived) { + if (Base == NKI_None || Derived == NKI_None) + return false; + while (Derived != Base && Derived != NKI_None) { + Derived = AllKindInfo[Derived].ParentId; + } + return Derived == Base; +} + bool ASTNodeKind::isBaseOf(NodeKindId Base, NodeKindId Derived, unsigned *Distance) { if (Base == NKI_None || Derived == NKI_None) return false; @@ -96,7 +109,7 @@ ASTNodeKind ASTNodeKind::getMostDerivedCommonAncestor(ASTNodeKind Kind1, ASTNodeKind Kind2) { NodeKindId Parent = Kind1.KindId; - while (!isBaseOf(Parent, Kind2.KindId, nullptr) && Parent != NKI_None) { + while (!isBaseOf(Parent, Kind2.KindId) && Parent != NKI_None) { Parent = AllKindInfo[Parent].ParentId; } return ASTNodeKind(Parent);