Index: clangd/Quality.h =================================================================== --- clangd/Quality.h +++ clangd/Quality.h @@ -50,6 +50,15 @@ bool Deprecated = false; unsigned References = 0; + enum SymbolCategory { + Variable, + Macro, + Type, + Function, + Namespace, + Unknown, + } Category = Unknown; + void merge(const CodeCompletionResult &SemaCCResult); void merge(const Symbol &IndexResult); Index: clangd/Quality.cpp =================================================================== --- clangd/Quality.cpp +++ clangd/Quality.cpp @@ -9,6 +9,7 @@ #include "Quality.h" #include "index/Index.h" #include "clang/AST/ASTContext.h" +#include "clang/AST/DeclVisitor.h" #include "clang/Basic/SourceManager.h" #include "clang/Sema/CodeCompleteConsumer.h" #include "llvm/Support/FormatVariadic.h" @@ -29,14 +30,82 @@ return false; } +static SymbolQualitySignals::SymbolCategory categorize(const NamedDecl &ND) { + class Switch + : public ConstDeclVisitor { + public: +#define MAP(DeclType, Category) \ + SymbolQualitySignals::SymbolCategory Visit##DeclType(const DeclType *) { \ + return SymbolQualitySignals::Category; \ + } + + MAP(NamespaceDecl, Namespace); + MAP(NamespaceAliasDecl, Namespace); + MAP(TypeDecl, Type); + MAP(TypeAliasTemplateDecl, Type); + MAP(ClassTemplateDecl, Type); + MAP(ValueDecl, Variable); + MAP(VarTemplateDecl, Variable); + MAP(FunctionDecl, Function); + MAP(FunctionTemplateDecl, Function); + MAP(Decl, Unknown); + }; + return Switch().Visit(&ND); +} + +static SymbolQualitySignals::SymbolCategory +categorize(const index::SymbolInfo &D) { + switch (D.Kind) { + case index::SymbolKind::Namespace: + case index::SymbolKind::NamespaceAlias: + return SymbolQualitySignals::Namespace; + case index::SymbolKind::Macro: + return SymbolQualitySignals::Macro; + case index::SymbolKind::Enum: + case index::SymbolKind::Struct: + case index::SymbolKind::Class: + case index::SymbolKind::Protocol: + case index::SymbolKind::Extension: + case index::SymbolKind::Union: + case index::SymbolKind::TypeAlias: + return SymbolQualitySignals::Type; + case index::SymbolKind::Function: + case index::SymbolKind::ClassMethod: + case index::SymbolKind::InstanceMethod: + case index::SymbolKind::StaticMethod: + case index::SymbolKind::InstanceProperty: + case index::SymbolKind::ClassProperty: + case index::SymbolKind::StaticProperty: + case index::SymbolKind::Constructor: + case index::SymbolKind::Destructor: + case index::SymbolKind::ConversionFunction: + return SymbolQualitySignals::Function; + case index::SymbolKind::Variable: + case index::SymbolKind::Field: + case index::SymbolKind::EnumConstant: + case index::SymbolKind::Parameter: + return SymbolQualitySignals::Variable; + case index::SymbolKind::Using: + case index::SymbolKind::Module: + case index::SymbolKind::Unknown: + return SymbolQualitySignals::Unknown; + } +} + void SymbolQualitySignals::merge(const CodeCompletionResult &SemaCCResult) { SemaCCPriority = SemaCCResult.Priority; if (SemaCCResult.Availability == CXAvailability_Deprecated) Deprecated = true; + + if (SemaCCResult.Declaration) + Category = categorize(*SemaCCResult.Declaration); + else if (SemaCCResult.Kind == CodeCompletionResult::RK_Macro) + Category = Macro; } void SymbolQualitySignals::merge(const Symbol &IndexResult) { References = std::max(IndexResult.References, References); + Category = categorize(IndexResult.SymInfo); } float SymbolQualitySignals::evaluate() const { @@ -55,6 +124,20 @@ if (Deprecated) Score *= 0.1f; + switch (Category) { + case Type: + case Function: + case Variable: + Score *= 1.1; + break; + case Namespace: + case Macro: + Score *= 0.2; + break; + case Unknown: + break; + } + return Score; } @@ -64,6 +147,7 @@ OS << formatv("\tSemaCCPriority: {0}\n", S.SemaCCPriority); OS << formatv("\tReferences: {0}\n", S.References); OS << formatv("\tDeprecated: {0}\n", S.Deprecated); + OS << formatv("\tCategory: {0}\n", S.Category); return OS; } Index: unittests/clangd/QualityTests.cpp =================================================================== --- unittests/clangd/QualityTests.cpp +++ unittests/clangd/QualityTests.cpp @@ -41,6 +41,7 @@ EXPECT_FALSE(Quality.Deprecated); EXPECT_EQ(Quality.SemaCCPriority, SymbolQualitySignals().SemaCCPriority); EXPECT_EQ(Quality.References, SymbolQualitySignals().References); + EXPECT_EQ(Quality.Category, SymbolQualitySignals::Variable); Symbol F = findSymbol(Symbols, "f"); F.References = 24; // TestTU doesn't count references, so fake it. @@ -49,12 +50,14 @@ EXPECT_FALSE(Quality.Deprecated); // FIXME: Include deprecated bit in index. EXPECT_EQ(Quality.SemaCCPriority, SymbolQualitySignals().SemaCCPriority); EXPECT_EQ(Quality.References, 24u); + EXPECT_EQ(Quality.Category, SymbolQualitySignals::Function); Quality = {}; Quality.merge(CodeCompletionResult(&findDecl(AST, "f"), /*Priority=*/42)); EXPECT_TRUE(Quality.Deprecated); EXPECT_EQ(Quality.SemaCCPriority, 42u); EXPECT_EQ(Quality.References, SymbolQualitySignals().References); + EXPECT_EQ(Quality.Category, SymbolQualitySignals::Function); } TEST(QualityTests, SymbolRelevanceSignalExtraction) { @@ -123,6 +126,12 @@ HighPriority.SemaCCPriority = 20; EXPECT_GT(HighPriority.evaluate(), Default.evaluate()); EXPECT_LT(LowPriority.evaluate(), Default.evaluate()); + + SymbolQualitySignals Variable, Macro; + Variable.Category = SymbolQualitySignals::Variable; + Macro.Category = SymbolQualitySignals::Macro; + EXPECT_GT(Variable.evaluate(), Default.evaluate()); + EXPECT_LT(Macro.evaluate(), Default.evaluate()); } TEST(QualityTests, SymbolRelevanceSignalsSanity) {