diff --git a/clang-tools-extra/clangd/XRefs.h b/clang-tools-extra/clangd/XRefs.h --- a/clang-tools-extra/clangd/XRefs.h +++ b/clang-tools-extra/clangd/XRefs.h @@ -17,8 +17,8 @@ #include "Path.h" #include "Protocol.h" #include "index/Index.h" -#include "clang/AST/Type.h" #include "index/SymbolLocation.h" +#include "clang/AST/Type.h" #include "clang/Format/Format.h" #include "clang/Index/IndexSymbol.h" #include "llvm/ADT/Optional.h" @@ -158,6 +158,9 @@ /// SourceLocationBeg must point to the first character of the token bool hasDeducedType(ParsedAST &AST, SourceLocation SourceLocationBeg); +/// Returns all decls that are referenced in the \p FD except local symbols. +llvm::DenseSet getNonLocalDeclRefs(ParsedAST &AST, + const FunctionDecl *FD); } // namespace clangd } // namespace clang diff --git a/clang-tools-extra/clangd/XRefs.cpp b/clang-tools-extra/clangd/XRefs.cpp --- a/clang-tools-extra/clangd/XRefs.cpp +++ b/clang-tools-extra/clangd/XRefs.cpp @@ -1299,5 +1299,34 @@ return OS; } +llvm::DenseSet getNonLocalDeclRefs(ParsedAST &AST, + const FunctionDecl *FD) { + class DeclRecorder : public index::IndexDataConsumer { + public: + bool handleDeclOccurence(const Decl *D, index::SymbolRoleSet Roles, + ArrayRef Relations, + SourceLocation Loc, ASTNodeInfo ASTNode) override { + // We only care about referenced decls. + if (!(Roles & static_cast(index::SymbolRole::Reference))) + return true; + + // Store the original decl, e.g. specialization rather than templated + // decl. + DeclRefs.insert(ASTNode.OrigD); + return true; + } + + llvm::DenseSet DeclRefs; + }; + DeclRecorder Recorder; + index::IndexingOptions Opts; + Opts.SystemSymbolFilter = + index::IndexingOptions::SystemSymbolFilterKind::None; + // We don't want local symbols. + Opts.IndexFunctionLocals = false; + index::indexTopLevelDecls(AST.getASTContext(), AST.getPreprocessor(), {FD}, + Recorder, std::move(Opts)); + return Recorder.DeclRefs; +} } // namespace clangd } // namespace clang diff --git a/clang-tools-extra/clangd/unittests/XRefsTests.cpp b/clang-tools-extra/clangd/unittests/XRefsTests.cpp --- a/clang-tools-extra/clangd/unittests/XRefsTests.cpp +++ b/clang-tools-extra/clangd/unittests/XRefsTests.cpp @@ -10,6 +10,7 @@ #include "Matchers.h" #include "ParsedAST.h" #include "Protocol.h" +#include "SourceCode.h" #include "SyncAPI.h" #include "TestFS.h" #include "TestIndex.h" @@ -18,13 +19,19 @@ #include "index/FileIndex.h" #include "index/MemIndex.h" #include "index/SymbolCollector.h" +#include "clang/AST/Decl.h" +#include "clang/Basic/SourceLocation.h" #include "clang/Index/IndexingAction.h" #include "llvm/ADT/None.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Error.h" #include "llvm/Support/Path.h" #include "llvm/Support/ScopedPrinter.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include +#include namespace clang { namespace clangd { @@ -2187,6 +2194,84 @@ } } +TEST(GetNonLocalDeclRefs, All) { + struct Case { + llvm::StringRef AnnotatedCode; + std::vector ExpectedDecls; + } Cases[] = { + { + // VarDecl and ParamVarDecl + R"cpp( + void bar(); + void ^foo(int baz) { + int x = 10; + bar(); + })cpp", + {"bar"}, + }, + { + // Method from class + R"cpp( + class Foo { public: void foo(); }; + class Bar { + void foo(); + void bar(); + }; + void Bar::^foo() { + Foo f; + bar(); + f.foo(); + })cpp", + {"Bar", "Bar::bar", "Foo", "Foo::foo"}, + }, + { + // Local types + R"cpp( + void ^foo() { + class Foo { public: void foo() {} }; + class Bar { public: void bar() {} }; + Foo f; + Bar b; + b.bar(); + f.foo(); + })cpp", + {}, + }, + { + // Template params + R"cpp( + template class Q> + void ^foo() { + T x; + Q y; + })cpp", + {}, + }, + }; + for (const Case &C : Cases) { + Annotations File(C.AnnotatedCode); + auto AST = TestTU::withCode(File.code()).build(); + ASSERT_TRUE(AST.getDiagnostics().empty()) + << AST.getDiagnostics().begin()->Message; + SourceLocation SL = llvm::cantFail( + sourceLocationInMainFile(AST.getSourceManager(), File.point())); + + const FunctionDecl *FD = + llvm::dyn_cast(&findDecl(AST, [SL](const NamedDecl &ND) { + return ND.getLocation() == SL && llvm::isa(ND); + })); + ASSERT_NE(FD, nullptr); + + auto NonLocalDeclRefs = getNonLocalDeclRefs(AST, FD); + std::vector Names; + for (const Decl *D : NonLocalDeclRefs) { + if (const auto *ND = llvm::dyn_cast(D)) + Names.push_back(ND->getQualifiedNameAsString()); + } + EXPECT_THAT(Names, UnorderedElementsAreArray(C.ExpectedDecls)); + } +} + } // namespace } // namespace clangd } // namespace clang