Index: clangd/index/SymbolCollector.cpp =================================================================== --- clangd/index/SymbolCollector.cpp +++ clangd/index/SymbolCollector.cpp @@ -345,16 +345,20 @@ SM.getFileID(SpellingLoc) == SM.getMainFileID()) ReferencedDecls.insert(ND); - if ((static_cast(Opts.RefFilter) & Roles) && - SM.getFileID(SpellingLoc) == SM.getMainFileID()) - DeclRefs[ND].emplace_back(SpellingLoc, Roles); + bool CollectRef = static_cast(Opts.RefFilter) & Roles; + bool IsOnlyRef = + !(Roles & (static_cast(index::SymbolRole::Declaration) | + static_cast(index::SymbolRole::Definition))); - // Don't continue indexing if this is a mere reference. - if (!(Roles & static_cast(index::SymbolRole::Declaration) || - Roles & static_cast(index::SymbolRole::Definition))) + if (IsOnlyRef && !CollectRef) return true; if (!shouldCollectSymbol(*ND, *ASTCtx, Opts)) return true; + if (CollectRef && SM.getFileID(SpellingLoc) == SM.getMainFileID()) + DeclRefs[ND].emplace_back(SpellingLoc, Roles); + // Don't continue indexing if this is a mere reference. + if (IsOnlyRef) + return true; auto ID = getSymbolID(ND); if (!ID) @@ -476,17 +480,15 @@ std::string MainURI = *MainFileURI; for (const auto &It : DeclRefs) { if (auto ID = getSymbolID(It.first)) { - if (Symbols.find(*ID)) { - for (const auto &LocAndRole : It.second) { - Ref R; - auto Range = - getTokenRange(LocAndRole.first, SM, ASTCtx->getLangOpts()); - R.Location.Start = Range.first; - R.Location.End = Range.second; - R.Location.FileURI = MainURI; - R.Kind = toRefKind(LocAndRole.second); - Refs.insert(*ID, R); - } + for (const auto &LocAndRole : It.second) { + Ref R; + auto Range = + getTokenRange(LocAndRole.first, SM, ASTCtx->getLangOpts()); + R.Location.Start = Range.first; + R.Location.End = Range.second; + R.Location.FileURI = MainURI; + R.Kind = toRefKind(LocAndRole.second); + Refs.insert(*ID, R); } } } Index: unittests/clangd/FileIndexTests.cpp =================================================================== --- unittests/clangd/FileIndexTests.cpp +++ unittests/clangd/FileIndexTests.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "Annotations.h" +#include "AST.h" #include "ClangdUnit.h" #include "TestFS.h" #include "TestTU.h" @@ -15,6 +16,7 @@ #include "index/FileIndex.h" #include "clang/Frontend/CompilerInvocation.h" #include "clang/Frontend/PCHContainerOperations.h" +#include "clang/Frontend/Utils.h" #include "clang/Index/IndexSymbol.h" #include "clang/Lex/Preprocessor.h" #include "clang/Tooling/CompilationDatabase.h" @@ -346,6 +348,55 @@ EXPECT_TRUE(SeenSymbol); } +TEST(FileIndexTest, ReferencesInMainFileWithPreamble) { + const std::string Header = R"cpp( + class Foo {}; + )cpp"; + Annotations Main(R"cpp( + #include "foo.h" + void f() { + [[Foo]] foo; + } + )cpp"); + auto MainFile = testPath("foo.cpp"); + auto HeaderFile = testPath("foo.h"); + std::vector Cmd = {"clang", "-xc++", MainFile.c_str()}; + // Preparse ParseInputs. + ParseInputs PI; + PI.CompileCommand.Directory = testRoot(); + PI.CompileCommand.Filename = MainFile; + PI.CompileCommand.CommandLine = {Cmd.begin(), Cmd.end()}; + PI.Contents = Main.code(); + PI.FS = buildTestFS({{MainFile, Main.code()}, {HeaderFile, Header}}); + + // Prepare preamble. + auto CI = buildCompilerInvocation(PI); + auto PreambleData = buildPreamble( + MainFile, + *buildCompilerInvocation(PI), /*OldPreamble=*/nullptr, + tooling::CompileCommand(), PI, + std::make_shared(), /*StoreInMemory=*/true, + [&](ASTContext &Ctx, std::shared_ptr PP) {}); + // Build AST for main file with preamble. + auto AST = ParsedAST::build( + createInvocationFromCommandLine(Cmd), PreambleData, + llvm::MemoryBuffer::getMemBufferCopy(Main.code()), + std::make_shared(), + PI.FS); + ASSERT_TRUE(AST); + FileIndex Index; + Index.updateMain(MainFile, *AST); + + auto Foo = + findSymbol(TestTU::withHeaderCode(Header).headerSymbols(), "Foo"); + RefsRequest Request; + Request.IDs.insert(Foo.ID); + + // Expect to see references in main file, references in headers are excluded + // because we only index main AST. + EXPECT_THAT(getRefs(Index, Foo.ID), RefsAre({RefRange(Main.range())})); +} + } // namespace } // namespace clangd } // namespace clang