diff --git a/clang-tools-extra/clangd/unittests/SelectionTests.cpp b/clang-tools-extra/clangd/unittests/SelectionTests.cpp --- a/clang-tools-extra/clangd/unittests/SelectionTests.cpp +++ b/clang-tools-extra/clangd/unittests/SelectionTests.cpp @@ -473,7 +473,19 @@ [[@property(retain, nonnull) <:[My^Object2]:> *x]]; // error-ok @end )cpp", - "ObjCPropertyDecl"}}; + "ObjCPropertyDecl"}, + + {R"cpp( + typedef int Foo; + enum Bar : [[Fo^o]] {}; + )cpp", + "TypedefTypeLoc"}, + {R"cpp( + typedef int Foo; + enum Bar : [[Fo^o]]; + )cpp", + "TypedefTypeLoc"}, + }; for (const Case &C : Cases) { trace::TestTracer Tracer; diff --git a/clang-tools-extra/clangd/unittests/SemanticHighlightingTests.cpp b/clang-tools-extra/clangd/unittests/SemanticHighlightingTests.cpp --- a/clang-tools-extra/clangd/unittests/SemanticHighlightingTests.cpp +++ b/clang-tools-extra/clangd/unittests/SemanticHighlightingTests.cpp @@ -780,6 +780,16 @@ $LocalVariable_decl[[d]]($LocalVariable[[b]]) ]() {}(); } )cpp", + // Enum base specifier + R"cpp( + using $Primitive_decl[[MyTypedef]] = int; + enum $Enum_decl[[MyEnum]] : $Primitive[[MyTypedef]] {}; + )cpp", + // Enum base specifier + R"cpp( + typedef int $Primitive_decl[[MyTypedef]]; + enum $Enum_decl[[MyEnum]] : $Primitive[[MyTypedef]] {}; + )cpp", }; for (const auto &TestCase : TestCases) // Mask off scope modifiers to keep the tests manageable. diff --git a/clang-tools-extra/clangd/unittests/XRefsTests.cpp b/clang-tools-extra/clangd/unittests/XRefsTests.cpp --- a/clang-tools-extra/clangd/unittests/XRefsTests.cpp +++ b/clang-tools-extra/clangd/unittests/XRefsTests.cpp @@ -880,6 +880,19 @@ }; )cpp", + R"cpp(// Enum base + typedef int $decl[[MyTypeDef]]; + enum Foo : My^TypeDef {}; + )cpp", + R"cpp(// Enum base + typedef int $decl[[MyTypeDef]]; + enum Foo : My^TypeDef; + )cpp", + R"cpp(// Enum base + using $decl[[MyTypeDef]] = int; + enum Foo : My^TypeDef {}; + )cpp", + R"objc( @protocol Dog; @protocol $decl[[Dog]] diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -1867,6 +1867,8 @@ TRY_TO(TraverseType(QualType(D->getTypeForDecl(), 0))); TRY_TO(TraverseNestedNameSpecifierLoc(D->getQualifierLoc())); + if (auto *TSI = D->getIntegerTypeSourceInfo()) + TRY_TO(TraverseTypeLoc(TSI->getTypeLoc())); // The enumerators are already traversed by // decls_begin()/decls_end(). }) diff --git a/clang/unittests/AST/RecursiveASTVisitorTest.cpp b/clang/unittests/AST/RecursiveASTVisitorTest.cpp --- a/clang/unittests/AST/RecursiveASTVisitorTest.cpp +++ b/clang/unittests/AST/RecursiveASTVisitorTest.cpp @@ -10,6 +10,8 @@ #include "clang/AST/ASTConsumer.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Attr.h" +#include "clang/AST/Decl.h" +#include "clang/AST/TypeLoc.h" #include "clang/Frontend/FrontendAction.h" #include "clang/Tooling/Tooling.h" #include "llvm/ADT/FunctionExtras.h" @@ -53,7 +55,11 @@ StartTraverseFunction, EndTraverseFunction, StartTraverseAttr, - EndTraverseAttr + EndTraverseAttr, + StartTraverseEnum, + EndTraverseEnum, + StartTraverseTypedefType, + EndTraverseTypedefType, }; class CollectInterestingEvents @@ -75,6 +81,22 @@ return Ret; } + bool TraverseEnumDecl(EnumDecl *D) { + Events.push_back(VisitEvent::StartTraverseEnum); + bool Ret = RecursiveASTVisitor::TraverseEnumDecl(D); + Events.push_back(VisitEvent::EndTraverseEnum); + + return Ret; + } + + bool TraverseTypedefTypeLoc(TypedefTypeLoc TL) { + Events.push_back(VisitEvent::StartTraverseTypedefType); + bool Ret = RecursiveASTVisitor::TraverseTypedefTypeLoc(TL); + Events.push_back(VisitEvent::EndTraverseTypedefType); + + return Ret; + } + std::vector takeEvents() && { return std::move(Events); } private: @@ -103,3 +125,17 @@ VisitEvent::EndTraverseAttr, VisitEvent::EndTraverseFunction)); } + +TEST(RecursiveASTVisitorTest, EnumDeclWithBase) { + /// Check attributes are traversed inside TraverseFunctionDecl. + llvm::StringRef Code = R"cpp( + typedef int Foo; + enum Bar : Foo; + )cpp"; + + EXPECT_THAT(collectEvents(Code), + ElementsAre(VisitEvent::StartTraverseEnum, + VisitEvent::StartTraverseTypedefType, + VisitEvent::EndTraverseTypedefType, + VisitEvent::EndTraverseEnum)); +}