Index: clangd/ClangdUnit.h =================================================================== --- clangd/ClangdUnit.h +++ clangd/ClangdUnit.h @@ -91,7 +91,8 @@ /// This function returns top-level decls present in the main file of the AST. /// The result does not include the decls that come from the preamble. - ArrayRef getLocalTopLevelDecls(); + /// (These should be const, but RecursiveASTVisitor requires Decl*). + ArrayRef getLocalTopLevelDecls(); const std::vector &getDiagnostics() const; @@ -104,8 +105,8 @@ ParsedAST(std::shared_ptr Preamble, std::unique_ptr Clang, std::unique_ptr Action, - std::vector LocalTopLevelDecls, - std::vector Diags, std::vector Inclusions); + std::vector LocalTopLevelDecls, std::vector Diags, + std::vector Inclusions); // In-memory preambles must outlive the AST, it is important that this member // goes before Clang and Action. @@ -122,7 +123,7 @@ std::vector Diags; // Top-level decls inside the current file. Not that this does not include // top-level decls from the preamble. - std::vector LocalTopLevelDecls; + std::vector LocalTopLevelDecls; std::vector Inclusions; }; Index: clangd/ClangdUnit.cpp =================================================================== --- clangd/ClangdUnit.cpp +++ clangd/ClangdUnit.cpp @@ -51,11 +51,11 @@ class DeclTrackingASTConsumer : public ASTConsumer { public: - DeclTrackingASTConsumer(std::vector &TopLevelDecls) + DeclTrackingASTConsumer(std::vector &TopLevelDecls) : TopLevelDecls(TopLevelDecls) {} bool HandleTopLevelDecl(DeclGroupRef DG) override { - for (const Decl *D : DG) { + for (Decl *D : DG) { // ObjCMethodDecl are not actually top-level decls. if (isa(D)) continue; @@ -66,14 +66,12 @@ } private: - std::vector &TopLevelDecls; + std::vector &TopLevelDecls; }; class ClangdFrontendAction : public SyntaxOnlyAction { public: - std::vector takeTopLevelDecls() { - return std::move(TopLevelDecls); - } + std::vector takeTopLevelDecls() { return std::move(TopLevelDecls); } protected: std::unique_ptr CreateASTConsumer(CompilerInstance &CI, @@ -82,7 +80,7 @@ } private: - std::vector TopLevelDecls; + std::vector TopLevelDecls; }; class CppFilePreambleCallbacks : public PreambleCallbacks { @@ -174,7 +172,7 @@ // CompilerInstance won't run this callback, do it directly. ASTDiags.EndSourceFile(); - std::vector ParsedDecls = Action->takeTopLevelDecls(); + std::vector ParsedDecls = Action->takeTopLevelDecls(); std::vector Diags = ASTDiags.take(); // Add diagnostics from the preamble, if any. if (Preamble) @@ -210,7 +208,7 @@ return Clang->getPreprocessor(); } -ArrayRef ParsedAST::getLocalTopLevelDecls() { +ArrayRef ParsedAST::getLocalTopLevelDecls() { return LocalTopLevelDecls; } @@ -261,7 +259,7 @@ ParsedAST::ParsedAST(std::shared_ptr Preamble, std::unique_ptr Clang, std::unique_ptr Action, - std::vector LocalTopLevelDecls, + std::vector LocalTopLevelDecls, std::vector Diags, std::vector Inclusions) : Preamble(std::move(Preamble)), Clang(std::move(Clang)), Action(std::move(Action)), Diags(std::move(Diags)), Index: clangd/CodeComplete.cpp =================================================================== --- clangd/CodeComplete.cpp +++ clangd/CodeComplete.cpp @@ -1005,12 +1005,15 @@ SymbolQualitySignals Quality; SymbolRelevanceSignals Relevance; + Relevance.Query = SymbolRelevanceSignals::CodeComplete; if (auto FuzzyScore = Filter->match(C.Name)) Relevance.NameMatch = *FuzzyScore; else return; - if (IndexResult) + if (IndexResult) { Quality.merge(*IndexResult); + Relevance.merge(*IndexResult); + } if (SemaResult) { Quality.merge(*SemaResult); Relevance.merge(*SemaResult); Index: clangd/Quality.h =================================================================== --- clangd/Quality.h +++ clangd/Quality.h @@ -68,7 +68,21 @@ /// where 1 is closest float ProximityScore = 0; + // An approximate measure of where we expect the symbol to be used. + enum AccessibleScope { + FunctionScope, + ClassScope, + FileScope, + GlobalScope, + } Scope = GlobalScope; + + enum QueryType { + CodeComplete, + Generic, + } Query = Generic; + void merge(const CodeCompletionResult &SemaResult); + void merge(const Symbol &IndexResult); // Condense these signals down to a single number, higher is better. float evaluate() const; Index: clangd/Quality.cpp =================================================================== --- clangd/Quality.cpp +++ clangd/Quality.cpp @@ -67,6 +67,27 @@ return OS; } +static SymbolRelevanceSignals::AccessibleScope +ComputeScope(const NamedDecl &D) { + bool InClass; + for (const DeclContext *DC = D.getDeclContext(); !DC->isFileContext(); + DC = DC->getParent()) { + if (DC->isFunctionOrMethod()) + return SymbolRelevanceSignals::FunctionScope; + InClass = InClass || DC->isRecord(); + } + if (InClass) + return SymbolRelevanceSignals::ClassScope; + if (D.getLinkageInternal() < ExternalLinkage) + return SymbolRelevanceSignals::FileScope; + return SymbolRelevanceSignals::GlobalScope; +} + +void SymbolRelevanceSignals::merge(const Symbol &IndexResult) { + // FIXME: Index results always assumed to be at global scope. If Scope becomes + // relevant to non-completion requests, we should recognize class members etc. +} + void SymbolRelevanceSignals::merge(const CodeCompletionResult &SemaCCResult) { if (SemaCCResult.Availability == CXAvailability_NotAvailable || SemaCCResult.Availability == CXAvailability_NotAccessible) @@ -79,16 +100,41 @@ hasDeclInMainFile(*SemaCCResult.Declaration) ? 1.0 : 0.0; ProximityScore = std::max(DeclProximity, ProximityScore); } + + // Declarations are scoped, others (like macros) are assumed global. + if (SemaCCResult.Kind == CodeCompletionResult::RK_Declaration) + Scope = std::min(Scope, ComputeScope(*SemaCCResult.Declaration)); } float SymbolRelevanceSignals::evaluate() const { + float Score = 1; + if (Forbidden) return 0; - float Score = NameMatch; + Score *= NameMatch; + // Proximity scores are [0,1] and we translate them into a multiplier in the // range from 1 to 2. Score *= 1 + ProximityScore; + + // Symbols like local variables may only be referenced within their scope. + // Conversely if we're in that scope, it's likely we'll reference them. + if (Query == CodeComplete) { + // The narrower the scope where a symbol is visible, the more likely it is + // to be relevant when it is available. + switch (Scope) { + case GlobalScope: + break; + case FileScope: + Score *= 1.5; + case ClassScope: + Score *= 2; + case FunctionScope: + Score *= 4; + } + } + return Score; } raw_ostream &operator<<(raw_ostream &OS, const SymbolRelevanceSignals &S) { Index: unittests/clangd/QualityTests.cpp =================================================================== --- unittests/clangd/QualityTests.cpp +++ unittests/clangd/QualityTests.cpp @@ -70,6 +70,8 @@ [[deprecated]] int test_deprecated() { return 0; } + + namespace { struct X { void y() { int z; } }; } )cpp"; auto AST = Test.build(); @@ -79,6 +81,7 @@ /*Accessible=*/false)); EXPECT_EQ(Deprecated.NameMatch, SymbolRelevanceSignals().NameMatch); EXPECT_TRUE(Deprecated.Forbidden); + EXPECT_EQ(Deprecated.Scope, SymbolRelevanceSignals::GlobalScope); // Test proximity scores. SymbolRelevanceSignals FuncInCpp; @@ -98,6 +101,16 @@ &findDecl(AST, "test_func_in_header_and_cpp"), CCP_Declaration)); /// Decls in both header **and** the main file get the same boost. EXPECT_FLOAT_EQ(FuncInHeaderAndCpp.ProximityScore, 1.0); + + SymbolRelevanceSignals Relevance; + Relevance.merge(CodeCompletionResult(&findAnyDecl(AST, "X"), 42)); + EXPECT_EQ(Relevance.Scope, SymbolRelevanceSignals::FileScope); + Relevance = {}; + Relevance.merge(CodeCompletionResult(&findAnyDecl(AST, "y"), 42)); + EXPECT_EQ(Relevance.Scope, SymbolRelevanceSignals::ClassScope); + Relevance = {}; + Relevance.merge(CodeCompletionResult(&findAnyDecl(AST, "z"), 42)); + EXPECT_EQ(Relevance.Scope, SymbolRelevanceSignals::FunctionScope); } // Do the signals move the scores in the direction we expect? @@ -137,6 +150,12 @@ SymbolRelevanceSignals WithProximity; WithProximity.ProximityScore = 0.2; EXPECT_LT(Default.evaluate(), WithProximity.evaluate()); + + SymbolRelevanceSignals Scoped; + Scoped.Scope = SymbolRelevanceSignals::FileScope; + EXPECT_EQ(Scoped.evaluate(), Default.evaluate()); + Scoped.Query = SymbolRelevanceSignals::CodeComplete; + EXPECT_GT(Scoped.evaluate(), Default.evaluate()); } TEST(QualityTests, SortText) { Index: unittests/clangd/TestTU.h =================================================================== --- unittests/clangd/TestTU.h +++ unittests/clangd/TestTU.h @@ -53,6 +53,8 @@ const Symbol &findSymbol(const SymbolSlab &, llvm::StringRef QName); // Look up an AST symbol by qualified name, which must be unique and top-level. const NamedDecl &findDecl(ParsedAST &AST, llvm::StringRef QName); +// Look up a main-file AST symbol by unqualified name, which must be unique. +const NamedDecl &findAnyDecl(ParsedAST &AST, llvm::StringRef Name); } // namespace clangd } // namespace clang Index: unittests/clangd/TestTU.cpp =================================================================== --- unittests/clangd/TestTU.cpp +++ unittests/clangd/TestTU.cpp @@ -10,6 +10,7 @@ #include "TestFS.h" #include "index/FileIndex.h" #include "index/MemIndex.h" +#include "clang/AST/RecursiveASTVisitor.h" #include "clang/Frontend/CompilerInvocation.h" #include "clang/Frontend/PCHContainerOperations.h" #include "clang/Frontend/Utils.h" @@ -49,7 +50,6 @@ return MemIndex::build(headerSymbols()); } -// Look up a symbol by qualified name, which must be unique. const Symbol &findSymbol(const SymbolSlab &Slab, llvm::StringRef QName) { const Symbol *Result = nullptr; for (const Symbol &S : Slab) { @@ -92,5 +92,26 @@ return LookupDecl(*Scope, Components.back()); } +const NamedDecl &findAnyDecl(ParsedAST &AST, llvm::StringRef Name) { + struct Visitor : RecursiveASTVisitor { + llvm::StringRef Name; + llvm::SmallVector Decls; + bool VisitNamedDecl(const NamedDecl *ND) { + if (auto *ID = ND->getIdentifier()) + if (ID->getName() == Name) + Decls.push_back(ND); + return true; + } + } Visitor; + Visitor.Name = Name; + for (Decl *D : AST.getLocalTopLevelDecls()) + Visitor.TraverseDecl(D); + if (Visitor.Decls.size() != 1) { + ADD_FAILURE() << Visitor.Decls.size() << " symbols named " << Name; + assert(Visitor.Decls.size() == 1); + } + return *Visitor.Decls.front(); +} + } // namespace clangd } // namespace clang