Index: clangd/IncludeFixer.h =================================================================== --- clangd/IncludeFixer.h +++ clangd/IncludeFixer.h @@ -18,9 +18,9 @@ #include "clang/Sema/ExternalSemaSource.h" #include "clang/Sema/Sema.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/IntrusiveRefCntPtr.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include @@ -79,6 +79,15 @@ // These collect the last unresolved name so that we can associate it with the // diagnostic. llvm::Optional LastUnresolvedName; + + // There can be multiple diagnostics that are caused by the same unresolved + // name or incomplete type in one parse, especially when code is + // copy-and-pasted without #includes. As fixes are purely dependent on index + // requests and index results at this point, we can cache fixes by index + // requests to avoid repetitive index queries (assuming index results are + // consistent during the single AST parse). + mutable llvm::StringMap> FuzzyReqToFixesCache; + mutable llvm::StringMap> LookupIDToFixesCache; }; } // namespace clangd Index: clangd/IncludeFixer.cpp =================================================================== --- clangd/IncludeFixer.cpp +++ clangd/IncludeFixer.cpp @@ -57,8 +57,6 @@ std::vector IncludeFixer::fix(DiagnosticsEngine::Level DiagLevel, const clang::Diagnostic &Info) const { - if (IndexRequestCount >= IndexRequestLimit) - return {}; // Avoid querying index too many times in a single parse. switch (Info.getID()) { case diag::err_incomplete_type: case diag::err_incomplete_member_access: @@ -118,11 +116,17 @@ auto ID = getSymbolID(TD); if (!ID) return {}; - ++IndexRequestCount; - // FIXME: consider batching the requests for all diagnostics. - // FIXME: consider caching the lookup results. LookupRequest Req; Req.IDs.insert(*ID); + + std::string IDStr = ID->str(); + auto I = LookupIDToFixesCache.find(IDStr); + if (I != LookupIDToFixesCache.end()) + return I->second; + + if (IndexRequestCount++ >= IndexRequestLimit) + return {}; + // FIXME: consider batching the requests for all diagnostics. llvm::Optional Matched; Index.lookup(Req, [&](const Symbol &Sym) { if (Matched) @@ -130,10 +134,12 @@ Matched = Sym; }); - if (!Matched || Matched->IncludeHeaders.empty() || !Matched->Definition || - Matched->CanonicalDeclaration.FileURI != Matched->Definition.FileURI) - return {}; - return fixesForSymbols({*Matched}); + std::vector Fixes; + if (Matched && !Matched->IncludeHeaders.empty() && Matched->Definition && + Matched->CanonicalDeclaration.FileURI == Matched->Definition.FileURI) + Fixes = fixesForSymbols({*Matched}); + LookupIDToFixesCache[IDStr] = Fixes; + return Fixes; } std::vector @@ -289,6 +295,14 @@ Req.RestrictForCodeCompletion = true; Req.Limit = 100; + auto ReqStr = llvm::formatv("{0}", toJSON(Req)).str(); + auto I = FuzzyReqToFixesCache.find(ReqStr); + if (I != FuzzyReqToFixesCache.end()) + return I->second; + + if (IndexRequestCount++ >= IndexRequestLimit) + return {}; + SymbolSlab::Builder Matches; Index.fuzzyFind(Req, [&](const Symbol &Sym) { if (Sym.Name != Req.Query) @@ -297,7 +311,10 @@ Matches.insert(Sym); }); auto Syms = std::move(Matches).build(); - return fixesForSymbols(std::vector(Syms.begin(), Syms.end())); + auto Fixes = fixesForSymbols(std::vector(Syms.begin(), Syms.end())); + + FuzzyReqToFixesCache[ReqStr] = Fixes; + return Fixes; } } // namespace clangd Index: unittests/clangd/DiagnosticsTests.cpp =================================================================== --- unittests/clangd/DiagnosticsTests.cpp +++ unittests/clangd/DiagnosticsTests.cpp @@ -449,6 +449,44 @@ UnorderedElementsAre(Diag(Test.range(), "no member named 'xy' in 'X'"))); } +TEST(IncludeFixerTest, UseCachedIndexResults) { + // As index results for the identical request are cached, more than 5 fixes + // are generated. + Annotations Test(R"cpp( +$insert[[]]void foo() { + $x1[[X]] x; + $x2[[X]] x; + $x3[[X]] x; + $x4[[X]] x; + $x5[[X]] x; + $x6[[X]] x; + $x7[[X]] x; +} + +class X; +void bar(X *x) { + x$a1[[->]]f(); + x$a2[[->]]f(); + x$a3[[->]]f(); + x$a4[[->]]f(); + x$a5[[->]]f(); + x$a6[[->]]f(); + x$a7[[->]]f(); +} + )cpp"); + auto TU = TestTU::withCode(Test.code()); + auto Index = + buildIndexWithSymbol(SymbolWithHeader{"X", "unittest:///a.h", "\"a.h\""}); + TU.ExternalIndex = Index.get(); + + auto Parsed = TU.build(); + for (const auto &D : Parsed.getDiagnostics()) { + EXPECT_EQ(D.Fixes.size(), 1u); + EXPECT_EQ(D.Fixes[0].Message, + std::string("Add include \"a.h\" for symbol X")); + } +} + } // namespace } // namespace clangd } // namespace clang