Index: clangd/XRefs.h =================================================================== --- clangd/XRefs.h +++ clangd/XRefs.h @@ -34,6 +34,9 @@ /// Get the hover information when hovering at \p Pos. llvm::Optional getHover(ParsedAST &AST, Position Pos); +/// Returns reference locations of the symbol at a specified \p Pos. +std::vector findReferences(ParsedAST &AST, Position Pos, + const SymbolIndex *Index = nullptr); } // namespace clangd } // namespace clang Index: clangd/XRefs.cpp =================================================================== --- clangd/XRefs.cpp +++ clangd/XRefs.cpp @@ -174,30 +174,27 @@ return {DeclMacrosFinder.takeDecls(), DeclMacrosFinder.takeMacroInfos()}; } -llvm::Optional -makeLocation(ParsedAST &AST, const SourceRange &ValSourceRange) { +Range getTokenRange(ParsedAST &AST, SourceLocation TokLoc) { const SourceManager &SourceMgr = AST.getASTContext().getSourceManager(); - const LangOptions &LangOpts = AST.getASTContext().getLangOpts(); - SourceLocation LocStart = ValSourceRange.getBegin(); + SourceLocation LocEnd = Lexer::getLocForEndOfToken( + TokLoc, 0, SourceMgr, AST.getASTContext().getLangOpts()); + return {sourceLocToPosition(SourceMgr, TokLoc), + sourceLocToPosition(SourceMgr, LocEnd)}; +} - const FileEntry *F = - SourceMgr.getFileEntryForID(SourceMgr.getFileID(LocStart)); +llvm::Optional makeLocation(ParsedAST &AST, SourceLocation TokLoc) { + const SourceManager &SourceMgr = AST.getASTContext().getSourceManager(); + const FileEntry *F = SourceMgr.getFileEntryForID(SourceMgr.getFileID(TokLoc)); if (!F) return llvm::None; - SourceLocation LocEnd = Lexer::getLocForEndOfToken(ValSourceRange.getEnd(), 0, - SourceMgr, LangOpts); - Position Begin = sourceLocToPosition(SourceMgr, LocStart); - Position End = sourceLocToPosition(SourceMgr, LocEnd); - Range R = {Begin, End}; - Location L; - auto FilePath = getRealPath(F, SourceMgr); if (!FilePath) { log("failed to get path!"); return llvm::None; } + Location L; L.uri = URIForFile(*FilePath); - L.range = R; + L.range = getTokenRange(AST, TokLoc); return L; } @@ -223,7 +220,7 @@ for (auto Item : Symbols.Macros) { auto Loc = Item.Info->getDefinitionLoc(); - auto L = makeLocation(AST, SourceRange(Loc, Loc)); + auto L = makeLocation(AST, Loc); if (L) Result.push_back(*L); } @@ -266,7 +263,7 @@ auto &Candidate = ResultCandidates[Key]; auto Loc = findNameLoc(D); - auto L = makeLocation(AST, SourceRange(Loc, Loc)); + auto L = makeLocation(AST, Loc); // The declaration in the identified symbols is a definition if possible // otherwise it is declaration. bool IsDef = getDefinition(D) == D; @@ -316,64 +313,119 @@ namespace { -/// Finds document highlights that a given list of declarations refers to. -class DocumentHighlightsFinder : public index::IndexDataConsumer { - std::vector &Decls; - std::vector DocumentHighlights; - const ASTContext &AST; - +/// Collects all occurrences (and related information) in the main file that a +/// given whitelist of declarations refers to. +class OccurrenceCollector : public index::IndexDataConsumer { public: - DocumentHighlightsFinder(ASTContext &AST, Preprocessor &PP, - std::vector &Decls) - : Decls(Decls), AST(AST) {} - std::vector takeHighlights() { - // Don't keep the same highlight multiple times. - // This can happen when nodes in the AST are visited twice. - std::sort(DocumentHighlights.begin(), DocumentHighlights.end()); - auto Last = - std::unique(DocumentHighlights.begin(), DocumentHighlights.end()); - DocumentHighlights.erase(Last, DocumentHighlights.end()); - return std::move(DocumentHighlights); - } - bool handleDeclOccurence(const Decl *D, index::SymbolRoleSet Roles, ArrayRef Relations, SourceLocation Loc, index::IndexDataConsumer::ASTNodeInfo ASTNode) override { - const SourceManager &SourceMgr = AST.getSourceManager(); - SourceLocation HighlightStartLoc = SourceMgr.getFileLoc(Loc); - if (SourceMgr.getMainFileID() != SourceMgr.getFileID(HighlightStartLoc) || - std::find(Decls.begin(), Decls.end(), D) == Decls.end()) { + const SourceManager &SourceMgr = AST.getASTContext().getSourceManager(); + SourceLocation FileLoc = SourceMgr.getFileLoc(Loc); + // We only collect locations in current main file. + if (SourceMgr.getMainFileID() != SourceMgr.getFileID(FileLoc) || + std::find(WhitelistDecls.begin(), WhitelistDecls.end(), D) == + WhitelistDecls.end()) return true; + + DeclOccurrences[D].emplace_back(FileLoc, Roles); + return true; + } + +protected: + OccurrenceCollector(ParsedAST &AST, + const std::vector &WhitelistDecls) + : AST(AST), WhitelistDecls(WhitelistDecls) { + + auto &SM = AST.getASTContext().getSourceManager(); + auto MainFilePath = + getRealPath(SM.getFileEntryForID(SM.getMainFileID()), SM); + if (!MainFilePath) { + log("Fail to get real path!"); } - SourceLocation End; - const LangOptions &LangOpts = AST.getLangOpts(); - End = Lexer::getLocForEndOfToken(HighlightStartLoc, 0, SourceMgr, LangOpts); - SourceRange SR(HighlightStartLoc, End); + MainFileURI = URIForFile(*MainFilePath); + } - DocumentHighlightKind Kind = DocumentHighlightKind::Text; - if (static_cast(index::SymbolRole::Write) & Roles) - Kind = DocumentHighlightKind::Write; - else if (static_cast(index::SymbolRole::Read) & Roles) - Kind = DocumentHighlightKind::Read; + ParsedAST &AST; + const std::vector &WhitelistDecls; - DocumentHighlights.push_back(getDocumentHighlight(SR, Kind)); - return true; + URIForFile MainFileURI; + + using DeclOccurrence = std::pair; + llvm::DenseMap> DeclOccurrences; +}; + +/// Find symbol occurrences that a given whilelist of declarations refers to. +class OccurrencesFinder : public OccurrenceCollector { +public: + OccurrencesFinder(ParsedAST &AST, std::vector &WhitelistDecls) + : OccurrenceCollector(AST, WhitelistDecls) {} + + std::vector takeOccurrences() { return std::move(Occurrences); } + + void finish() override { + OccurrenceCollector::finish(); + Occurrences.clear(); + + for (auto It : DeclOccurrences) { + for (const auto &LocAndRole : It.second) { + Occurrences.push_back( + {MainFileURI, getTokenRange(AST, LocAndRole.first)}); + } + } + // Deduplicate results. + std::sort(Occurrences.begin(), Occurrences.end()); + auto Last = std::unique(Occurrences.begin(), Occurrences.end()); + Occurrences.erase(Last, Occurrences.end()); } private: - DocumentHighlight getDocumentHighlight(SourceRange SR, - DocumentHighlightKind Kind) { - const SourceManager &SourceMgr = AST.getSourceManager(); - Position Begin = sourceLocToPosition(SourceMgr, SR.getBegin()); - Position End = sourceLocToPosition(SourceMgr, SR.getEnd()); - Range R = {Begin, End}; + std::vector Occurrences; +}; + +/// Finds document highlights that a given whitelist of declarations refers to. +class DocumentHighlightsFinder : public OccurrenceCollector { +public: + DocumentHighlightsFinder(ParsedAST &AST, + const std::vector &WhitelistDecls) + : OccurrenceCollector(AST, WhitelistDecls) {} + + std::vector takeHighlights() { + return std::move(DocumentHighlights); + } + + void finish() override { + OccurrenceCollector::finish(); + DocumentHighlights.clear(); + for (auto It : DeclOccurrences) { + for (const auto &LocAndRole : It.second) + DocumentHighlights.push_back( + createHighlight(LocAndRole.first, LocAndRole.second)); + } + // Don't keep the same highlight multiple times. + // This can happen when nodes in the AST are visited twice. + std::sort(DocumentHighlights.begin(), DocumentHighlights.end()); + auto Last = + std::unique(DocumentHighlights.begin(), DocumentHighlights.end()); + DocumentHighlights.erase(Last, DocumentHighlights.end()); + } + + DocumentHighlight createHighlight(SourceLocation Loc, + index::SymbolRoleSet Roles) { DocumentHighlight DH; - DH.range = R; - DH.kind = Kind; + DH.range = getTokenRange(AST, Loc); + DH.kind = DocumentHighlightKind::Text; + if (static_cast(index::SymbolRole::Write) & Roles) + DH.kind = DocumentHighlightKind::Write; + else if (static_cast(index::SymbolRole::Read) & Roles) + DH.kind = DocumentHighlightKind::Read; return DH; } + +private: + std::vector DocumentHighlights; }; } // namespace @@ -387,8 +439,7 @@ auto Symbols = getSymbolAtPosition(AST, SourceLocationBeg); std::vector SelectedDecls = Symbols.Decls; - DocumentHighlightsFinder DocHighlightsFinder( - AST.getASTContext(), AST.getPreprocessor(), SelectedDecls); + DocumentHighlightsFinder DocHighlightsFinder(AST, SelectedDecls); index::IndexingOptions IndexOpts; IndexOpts.SystemSymbolFilter = @@ -659,5 +710,63 @@ return None; } +std::vector findReferences(ParsedAST &AST, Position Pos, + const SymbolIndex *Index) { + const SourceManager &SourceMgr = AST.getASTContext().getSourceManager(); + SourceLocation SourceLocationBeg = + getBeginningOfIdentifier(AST, Pos, SourceMgr.getMainFileID()); + // Identified symbols at a specific position. + auto Symbols = getSymbolAtPosition(AST, SourceLocationBeg); + + // For local symbols (e.g. symbols that are only visiable in main file, + // symbols defined in function body), we can get complete references by + // traversing the AST, and we don't want to make unnecessary queries to the + // index. Howerver, we don't have a reliable way to distinguish file-local + // symbols. We conservatively consider function-local symbols. + llvm::DenseSet PossiblyVisibleIDs; + for (const auto *D : Symbols.Decls) { + if (auto ID = getSymbolID(D)) { + // Ignore if it is a fucntion-scope symbol. + if (D->getParentFunctionOrMethod()) + continue; + PossiblyVisibleIDs.insert(*ID); + } + } + + // Look in the AST for the references from current main file. + OccurrencesFinder FindOccurrences(AST, Symbols.Decls); + index::IndexingOptions IndexOpts; + IndexOpts.SystemSymbolFilter = + index::IndexingOptions::SystemSymbolFilterKind::All; + IndexOpts.IndexFunctionLocals = true; + indexTopLevelDecls(AST.getASTContext(), AST.getLocalTopLevelDecls(), + FindOccurrences, IndexOpts); + + auto Results = FindOccurrences.takeOccurrences(); + + // Consult the index for references in other files. + // We only need to consider symbols visible to other files. + if (Index && !PossiblyVisibleIDs.empty()) { + const auto &SM = AST.getASTContext().getSourceManager(); + auto MainFilePath = + getRealPath(SM.getFileEntryForID(SourceMgr.getMainFileID()), SM); + if (!MainFilePath) { + log("Failed to get real path!"); + return Results; + } + auto MainFileURI = URIForFile(*MainFilePath); + OccurrencesRequest Request; + Request.IDs = std::move(PossiblyVisibleIDs); + Request.Filter = AllOccurrenceKinds; + Index->findOccurrences(Request, [&](const SymbolOccurrence &O) { + if (auto LSPLoc = toLSPLocation(O.Location, /*HintPath=*/*MainFilePath)) { + if (MainFileURI != LSPLoc->uri) + Results.push_back(*LSPLoc); + } + }); + } + return Results; +} + } // namespace clangd } // namespace clang Index: unittests/clangd/XRefsTests.cpp =================================================================== --- unittests/clangd/XRefsTests.cpp +++ unittests/clangd/XRefsTests.cpp @@ -26,6 +26,7 @@ using namespace llvm; namespace { +using testing::_; using testing::ElementsAre; using testing::Field; using testing::IsEmpty; @@ -1068,6 +1069,179 @@ ElementsAre(Location{FooCppUri, FooWithoutHeader.range()})); } +TEST(FindReferences, AllWithoutIndex) { + const char *Tests[] = { + R"cpp(// Local variable + int main() { + int $foo[[foo]]; + $foo[[^foo]] = 2; + int test1 = $foo[[foo]]; + } + )cpp", + + R"cpp(// Struct + namespace ns1 { + struct $foo[[Foo]] {}; + } // namespace ns1 + int main() { + ns1::$foo[[Fo^o]]* Params; + } + )cpp", + + R"cpp(// Function + int $foo[[foo]](int) {} + int main() { + auto *X = &$foo[[^foo]]; + $foo[[foo]](42) + } + )cpp", + + R"cpp(// Field + struct Foo { + int $foo[[foo]]; + Foo() : $foo[[foo]](0) {} + }; + int main() { + Foo f; + f.$foo[[f^oo]] = 1; + } + )cpp", + + R"cpp(// Method call + struct Foo { int [[foo]](); }; + int Foo::[[foo]]() {} + int main() { + Foo f; + f.^foo(); + } + )cpp", + + R"cpp(// Typedef + typedef int $foo[[Foo]]; + int main() { + $foo[[^Foo]] bar; + } + )cpp", + + R"cpp(// Namespace + namespace $foo[[ns]] { + struct Foo {}; + } // namespace ns + int main() { $foo[[^ns]]::Foo foo; } + )cpp", + }; + for (const char *Test : Tests) { + Annotations T(Test); + auto AST = TestTU::withCode(T.code()).build(); + std::vector> ExpectedLocations; + for (const auto &R : T.ranges("foo")) + ExpectedLocations.push_back(RangeIs(R)); + EXPECT_THAT(findReferences(AST, T.point()), + ElementsAreArray(ExpectedLocations)) + << Test; + } +} + +class MockIndex : public SymbolIndex { +public: + MOCK_CONST_METHOD2(fuzzyFind, bool(const FuzzyFindRequest &, + llvm::function_ref)); + MOCK_CONST_METHOD2(lookup, void(const LookupRequest &, + llvm::function_ref)); + MOCK_CONST_METHOD2(findOccurrences, + void(const OccurrencesRequest &, + llvm::function_ref)); + MOCK_CONST_METHOD0(estimateMemoryUsage, size_t()); +}; + +TEST(FindReferences, QueryIndex) { + const char *Tests[] = { + // Refers to symbols from headers. + R"cpp( + int main() { + F^oo foo; + } + )cpp", + R"cpp( + int main() { + f^unc(); + } + )cpp", + R"cpp( + int main() { + return I^NT; + } + )cpp", + + // These are cases of file-local but not function-local symbols, we still + // query the index. + R"cpp( + void MyF^unc() {} + )cpp", + + R"cpp( + int My^Int = 2; + )cpp", + }; + + TestTU TU; + TU.HeaderCode = R"( + class Foo {}; + static const int INT = 3; + inline void func() {}; + )"; + MockIndex Index; + for (const char *Test : Tests) { + Annotations T(Test); + TU.Code = T.code(); + auto AST = TU.build(); + EXPECT_CALL(Index, findOccurrences(_, _)); + findReferences(AST, T.point(), &Index); + } +} + +TEST(FindReferences, DontQueryIndex) { + // Don't query index for function-local symbols. + const char *Tests[] = { + R"cpp(// Local variable in function body + int main() { + int $foo[[foo]]; + $foo[[^foo]] = 2; + } + )cpp", + + R"cpp(// function parameter + int f(int fo^o) { + } + )cpp", + + R"cpp(// function parameter in lambda + int f(int foo) { + auto func = [](int a, int b) { + return ^a = 2; + }; + } + )cpp", + + R"cpp(// capture in lambda + int f(int foo) { + int A; + auto func = [&A](int a, int b) { + return a = ^A; + }; + } + )cpp", + }; + + MockIndex Index; + for (const char *Test : Tests) { + Annotations T(Test); + auto AST = TestTU::withCode(T.code()).build(); + EXPECT_CALL(Index, findOccurrences(_, _)).Times(0); + findReferences(AST, T.point(), &Index); + } +} + } // namespace } // namespace clangd } // namespace clang