diff --git a/clang-tools-extra/clangd/index/SymbolCollector.h b/clang-tools-extra/clangd/index/SymbolCollector.h --- a/clang-tools-extra/clangd/index/SymbolCollector.h +++ b/clang-tools-extra/clangd/index/SymbolCollector.h @@ -110,6 +110,7 @@ SymbolSlab takeSymbols() { return std::move(Symbols).build(); } RefSlab takeRefs() { return std::move(Refs).build(); } + RelationSlab takeRelations() { return std::move(Relations).build(); } void finish() override; @@ -117,6 +118,8 @@ const Symbol *addDeclaration(const NamedDecl &, SymbolID, bool IsMainFileSymbol); void addDefinition(const NamedDecl &, const Symbol &DeclSymbol); + void processRelations(const NamedDecl &ND, const SymbolID &ID, + ArrayRef Relations); llvm::Optional getIncludeHeader(llvm::StringRef QName, FileID); bool isSelfContainedHeader(FileID); @@ -135,6 +138,8 @@ // Only symbols declared in preamble (from #include) and referenced from the // main file will be included. RefSlab::Builder Refs; + // All relations collected from the AST. + RelationSlab::Builder Relations; ASTContext *ASTCtx; std::shared_ptr PP; std::shared_ptr CompletionAllocator; diff --git a/clang-tools-extra/clangd/index/SymbolCollector.cpp b/clang-tools-extra/clangd/index/SymbolCollector.cpp --- a/clang-tools-extra/clangd/index/SymbolCollector.cpp +++ b/clang-tools-extra/clangd/index/SymbolCollector.cpp @@ -193,6 +193,11 @@ return static_cast(static_cast(RefKind::All) & Roles); } +bool shouldIndexRelation(const index::SymbolRelation &R) { + // We currently only index BaseOf relations, for type hierarchy subtypes. + return R.Roles & static_cast(index::SymbolRole::RelationBaseOf); +} + } // namespace SymbolCollector::SymbolCollector(Options Opts) : Opts(std::move(Opts)) {} @@ -291,6 +296,12 @@ SM.getFileID(SpellingLoc) == SM.getMainFileID()) ReferencedDecls.insert(ND); + auto ID = getSymbolID(ND); + if (!ID) + return true; + + processRelations(*ND, *ID, Relations); + bool CollectRef = static_cast(Opts.RefFilter) & Roles; bool IsOnlyRef = !(Roles & (static_cast(index::SymbolRole::Declaration) | @@ -315,10 +326,6 @@ if (IsOnlyRef) return true; - auto ID = getSymbolID(ND); - if (!ID) - return true; - // FIXME: ObjCPropertyDecl are not properly indexed here: // - ObjCPropertyDecl may have an OrigD of ObjCPropertyImplDecl, which is // not a NamedDecl. @@ -338,6 +345,7 @@ if (Roles & static_cast(index::SymbolRole::Definition)) addDefinition(*OriginalDecl, *BasicSymbol); + return true; } @@ -416,8 +424,39 @@ return true; } -void SymbolCollector::setIncludeLocation(const Symbol &S, - SourceLocation Loc) { +void SymbolCollector::processRelations( + const NamedDecl &ND, const SymbolID &ID, + ArrayRef Relations) { + // Store subtype relations. + if (!dyn_cast(&ND)) + return; + + for (const auto &R : Relations) { + if (!shouldIndexRelation(R)) { + continue; + } + const Decl *Object = R.RelatedSymbol; + + auto ObjectID = getSymbolID(Object); + if (!ObjectID) { + continue; + } + + // Record the relation. + // TODO: There may be cases where the object decl is not indexed for some + // reason. Those cases should probably be removed in due course, but for + // now there are two possible ways to handle it: + // (A) Avoid storing the relation in such cases. + // (B) Store it anyways. Clients will likely lookup() the SymbolID + // in the index and find nothing, but that's a situation they + // probably need to handle for other reasons anyways. + // We currently do (B) because it's simpler. + this->Relations.insert( + Relation{ID, index::SymbolRole::RelationBaseOf, *ObjectID}); + } +} + +void SymbolCollector::setIncludeLocation(const Symbol &S, SourceLocation Loc) { if (Opts.CollectIncludePath) if (shouldCollectIncludePath(S.SymInfo.Kind)) // Use the expansion location to get the #include header since this is @@ -681,7 +720,7 @@ if (!Line.consume_front("#")) return false; Line = Line.ltrim(); - if (! Line.startswith("error")) + if (!Line.startswith("error")) return false; return Line.contains_lower("includ"); // Matches "include" or "including". } @@ -689,7 +728,7 @@ bool SymbolCollector::isDontIncludeMeHeader(llvm::StringRef Content) { llvm::StringRef Line; // Only sniff up to 100 lines or 10KB. - Content = Content.take_front(100*100); + Content = Content.take_front(100 * 100); for (unsigned I = 0; I < 100 && !Content.empty(); ++I) { std::tie(Line, Content) = Content.split('\n'); if (isIf(Line) && isErrorAboutInclude(Content.split('\n').first)) diff --git a/clang-tools-extra/clangd/unittests/SymbolCollectorTests.cpp b/clang-tools-extra/clangd/unittests/SymbolCollectorTests.cpp --- a/clang-tools-extra/clangd/unittests/SymbolCollectorTests.cpp +++ b/clang-tools-extra/clangd/unittests/SymbolCollectorTests.cpp @@ -123,8 +123,9 @@ assert(AST.hasValue()); const NamedDecl &ND = Qualified ? findDecl(*AST, Name) : findUnqualifiedDecl(*AST, Name); - const SourceManager& SM = AST->getSourceManager(); - bool MainFile = SM.isWrittenInMainFile(SM.getExpansionLoc(ND.getBeginLoc())); + const SourceManager &SM = AST->getSourceManager(); + bool MainFile = + SM.isWrittenInMainFile(SM.getExpansionLoc(ND.getBeginLoc())); return SymbolCollector::shouldCollectSymbol( ND, AST->getASTContext(), SymbolCollector::Options(), MainFile); } @@ -272,13 +273,14 @@ Args, Factory->create(), Files.get(), std::make_shared()); - InMemoryFileSystem->addFile( - TestHeaderName, 0, llvm::MemoryBuffer::getMemBuffer(HeaderCode)); + InMemoryFileSystem->addFile(TestHeaderName, 0, + llvm::MemoryBuffer::getMemBuffer(HeaderCode)); InMemoryFileSystem->addFile(TestFileName, 0, llvm::MemoryBuffer::getMemBuffer(MainCode)); Invocation.run(); Symbols = Factory->Collector->takeSymbols(); Refs = Factory->Collector->takeRefs(); + Relations = Factory->Collector->takeRelations(); return true; } @@ -290,6 +292,7 @@ std::string TestFileURI; SymbolSlab Symbols; RefSlab Refs; + RelationSlab Relations; SymbolCollector::Options CollectorOpts; std::unique_ptr PragmaHandler; }; @@ -634,6 +637,19 @@ HaveRanges(Header.ranges())))); } +TEST_F(SymbolCollectorTest, Relations) { + std::string Header = R"( + class Base {}; + class Derived : public Base {}; + )"; + runSymbolCollector(Header, /*Main=*/""); + const Symbol &Base = findSymbol(Symbols, "Base"); + const Symbol &Derived = findSymbol(Symbols, "Derived"); + EXPECT_THAT(Relations, + Contains(Relation{Base.ID, index::SymbolRole::RelationBaseOf, + Derived.ID})); +} + TEST_F(SymbolCollectorTest, References) { const std::string Header = R"( class W; @@ -783,10 +799,9 @@ void f1() {} )"; runSymbolCollector(/*Header=*/"", Main); - EXPECT_THAT(Symbols, - UnorderedElementsAre(QName("Foo"), QName("f1"), QName("f2"), - QName("ff"), QName("foo"), QName("foo::Bar"), - QName("main_f"))); + EXPECT_THAT(Symbols, UnorderedElementsAre( + QName("Foo"), QName("f1"), QName("f2"), QName("ff"), + QName("foo"), QName("foo::Bar"), QName("main_f"))); } TEST_F(SymbolCollectorTest, Documentation) {