diff --git a/clang-tools-extra/clangd/ClangdServer.cpp b/clang-tools-extra/clangd/ClangdServer.cpp --- a/clang-tools-extra/clangd/ClangdServer.cpp +++ b/clang-tools-extra/clangd/ClangdServer.cpp @@ -255,6 +255,7 @@ ParseInput.Opts.BuildRecoveryAST = BuildRecoveryAST; ParseInput.Opts.PreserveRecoveryASTType = PreserveRecoveryASTType; + CodeCompleteOpts.MainFileSignals = IP->Signals; // FIXME(ibiryukov): even if Preamble is non-null, we may want to check // both the old and the new version in case only one of them matches. CodeCompleteResult Result = clangd::codeComplete( diff --git a/clang-tools-extra/clangd/CodeComplete.h b/clang-tools-extra/clangd/CodeComplete.h --- a/clang-tools-extra/clangd/CodeComplete.h +++ b/clang-tools-extra/clangd/CodeComplete.h @@ -15,6 +15,7 @@ #ifndef LLVM_CLANG_TOOLS_EXTRA_CLANGD_CODECOMPLETE_H #define LLVM_CLANG_TOOLS_EXTRA_CLANGD_CODECOMPLETE_H +#include "ASTSignals.h" #include "Compiler.h" #include "Headers.h" #include "Protocol.h" @@ -89,6 +90,7 @@ /// clangd. const SymbolIndex *Index = nullptr; + const ASTSignals *MainFileSignals = nullptr; /// Include completions that require small corrections, e.g. change '.' to /// '->' on member access etc. bool IncludeFixIts = false; diff --git a/clang-tools-extra/clangd/CodeComplete.cpp b/clang-tools-extra/clangd/CodeComplete.cpp --- a/clang-tools-extra/clangd/CodeComplete.cpp +++ b/clang-tools-extra/clangd/CodeComplete.cpp @@ -1685,6 +1685,7 @@ if (PreferredType) Relevance.HadContextType = true; Relevance.ContextWords = &ContextWords; + Relevance.MainFileSignals = Opts.MainFileSignals; auto &First = Bundle.front(); if (auto FuzzyScore = fuzzyScore(First)) diff --git a/clang-tools-extra/clangd/Quality.h b/clang-tools-extra/clangd/Quality.h --- a/clang-tools-extra/clangd/Quality.h +++ b/clang-tools-extra/clangd/Quality.h @@ -29,6 +29,7 @@ #include "ExpectedTypes.h" #include "FileDistance.h" +#include "TUScheduler.h" #include "clang/Sema/CodeCompleteConsumer.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" @@ -140,6 +141,14 @@ /// CompletionPrefix. unsigned FilterLength = 0; + const ASTSignals *MainFileSignals = nullptr; + /// Number of references to the candidate in the main file. + unsigned MainFileRefs = 0; + /// Number of unique symbols in the main file which belongs to candidate's + /// namespace. This indicates how relevant the namespace is in the current + /// file. + unsigned ScopeRefsInFile = 0; + /// Set of derived signals computed by calculateDerivedSignals(). Must not be /// set explicitly. struct DerivedSignals { @@ -155,6 +164,7 @@ void merge(const CodeCompletionResult &SemaResult); void merge(const Symbol &IndexResult); + void computeASTSignals(const CodeCompletionResult &SemaResult); // Condense these signals down to a single number, higher is better. float evaluateHeuristics() const; diff --git a/clang-tools-extra/clangd/Quality.cpp b/clang-tools-extra/clangd/Quality.cpp --- a/clang-tools-extra/clangd/Quality.cpp +++ b/clang-tools-extra/clangd/Quality.cpp @@ -294,6 +294,38 @@ if (!(IndexResult.Flags & Symbol::VisibleOutsideFile)) { Scope = AccessibleScope::FileScope; } + if (MainFileSignals) { + MainFileRefs = + std::max(MainFileRefs, + MainFileSignals->ReferencedSymbols.lookup(IndexResult.ID)); + ScopeRefsInFile = + std::max(ScopeRefsInFile, + MainFileSignals->RelatedNamespaces.lookup(IndexResult.Scope)); + } +} + +void SymbolRelevanceSignals::computeASTSignals( + const CodeCompletionResult &SemaResult) { + if (!MainFileSignals) + return; + if ((SemaResult.Kind != CodeCompletionResult::RK_Declaration) && + (SemaResult.Kind != CodeCompletionResult::RK_Pattern)) + return; + if (const NamedDecl *ND = SemaResult.getDeclaration()) { + auto ID = getSymbolID(ND); + if (!ID) + return; + MainFileRefs = + std::max(MainFileRefs, MainFileSignals->ReferencedSymbols.lookup(ID)); + if (const auto *NSD = dyn_cast(ND->getDeclContext())) { + if (NSD->isAnonymousNamespace()) + return; + std::string Scope = printNamespaceScope(*NSD); + if (!Scope.empty()) + ScopeRefsInFile = std::max( + ScopeRefsInFile, MainFileSignals->RelatedNamespaces.lookup(Scope)); + } + } } void SymbolRelevanceSignals::merge(const CodeCompletionResult &SemaCCResult) { @@ -315,6 +347,7 @@ InBaseClass |= SemaCCResult.InBaseClass; } + computeASTSignals(SemaCCResult); // Declarations are scoped, others (like macros) are assumed global. if (SemaCCResult.Declaration) Scope = std::min(Scope, computeScope(SemaCCResult.Declaration)); diff --git a/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp b/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp --- a/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp +++ b/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "ASTSignals.h" #include "Annotations.h" #include "ClangdServer.h" #include "CodeComplete.h" @@ -51,6 +52,8 @@ // GMock helpers for matching completion items. MATCHER_P(Named, Name, "") { return arg.Name == Name; } +MATCHER_P(MainFileRefs, Refs, "") { return arg.MainFileRefs == Refs; } +MATCHER_P(ScopeRefs, Refs, "") { return arg.ScopeRefsInFile == Refs; } MATCHER_P(NameStartsWith, Prefix, "") { return llvm::StringRef(arg.Name).startswith(Prefix); } @@ -1110,6 +1113,49 @@ UnorderedElementsAre(Named("xy1"), Named("xy2"))); } +TEST(CompletionTest, ASTSignals) { + struct Completion { + std::string Name; + unsigned MainFileRefs; + unsigned ScopeRefsInFile; + }; + CodeCompleteOptions Opts; + std::vector RecordedCompletions; + Opts.RecordCCResult = [&RecordedCompletions](const CodeCompletion &CC, + const SymbolQualitySignals &, + const SymbolRelevanceSignals &R, + float Score) { + RecordedCompletions.push_back({CC.Name, R.MainFileRefs, R.ScopeRefsInFile}); + }; + ASTSignals MainFileSignals; + MainFileSignals.ReferencedSymbols[var("xy1").ID] = 3; + MainFileSignals.ReferencedSymbols[var("xy2").ID] = 1; + MainFileSignals.ReferencedSymbols[var("xyindex").ID] = 10; + MainFileSignals.RelatedNamespaces["tar::"] = 5; + MainFileSignals.RelatedNamespaces["bar::"] = 3; + Opts.MainFileSignals = &MainFileSignals; + Opts.AllScopes = true; + completions( + R"cpp( + int xy1; + int xy2; + namespace bar { + int xybar = 1; + int a = xy^ + } + )cpp", + /*IndexSymbols=*/{var("xyindex"), var("tar::xytar"), var("bar::xybar")}, + Opts); + EXPECT_THAT(RecordedCompletions, + UnorderedElementsAre( + AllOf(Named("xy1"), MainFileRefs(3u), ScopeRefs(0u)), + AllOf(Named("xy2"), MainFileRefs(1u), ScopeRefs(0u)), + AllOf(Named("xyindex"), MainFileRefs(10u), ScopeRefs(0u)), + AllOf(Named("xytar"), MainFileRefs(0u), ScopeRefs(5u)), + AllOf(/*both from sema and index*/ Named("xybar"), + MainFileRefs(0u), ScopeRefs(3u)))); +} + SignatureHelp signatures(llvm::StringRef Text, Position Point, std::vector IndexSymbols = {}) { std::unique_ptr Index;