diff --git a/clang-tools-extra/clangd/CodeComplete.h b/clang-tools-extra/clangd/CodeComplete.h --- a/clang-tools-extra/clangd/CodeComplete.h +++ b/clang-tools-extra/clangd/CodeComplete.h @@ -154,14 +154,21 @@ DecisionForest, } RankingModel = Heuristics; + /// Callback used to score a CompletionCandidate if DecisionForest ranking + /// model is enabled. + /// This allows us to inject experimental models and compare them with + /// baseline model using A/B testing. + std::function + DecisionForestScorer = &evaluateDecisionForest; /// Weight for combining NameMatch and Prediction of DecisionForest. /// CompletionScore is NameMatch * pow(Base, Prediction). /// The optimal value of Base largely depends on the semantics of the model /// and prediction score (e.g. algorithm used during training, number of /// trees, etc.). Usually if the range of Prediciton is [-20, 20] then a Base /// in [1.2, 1.7] works fine. - /// Semantics: E.g. the completion score reduces by 50% if the Prediciton - /// score is reduced by 2.6 points for Base = 1.3. + /// Semantics: E.g. For Base = 1.3, if the Prediciton score reduces by 2.6 + /// points then completion score reduces by 50% or 1.3^(-2.6). float DecisionForestBase = 1.3f; }; diff --git a/clang-tools-extra/clangd/CodeComplete.cpp b/clang-tools-extra/clangd/CodeComplete.cpp --- a/clang-tools-extra/clangd/CodeComplete.cpp +++ b/clang-tools-extra/clangd/CodeComplete.cpp @@ -1644,19 +1644,10 @@ return Scores; case RM::DecisionForest: - Scores.Quality = 0; - Scores.Relevance = 0; - // Exponentiating DecisionForest prediction makes the score of each tree a - // multiplciative boost (like NameMatch). This allows us to weigh the - // prediciton score and NameMatch appropriately. - Scores.ExcludingName = pow(Opts.DecisionForestBase, - evaluateDecisionForest(Quality, Relevance)); - // NeedsFixIts is not part of the DecisionForest as generating training - // data that needs fixits is not-feasible. - if (Relevance.NeedsFixIts) - Scores.ExcludingName *= 0.5; - // NameMatch should be a multiplier on total score to support rescoring. - Scores.Total = Relevance.NameMatch * Scores.ExcludingName; + DecisionForestScores DFScores = Opts.DecisionForestScorer( + Quality, Relevance, Opts.DecisionForestBase); + Scores.ExcludingName = DFScores.ExcludingName; + Scores.Total = DFScores.Total; return Scores; } llvm_unreachable("Unhandled CodeCompletion ranking model."); diff --git a/clang-tools-extra/clangd/Quality.h b/clang-tools-extra/clangd/Quality.h --- a/clang-tools-extra/clangd/Quality.h +++ b/clang-tools-extra/clangd/Quality.h @@ -165,8 +165,18 @@ /// Combine symbol quality and relevance into a single score. float evaluateSymbolAndRelevance(float SymbolQuality, float SymbolRelevance); -float evaluateDecisionForest(const SymbolQualitySignals &Quality, - const SymbolRelevanceSignals &Relevance); +/// Same semantics as CodeComplete::Score. Quality score and Relevance score +/// have been removed since DecisionForest cannot assign individual scores to +/// Quality and Relevance signals. +struct DecisionForestScores { + float Total = 0.f; + float ExcludingName = 0.f; +}; + +DecisionForestScores +evaluateDecisionForest(const SymbolQualitySignals &Quality, + const SymbolRelevanceSignals &Relevance, float Base); + /// TopN is a lossy container that preserves only the "best" N elements. template > class TopN { public: diff --git a/clang-tools-extra/clangd/Quality.cpp b/clang-tools-extra/clangd/Quality.cpp --- a/clang-tools-extra/clangd/Quality.cpp +++ b/clang-tools-extra/clangd/Quality.cpp @@ -487,8 +487,9 @@ return SymbolQuality * SymbolRelevance; } -float evaluateDecisionForest(const SymbolQualitySignals &Quality, - const SymbolRelevanceSignals &Relevance) { +DecisionForestScores +evaluateDecisionForest(const SymbolQualitySignals &Quality, + const SymbolRelevanceSignals &Relevance, float Base) { Example E; E.setIsDeprecated(Quality.Deprecated); E.setIsReservedName(Quality.ReservedName); @@ -512,7 +513,19 @@ E.setHadSymbolType(Relevance.HadSymbolType); E.setTypeMatchesPreferred(Relevance.TypeMatchesPreferred); E.setFilterLength(Relevance.FilterLength); - return Evaluate(E); + + DecisionForestScores Scores; + // Exponentiating DecisionForest prediction makes the score of each tree a + // multiplciative boost (like NameMatch). This allows us to weigh the + // prediciton score and NameMatch appropriately. + Scores.ExcludingName = pow(Base, Evaluate(E)); + // NeedsFixIts is not part of the DecisionForest as generating training + // data that needs fixits is not-feasible. + if (Relevance.NeedsFixIts) + Scores.ExcludingName *= 0.5; + // NameMatch should be a multiplier on total score to support rescoring. + Scores.Total = Relevance.NameMatch * Scores.ExcludingName; + return Scores; } // Produces an integer that sorts in the same order as F. diff --git a/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp b/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp --- a/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp +++ b/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp @@ -194,6 +194,33 @@ ElementsAre(Named("clangA"), Named("clangD"))); } +TEST(DecisionForestRankingModel, DecisionForestScorerCallbackTest) { + clangd::CodeCompleteOptions Opts; + constexpr float MagicNumber = 1234.5678; + Opts.RankingModel = CodeCompleteOptions::DecisionForest; + Opts.DecisionForestScorer = [](const SymbolQualitySignals &, + const SymbolRelevanceSignals &, float Base) { + DecisionForestScores Scores; + Scores.Total = MagicNumber; + Scores.ExcludingName = MagicNumber; + return Scores; + }; + llvm::StringRef Code = "int func() { int xyz; xy^ }"; + auto Results = completions(Code, + /*IndexSymbols=*/{}, Opts); + ASSERT_EQ(Results.Completions.size(), 1u); + EXPECT_EQ(Results.Completions[0].Score.Total, MagicNumber); + EXPECT_EQ(Results.Completions[0].Score.ExcludingName, MagicNumber); + + // Do not use DecisionForestScorer for heuristics model. + Opts.RankingModel = CodeCompleteOptions::Heuristics; + Results = completions(Code, + /*IndexSymbols=*/{}, Opts); + ASSERT_EQ(Results.Completions.size(), 1u); + EXPECT_NE(Results.Completions[0].Score.Total, MagicNumber); + EXPECT_NE(Results.Completions[0].Score.ExcludingName, MagicNumber); +} + TEST(CompletionTest, Limit) { clangd::CodeCompleteOptions Opts; Opts.Limit = 2;