diff --git a/clang/include/clang/AST/ASTFwd.h b/clang/include/clang/AST/ASTFwd.h --- a/clang/include/clang/AST/ASTFwd.h +++ b/clang/include/clang/AST/ASTFwd.h @@ -33,6 +33,7 @@ class Attr; #define ATTR(A) class A##Attr; #include "clang/Basic/AttrList.inc" +class ObjCProtocolLoc; } // end namespace clang diff --git a/clang/include/clang/AST/ASTTypeTraits.h b/clang/include/clang/AST/ASTTypeTraits.h --- a/clang/include/clang/AST/ASTTypeTraits.h +++ b/clang/include/clang/AST/ASTTypeTraits.h @@ -160,6 +160,7 @@ NKI_Attr, #define ATTR(A) NKI_##A##Attr, #include "clang/Basic/AttrList.inc" + NKI_ObjCProtocolLoc, NKI_NumberOfKinds }; @@ -213,6 +214,7 @@ KIND_TO_KIND_ID(Type) KIND_TO_KIND_ID(OMPClause) KIND_TO_KIND_ID(Attr) +KIND_TO_KIND_ID(ObjCProtocolLoc) KIND_TO_KIND_ID(CXXBaseSpecifier) #define DECL(DERIVED, BASE) KIND_TO_KIND_ID(DERIVED##Decl) #include "clang/AST/DeclNodes.inc" @@ -499,7 +501,7 @@ /// have storage or unique pointers and thus need to be stored by value. llvm::AlignedCharArrayUnion + QualType, TypeLoc, ObjCProtocolLoc> Storage; }; @@ -570,6 +572,10 @@ struct DynTypedNode::BaseConverter : public PtrConverter {}; +template <> +struct DynTypedNode::BaseConverter + : public ValueConverter {}; + // The only operation we allow on unsupported types is \c get. // This allows to conveniently use \c DynTypedNode when having an arbitrary // AST node that is not supported, but prevents misuse - a user cannot create diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -324,6 +324,12 @@ /// \returns false if the visitation was terminated early, true otherwise. bool TraverseConceptReference(const ConceptReference &C); + /// Recursively visit an Objective-C protocol reference with location + /// information. + /// + /// \returns false if the visitation was terminated early, true otherwise. + bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLoc); + // ---- Methods on Attrs ---- // Visit an attribute. @@ -1340,7 +1346,12 @@ DEF_TRAVERSE_TYPELOC(PackExpansionType, { TRY_TO(TraverseTypeLoc(TL.getPatternLoc())); }) -DEF_TRAVERSE_TYPELOC(ObjCTypeParamType, {}) +DEF_TRAVERSE_TYPELOC(ObjCTypeParamType, { + for (unsigned I = 0, N = TL.getNumProtocols(); I != N; ++I) { + ObjCProtocolLoc ProtocolLoc(TL.getProtocol(I), TL.getProtocolLoc(I)); + TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc)); + } +}) DEF_TRAVERSE_TYPELOC(ObjCInterfaceType, {}) @@ -1351,6 +1362,10 @@ TRY_TO(TraverseTypeLoc(TL.getBaseLoc())); for (unsigned i = 0, n = TL.getNumTypeArgs(); i != n; ++i) TRY_TO(TraverseTypeLoc(TL.getTypeArgTInfo(i)->getTypeLoc())); + for (unsigned I = 0, N = TL.getNumProtocols(); I != N; ++I) { + ObjCProtocolLoc ProtocolLoc(TL.getProtocol(I), TL.getProtocolLoc(I)); + TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc)); + } }) DEF_TRAVERSE_TYPELOC(ObjCObjectPointerType, @@ -1541,12 +1556,16 @@ DEF_TRAVERSE_DECL(ObjCCompatibleAliasDecl, {// FIXME: implement }) -DEF_TRAVERSE_DECL(ObjCCategoryDecl, {// FIXME: implement +DEF_TRAVERSE_DECL(ObjCCategoryDecl, { if (ObjCTypeParamList *typeParamList = D->getTypeParamList()) { for (auto typeParam : *typeParamList) { TRY_TO(TraverseObjCTypeParamDecl(typeParam)); } } + for (auto It : llvm::zip(D->protocols(), D->protocol_locs())) { + ObjCProtocolLoc ProtocolLoc(std::get<0>(It), std::get<1>(It)); + TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc)); + } }) DEF_TRAVERSE_DECL(ObjCCategoryImplDecl, {// FIXME: implement @@ -1555,7 +1574,7 @@ DEF_TRAVERSE_DECL(ObjCImplementationDecl, {// FIXME: implement }) -DEF_TRAVERSE_DECL(ObjCInterfaceDecl, {// FIXME: implement +DEF_TRAVERSE_DECL(ObjCInterfaceDecl, { if (ObjCTypeParamList *typeParamList = D->getTypeParamListAsWritten()) { for (auto typeParam : *typeParamList) { TRY_TO(TraverseObjCTypeParamDecl(typeParam)); @@ -1565,10 +1584,22 @@ if (TypeSourceInfo *superTInfo = D->getSuperClassTInfo()) { TRY_TO(TraverseTypeLoc(superTInfo->getTypeLoc())); } + if (D->isThisDeclarationADefinition()) { + for (auto It : llvm::zip(D->protocols(), D->protocol_locs())) { + ObjCProtocolLoc ProtocolLoc(std::get<0>(It), std::get<1>(It)); + TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc)); + } + } }) -DEF_TRAVERSE_DECL(ObjCProtocolDecl, {// FIXME: implement - }) +DEF_TRAVERSE_DECL(ObjCProtocolDecl, { + if (D->isThisDeclarationADefinition()) { + for (auto It : llvm::zip(D->protocols(), D->protocol_locs())) { + ObjCProtocolLoc ProtocolLoc(std::get<0>(It), std::get<1>(It)); + TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc)); + } + } +}) DEF_TRAVERSE_DECL(ObjCMethodDecl, { if (D->getReturnTypeSourceInfo()) { @@ -2409,6 +2440,12 @@ return true; } +template +bool RecursiveASTVisitor::TraverseObjCProtocolLoc( + ObjCProtocolLoc ProtocolLoc) { + return true; +} + // If shouldVisitImplicitCode() returns false, this method traverses only the // syntactic form of InitListExpr. // If shouldVisitImplicitCode() return true, this method is called once for diff --git a/clang/include/clang/AST/TypeLoc.h b/clang/include/clang/AST/TypeLoc.h --- a/clang/include/clang/AST/TypeLoc.h +++ b/clang/include/clang/AST/TypeLoc.h @@ -2607,6 +2607,22 @@ : public InheritingConcreteTypeLoc {}; +class ObjCProtocolLoc { + ObjCProtocolDecl *Protocol = nullptr; + SourceLocation Loc = SourceLocation(); + +public: + ObjCProtocolLoc(ObjCProtocolDecl *protocol, SourceLocation loc) + : Protocol(protocol), Loc(loc) {} + ObjCProtocolDecl *getProtocol() const { return Protocol; } + SourceLocation getLocation() const { return Loc; } + + /// The source range is just the protocol name. + SourceRange getSourceRange() const LLVM_READONLY { + return SourceRange(Loc, Loc); + } +}; + } // namespace clang #endif // LLVM_CLANG_AST_TYPELOC_H diff --git a/clang/lib/AST/ASTTypeTraits.cpp b/clang/lib/AST/ASTTypeTraits.cpp --- a/clang/lib/AST/ASTTypeTraits.cpp +++ b/clang/lib/AST/ASTTypeTraits.cpp @@ -16,6 +16,7 @@ #include "clang/AST/ASTContext.h" #include "clang/AST/Attr.h" #include "clang/AST/DeclCXX.h" +#include "clang/AST/DeclObjC.h" #include "clang/AST/NestedNameSpecifier.h" #include "clang/AST/OpenMPClause.h" #include "clang/AST/TypeLoc.h" @@ -52,6 +53,7 @@ {NKI_None, "Attr"}, #define ATTR(A) {NKI_Attr, #A "Attr"}, #include "clang/Basic/AttrList.inc" + {NKI_None, "ObjCProtocolLoc"}, }; bool ASTNodeKind::isBaseOf(ASTNodeKind Other, unsigned *Distance) const { @@ -193,6 +195,8 @@ QualType(T, 0).print(OS, PP); else if (const Attr *A = get()) A->printPretty(OS, PP); + else if (const ObjCProtocolLoc *P = get()) + P->getProtocol()->print(OS, PP); else OS << "Unable to print values of type " << NodeKind.asStringRef() << "\n"; } @@ -228,5 +232,7 @@ return CBS->getSourceRange(); if (const auto *A = get()) return A->getRange(); + if (const ObjCProtocolLoc *P = get()) + return P->getSourceRange(); return SourceRange(); } diff --git a/clang/lib/AST/ParentMapContext.cpp b/clang/lib/AST/ParentMapContext.cpp --- a/clang/lib/AST/ParentMapContext.cpp +++ b/clang/lib/AST/ParentMapContext.cpp @@ -330,6 +330,9 @@ DynTypedNode createDynTypedNode(const NestedNameSpecifierLoc &Node) { return DynTypedNode::create(Node); } +template <> DynTypedNode createDynTypedNode(const ObjCProtocolLoc &Node) { + return DynTypedNode::create(Node); +} /// @} /// A \c RecursiveASTVisitor that builds a map from nodes to their @@ -398,11 +401,14 @@ } } + template static bool isNull(T Node) { return !Node; } + static bool isNull(ObjCProtocolLoc Node) { return false; } + template bool TraverseNode(T Node, MapNodeTy MapNode, BaseTraverseFn BaseTraverse, MapTy *Parents) { - if (!Node) + if (isNull(Node)) return true; addParent(MapNode, Parents); ParentStack.push_back(createDynTypedNode(Node)); @@ -433,6 +439,12 @@ AttrNode, AttrNode, [&] { return VisitorBase::TraverseAttr(AttrNode); }, &Map.PointerParents); } + bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLocNode) { + return TraverseNode( + ProtocolLocNode, DynTypedNode::create(ProtocolLocNode), + [&] { return VisitorBase::TraverseObjCProtocolLoc(ProtocolLocNode); }, + &Map.OtherParents); + } // Using generic TraverseNode for Stmt would prevent data-recursion. bool dataTraverseStmtPre(Stmt *StmtNode) { diff --git a/clang/unittests/AST/RecursiveASTVisitorTest.cpp b/clang/unittests/AST/RecursiveASTVisitorTest.cpp --- a/clang/unittests/AST/RecursiveASTVisitorTest.cpp +++ b/clang/unittests/AST/RecursiveASTVisitorTest.cpp @@ -60,6 +60,12 @@ EndTraverseEnum, StartTraverseTypedefType, EndTraverseTypedefType, + StartTraverseObjCInterface, + EndTraverseObjCInterface, + StartTraverseObjCProtocol, + EndTraverseObjCProtocol, + StartTraverseObjCProtocolLoc, + EndTraverseObjCProtocolLoc, }; class CollectInterestingEvents @@ -97,18 +103,43 @@ return Ret; } + bool TraverseObjCInterfaceDecl(ObjCInterfaceDecl *ID) { + Events.push_back(VisitEvent::StartTraverseObjCInterface); + bool Ret = RecursiveASTVisitor::TraverseObjCInterfaceDecl(ID); + Events.push_back(VisitEvent::EndTraverseObjCInterface); + + return Ret; + } + + bool TraverseObjCProtocolDecl(ObjCProtocolDecl *PD) { + Events.push_back(VisitEvent::StartTraverseObjCProtocol); + bool Ret = RecursiveASTVisitor::TraverseObjCProtocolDecl(PD); + Events.push_back(VisitEvent::EndTraverseObjCProtocol); + + return Ret; + } + + bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLoc) { + Events.push_back(VisitEvent::StartTraverseObjCProtocolLoc); + bool Ret = RecursiveASTVisitor::TraverseObjCProtocolLoc(ProtocolLoc); + Events.push_back(VisitEvent::EndTraverseObjCProtocolLoc); + + return Ret; + } + std::vector takeEvents() && { return std::move(Events); } private: std::vector Events; }; -std::vector collectEvents(llvm::StringRef Code) { +std::vector collectEvents(llvm::StringRef Code, + const Twine &FileName = "input.cc") { CollectInterestingEvents Visitor; clang::tooling::runToolOnCode( std::make_unique( [&](clang::ASTContext &Ctx) { Visitor.TraverseAST(Ctx); }), - Code); + Code, FileName); return std::move(Visitor).takeEvents(); } } // namespace @@ -139,3 +170,28 @@ VisitEvent::EndTraverseTypedefType, VisitEvent::EndTraverseEnum)); } + +TEST(RecursiveASTVisitorTest, InterfaceDeclWithProtocols) { + // Check interface and its protocols are visited. + llvm::StringRef Code = R"cpp( + @protocol Foo + @end + @protocol Bar + @end + + @interface SomeObject + @end + )cpp"; + + EXPECT_THAT(collectEvents(Code, "input.m"), + ElementsAre(VisitEvent::StartTraverseObjCProtocol, + VisitEvent::EndTraverseObjCProtocol, + VisitEvent::StartTraverseObjCProtocol, + VisitEvent::EndTraverseObjCProtocol, + VisitEvent::StartTraverseObjCInterface, + VisitEvent::StartTraverseObjCProtocolLoc, + VisitEvent::EndTraverseObjCProtocolLoc, + VisitEvent::StartTraverseObjCProtocolLoc, + VisitEvent::EndTraverseObjCProtocolLoc, + VisitEvent::EndTraverseObjCInterface)); +}