Index: clangd/CMakeLists.txt =================================================================== --- clangd/CMakeLists.txt +++ clangd/CMakeLists.txt @@ -40,6 +40,7 @@ FuzzyMatch.cpp GlobalCompilationDatabase.cpp Headers.cpp + IncludeFixer.cpp JSONTransport.cpp Logger.cpp Protocol.cpp Index: clangd/ClangdServer.cpp =================================================================== --- clangd/ClangdServer.cpp +++ clangd/ClangdServer.cpp @@ -152,8 +152,9 @@ // "PreparingBuild" status to inform users, it is non-trivial given the // current implementation. WorkScheduler.update(File, - ParseInputs{getCompileCommand(File), - FSProvider.getFileSystem(), Contents.str()}, + ParseInputs(getCompileCommand(File), + FSProvider.getFileSystem(), Contents.str(), + Index), WantDiags); } Index: clangd/ClangdUnit.h =================================================================== --- clangd/ClangdUnit.h +++ clangd/ClangdUnit.h @@ -15,6 +15,7 @@ #include "Headers.h" #include "Path.h" #include "Protocol.h" +#include "index/Index.h" #include "clang/Frontend/FrontendAction.h" #include "clang/Frontend/PrecompiledPreamble.h" #include "clang/Lex/Preprocessor.h" @@ -61,9 +62,19 @@ /// Information required to run clang, e.g. to parse AST or do code completion. struct ParseInputs { + ParseInputs(tooling::CompileCommand CompileCommand, + IntrusiveRefCntPtr FS, + std::string Contents, const SymbolIndex *Index) + : CompileCommand(CompileCommand), FS(std::move(FS)), + Contents(std::move(Contents)), Index(Index) {} + + ParseInputs() = default; + tooling::CompileCommand CompileCommand; IntrusiveRefCntPtr FS; std::string Contents; + // Used to recover from diagnostics (e.g. find missing includes for symbol). + const SymbolIndex *Index = nullptr; }; /// Stores and provides access to parsed AST. @@ -76,7 +87,8 @@ std::shared_ptr Preamble, std::unique_ptr Buffer, std::shared_ptr PCHs, - IntrusiveRefCntPtr VFS); + IntrusiveRefCntPtr VFS, + const SymbolIndex *Index); ParsedAST(ParsedAST &&Other); ParsedAST &operator=(ParsedAST &&Other); Index: clangd/ClangdUnit.cpp =================================================================== --- clangd/ClangdUnit.cpp +++ clangd/ClangdUnit.cpp @@ -11,9 +11,12 @@ #include "../clang-tidy/ClangTidyModuleRegistry.h" #include "Compiler.h" #include "Diagnostics.h" +#include "Headers.h" +#include "IncludeFixer.h" #include "Logger.h" #include "SourceCode.h" #include "Trace.h" +#include "index/Index.h" #include "clang/AST/ASTContext.h" #include "clang/Basic/LangOptions.h" #include "clang/Frontend/CompilerInstance.h" @@ -30,6 +33,7 @@ #include "clang/Serialization/ASTWriter.h" #include "clang/Tooling/CompilationDatabase.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" @@ -227,7 +231,8 @@ std::shared_ptr Preamble, std::unique_ptr Buffer, std::shared_ptr PCHs, - llvm::IntrusiveRefCntPtr VFS) { + llvm::IntrusiveRefCntPtr VFS, + const SymbolIndex *Index) { assert(CI); // Command-line parsing sets DisableFree to true by default, but we don't want // to leak memory in clangd. @@ -236,9 +241,11 @@ Preamble ? &Preamble->Preamble : nullptr; StoreDiags ASTDiags; + std::string Content = Buffer->getBuffer(); + auto Clang = prepareCompilerInstance(std::move(CI), PreamblePCH, std::move(Buffer), - std::move(PCHs), std::move(VFS), ASTDiags); + std::move(PCHs), VFS, ASTDiags); if (!Clang) return None; @@ -285,6 +292,24 @@ } } + llvm::Optional FixIncludes; + auto BuildDir = VFS->getCurrentWorkingDirectory(); + // Add IncludeFixer if Index is provided. + if (Index && !BuildDir.getError()) { + auto Style = getFormatStyleForFile(MainInput.getFile(), Content, VFS.get()); + auto Inserter = llvm::make_unique( + MainInput.getFile(), Content, Style, BuildDir.get(), + Clang->getPreprocessor().getHeaderSearchInfo()); + if (Preamble) { + for (const auto &Inc : Preamble->Includes.MainFileIncludes) + Inserter->addExisting(Inc); + } + FixIncludes.emplace(*Clang, MainInput.getFile(), std::move(Inserter), + *Index); + ASTDiags.setIncludeFixer(*FixIncludes); + Clang->setExternalSemaSource(FixIncludes->typoRecorder()); + } + // Copy over the includes from the preamble, then combine with the // non-preamble includes below. auto Includes = Preamble ? Preamble->Includes : IncludeStructure{}; @@ -538,7 +563,7 @@ return ParsedAST::build(llvm::make_unique(*Invocation), Preamble, llvm::MemoryBuffer::getMemBufferCopy(Inputs.Contents), - PCHs, std::move(VFS)); + PCHs, std::move(VFS), Inputs.Index); } SourceLocation getBeginningOfIdentifier(ParsedAST &Unit, const Position &Pos, Index: clangd/CodeComplete.cpp =================================================================== --- clangd/CodeComplete.cpp +++ clangd/CodeComplete.cpp @@ -177,28 +177,6 @@ return Result; } -/// Creates a `HeaderFile` from \p Header which can be either a URI or a literal -/// include. -static llvm::Expected toHeaderFile(llvm::StringRef Header, - llvm::StringRef HintPath) { - if (isLiteralInclude(Header)) - return HeaderFile{Header.str(), /*Verbatim=*/true}; - auto U = URI::parse(Header); - if (!U) - return U.takeError(); - - auto IncludePath = URI::includeSpelling(*U); - if (!IncludePath) - return IncludePath.takeError(); - if (!IncludePath->empty()) - return HeaderFile{std::move(*IncludePath), /*Verbatim=*/true}; - - auto Resolved = URI::resolve(*U, HintPath); - if (!Resolved) - return Resolved.takeError(); - return HeaderFile{std::move(*Resolved), /*Verbatim=*/false}; -} - /// A code completion result, in clang-native form. /// It may be promoted to a CompletionItem if it's among the top-ranked results. struct CompletionCandidate { @@ -1155,24 +1133,6 @@ return CachedReq; } -// Returns the most popular include header for \p Sym. If two headers are -// equally popular, prefer the shorter one. Returns empty string if \p Sym has -// no include header. -llvm::SmallVector getRankedIncludes(const Symbol &Sym) { - auto Includes = Sym.IncludeHeaders; - // Sort in descending order by reference count and header length. - llvm::sort(Includes, [](const Symbol::IncludeHeaderWithReferences &LHS, - const Symbol::IncludeHeaderWithReferences &RHS) { - if (LHS.References == RHS.References) - return LHS.IncludeHeader.size() < RHS.IncludeHeader.size(); - return LHS.References > RHS.References; - }); - llvm::SmallVector Headers; - for (const auto &Include : Includes) - Headers.push_back(Include.IncludeHeader); - return Headers; -} - // Runs Sema-based (AST) and Index-based completion, returns merged results. // // There are a few tricky considerations: @@ -1253,19 +1213,12 @@ CodeCompleteResult Output; auto RecorderOwner = llvm::make_unique(Opts, [&]() { assert(Recorder && "Recorder is not set"); - auto Style = - format::getStyle(format::DefaultFormatStyle, SemaCCInput.FileName, - format::DefaultFallbackStyle, SemaCCInput.Contents, - SemaCCInput.VFS.get()); - if (!Style) { - log("getStyle() failed for file {0}: {1}. Fallback is LLVM style.", - SemaCCInput.FileName, Style.takeError()); - Style = format::getLLVMStyle(); - } + auto Style = getFormatStyleForFile( + SemaCCInput.FileName, SemaCCInput.Contents, SemaCCInput.VFS.get()); // If preprocessor was run, inclusions from preprocessor callback should // already be added to Includes. Inserter.emplace( - SemaCCInput.FileName, SemaCCInput.Contents, *Style, + SemaCCInput.FileName, SemaCCInput.Contents, Style, SemaCCInput.Command.Directory, Recorder->CCSema->getPreprocessor().getHeaderSearchInfo()); for (const auto &Inc : Includes.MainFileIncludes) Index: clangd/Diagnostics.h =================================================================== --- clangd/Diagnostics.h +++ clangd/Diagnostics.h @@ -86,6 +86,8 @@ /// Convert from clang diagnostic level to LSP severity. int getSeverity(DiagnosticsEngine::Level L); +class IncludeFixer; + /// StoreDiags collects the diagnostics that can later be reported by /// clangd. It groups all notes for a diagnostic into a single Diag /// and filters out diagnostics that don't mention the main file (i.e. neither @@ -99,9 +101,14 @@ void HandleDiagnostic(DiagnosticsEngine::Level DiagLevel, const clang::Diagnostic &Info) override; + /// If set, possibly adds fixes for diagnostics using \p Fixer. + void setIncludeFixer(const IncludeFixer &Fixer) { FixIncludes = &Fixer; } + private: void flushLastDiag(); + const IncludeFixer *FixIncludes = nullptr; + std::vector Output; llvm::Optional LangOpts; llvm::Optional LastDiag; Index: clangd/Diagnostics.cpp =================================================================== --- clangd/Diagnostics.cpp +++ clangd/Diagnostics.cpp @@ -8,6 +8,7 @@ #include "Diagnostics.h" #include "Compiler.h" +#include "IncludeFixer.h" #include "Logger.h" #include "SourceCode.h" #include "clang/Basic/SourceManager.h" @@ -374,6 +375,11 @@ if (!Info.getFixItHints().empty()) AddFix(true /* try to invent a message instead of repeating the diag */); + if (FixIncludes) { + auto ExtraFixes = FixIncludes->fix(DiagLevel, Info); + LastDiag->Fixes.insert(LastDiag->Fixes.end(), ExtraFixes.begin(), + ExtraFixes.end()); + } } else { // Handle a note to an existing diagnostic. if (!LastDiag) { Index: clangd/Headers.h =================================================================== --- clangd/Headers.h +++ clangd/Headers.h @@ -12,10 +12,12 @@ #include "Path.h" #include "Protocol.h" #include "SourceCode.h" +#include "index/Index.h" #include "clang/Format/Format.h" #include "clang/Lex/HeaderSearch.h" #include "clang/Lex/PPCallbacks.h" #include "clang/Tooling/Inclusions/HeaderIncludes.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Error.h" @@ -37,6 +39,15 @@ bool valid() const; }; +/// Creates a `HeaderFile` from \p Header which can be either a URI or a literal +/// include. +llvm::Expected toHeaderFile(llvm::StringRef Header, + llvm::StringRef HintPath); + +// Returns include headers for \p Sym sorted by popularity. If two headers are +// equally popular, prefer the shorter one. +llvm::SmallVector getRankedIncludes(const Symbol &Sym); + // An #include directive that we found in the main file. struct Inclusion { Range R; // Inclusion range. Index: clangd/Headers.cpp =================================================================== --- clangd/Headers.cpp +++ clangd/Headers.cpp @@ -73,6 +73,41 @@ (!Verbatim && llvm::sys::path::is_absolute(File)); } +llvm::Expected toHeaderFile(llvm::StringRef Header, + llvm::StringRef HintPath) { + if (isLiteralInclude(Header)) + return HeaderFile{Header.str(), /*Verbatim=*/true}; + auto U = URI::parse(Header); + if (!U) + return U.takeError(); + + auto IncludePath = URI::includeSpelling(*U); + if (!IncludePath) + return IncludePath.takeError(); + if (!IncludePath->empty()) + return HeaderFile{std::move(*IncludePath), /*Verbatim=*/true}; + + auto Resolved = URI::resolve(*U, HintPath); + if (!Resolved) + return Resolved.takeError(); + return HeaderFile{std::move(*Resolved), /*Verbatim=*/false}; +} + +llvm::SmallVector getRankedIncludes(const Symbol &Sym) { + auto Includes = Sym.IncludeHeaders; + // Sort in descending order by reference count and header length. + llvm::sort(Includes, [](const Symbol::IncludeHeaderWithReferences &LHS, + const Symbol::IncludeHeaderWithReferences &RHS) { + if (LHS.References == RHS.References) + return LHS.IncludeHeader.size() < RHS.IncludeHeader.size(); + return LHS.References > RHS.References; + }); + llvm::SmallVector Headers; + for (const auto &Include : Includes) + Headers.push_back(Include.IncludeHeader); + return Headers; +} + std::unique_ptr collectIncludeStructureCallback(const SourceManager &SM, IncludeStructure *Out) { Index: clangd/IncludeFixer.h =================================================================== --- /dev/null +++ clangd/IncludeFixer.h @@ -0,0 +1,102 @@ +//===- IncludeFixer.h - Add missing includes --------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_TOOLS_EXTRA_CLANGD_INCLUDE_FIXER_H +#define LLVM_CLANG_TOOLS_EXTRA_CLANGD_INCLUDE_FIXER_H + +#include "Diagnostics.h" +#include "Headers.h" +#include "index/Index.h" +#include "clang/AST/Type.h" +#include "clang/Basic/Diagnostic.h" +#include "clang/Basic/SourceLocation.h" +#include "clang/Frontend/CompilerInstance.h" +#include "clang/Sema/ExternalSemaSource.h" +#include "clang/Sema/Sema.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/IntrusiveRefCntPtr.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringRef.h" +#include + +namespace clang { +namespace clangd { + +/// Attempts to recover from error diagnostics by suggesting include insertion +/// fixes. For example, member access into incomplete type can be fixes by +/// include headers with the definition. +class IncludeFixer { +public: + IncludeFixer(CompilerInstance &Compiler, llvm::StringRef File, + std::unique_ptr Inserter, + const SymbolIndex &Index) + : File(File), Inserter(std::move(Inserter)), Index(Index), + Compiler(Compiler), RecordTypo(new TypoRecorder(Compiler)) {} + + /// Returns include insertions that can potentially recover the diagnostic. + std::vector fix(DiagnosticsEngine::Level DiagLevel, + const clang::Diagnostic &Info) const; + + /// Returns an ExternalSemaSource that records typos seen in Sema. It must be + /// used in the same Sema run as the IncludeFixer. + llvm::IntrusiveRefCntPtr typoRecorder() { + return RecordTypo; + } + +private: + std::vector fixInCompleteType(const Type &T) const; + + std::vector fixesForSymbol(const Symbol &Sym) const; + + struct TypoRecord { + std::string Typo; // The typo identifier e.g. "X" in ns::X. + SourceLocation Loc; // Location of the typo. + Scope *S; // Scope in which the typo is found. + llvm::Optional SS; // The scope qualifier before the typo. + Sema::LookupNameKind LookupKind; // LookupKind of the typo. + }; + + /// Records the last typo seen by Sema. + class TypoRecorder : public ExternalSemaSource { + public: + TypoRecorder(CompilerInstance &Compiler) : Compiler(Compiler) {} + + // Captures the latest typo. + TypoCorrection CorrectTypo(const DeclarationNameInfo &Typo, int LookupKind, + Scope *S, CXXScopeSpec *SS, + CorrectionCandidateCallback &CCC, + DeclContext *MemberContext, bool EnteringContext, + const ObjCObjectPointerType *OPT) override; + + llvm::Optional lastTypo() const { return LastTypo; } + + private: + CompilerInstance &Compiler; + + llvm::Optional LastTypo; + }; + + /// Attempts to fix the typo associated with the current diagnostic. We assume + /// a diagnostic is caused by a typo when they have the same source location + /// and the typo is the last typo we've seen during the Sema run. + std::vector fixTypo(const TypoRecord &Typo) const; + + std::string File; + std::unique_ptr Inserter; + const SymbolIndex &Index; + CompilerInstance &Compiler; + // This collects the last typo so that we can associate it with the + // diagnostic. + llvm::IntrusiveRefCntPtr RecordTypo; +}; + +} // namespace clangd +} // namespace clang + +#endif // LLVM_CLANG_TOOLS_EXTRA_CLANGD_INCLUDE_FIXER_H Index: clangd/IncludeFixer.cpp =================================================================== --- /dev/null +++ clangd/IncludeFixer.cpp @@ -0,0 +1,250 @@ +//===--- IncludeFixer.cpp ----------------------------------------*- C++-*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "IncludeFixer.h" +#include "AST.h" +#include "Diagnostics.h" +#include "Logger.h" +#include "SourceCode.h" +#include "index/Index.h" +#include "clang/AST/Decl.h" +#include "clang/AST/DeclBase.h" +#include "clang/AST/NestedNameSpecifier.h" +#include "clang/AST/Type.h" +#include "clang/Basic/Diagnostic.h" +#include "clang/Basic/DiagnosticSema.h" +#include "clang/Sema/DeclSpec.h" +#include "clang/Sema/Lookup.h" +#include "clang/Sema/Scope.h" +#include "clang/Sema/Sema.h" +#include "clang/Sema/TypoCorrection.h" +#include "llvm/ADT/None.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" +#include + +namespace clang { +namespace clangd { + +namespace { + +bool isIncompleteTypeDiag(unsigned int DiagID) { + return DiagID == diag::err_incomplete_type || + DiagID == diag::err_incomplete_member_access || + DiagID == diag::err_incomplete_base_class; +} + +// Collects contexts visited during a Sema name lookup. +class VisitedContextCollector : public VisibleDeclConsumer { +public: + void EnteredContext(DeclContext *Ctx) override { Visited.push_back(Ctx); } + + void FoundDecl(NamedDecl *ND, NamedDecl *Hiding, DeclContext *Ctx, + bool InBaseClass) override {} + + std::vector takeVisitedContexts() { + return std::move(Visited); + } + +private: + std::vector Visited; +}; + +} // namespace + +std::vector IncludeFixer::fix(DiagnosticsEngine::Level DiagLevel, + const clang::Diagnostic &Info) const { + if (isIncompleteTypeDiag(Info.getID())) { + // Incomplete type diagnostics should have a QualType argument for the + // incomplete type. + for (unsigned i = 0; i < Info.getNumArgs(); ++i) { + if (Info.getArgKind(i) == DiagnosticsEngine::ak_qualtype) { + auto QT = QualType::getFromOpaquePtr((void *)Info.getRawArg(i)); + if (const Type *T = QT.getTypePtrOrNull()) + if (T->isIncompleteType()) + return fixInCompleteType(*T); + } + } + } else if (auto LastTypo = RecordTypo->lastTypo()) { + // Try to fix typos caused by missing declaraion. + // E.g. + // clang::SourceManager SM; + // ~~~~~~~~~~~~~ + // Typo + // or + // namespace clang { SourceManager SM; } + // ~~~~~~~~~~~~~ + // Typo + // We only attempt to recover a diagnostic if it has the same location as + // the last seen typo. + if (DiagLevel >= DiagnosticsEngine::Error && + LastTypo->Loc == Info.getLocation()) + return fixTypo(*LastTypo); + } + return {}; +} + +std::vector IncludeFixer::fixInCompleteType(const Type &T) const { + // Only handle incomplete TagDecl type. + const TagDecl *TD = T.getAsTagDecl(); + if (!TD) + return {}; + std::string IncompleteType = printQualifiedName(*TD); + + if (IncompleteType.empty()) { + vlog("No incomplete type name is found in diagnostic. Ignore."); + return {}; + } + vlog("Trying to fix include for incomplete type {0}", IncompleteType); + FuzzyFindRequest Req; + Req.AnyScope = false; + auto ScopeAndName = splitQualifiedName(IncompleteType); + Req.Scopes.push_back(ScopeAndName.first); + Req.Query = ScopeAndName.second; + // Only code completion symbols insert includes. + Req.RestrictForCodeCompletion = true; + llvm::Optional Matched; + Index.fuzzyFind(Req, [&](const Symbol &Sym) { + // FIXME: support multiple matched symbols. + if (Matched || Sym.Name != Req.Query) + return; + Matched = Sym; + }); + + if (!Matched || Matched->IncludeHeaders.empty()) + return {}; + return fixesForSymbol(*Matched); +} + +std::vector IncludeFixer::fixesForSymbol(const Symbol &Sym) const { + auto Inserted = [&](llvm::StringRef Header) + -> llvm::Expected> { + auto ResolvedDeclaring = + toHeaderFile(Sym.CanonicalDeclaration.FileURI, File); + if (!ResolvedDeclaring) + return ResolvedDeclaring.takeError(); + auto ResolvedInserted = toHeaderFile(Header, File); + if (!ResolvedInserted) + return ResolvedInserted.takeError(); + return std::make_pair( + Inserter->calculateIncludePath(*ResolvedDeclaring, *ResolvedInserted), + Inserter->shouldInsertInclude(*ResolvedDeclaring, *ResolvedInserted)); + }; + + std::vector Fixes; + for (const auto &Inc : getRankedIncludes(Sym)) { + if (auto ToInclude = Inserted(Inc)) { + if (ToInclude->second) + if (auto Edit = Inserter->insert(ToInclude->first)) + Fixes.push_back( + Fix{llvm::formatv("Add include {0} for symbol {1}{2}", + ToInclude->first, Sym.Scope, Sym.Name), + {std::move(*Edit)}}); + } else { + vlog("Failed to calculate include insertion for {0} into {1}: {2}", File, + Inc, llvm::toString(ToInclude.takeError())); + } + } + return Fixes; +} + +TypoCorrection IncludeFixer::TypoRecorder::CorrectTypo( + const DeclarationNameInfo &Typo, int LookupKind, Scope *S, CXXScopeSpec *SS, + CorrectionCandidateCallback &CCC, DeclContext *MemberContext, + bool EnteringContext, const ObjCObjectPointerType *OPT) { + if (Compiler.getSema().isSFINAEContext()) + return TypoCorrection(); + if (!Compiler.getSourceManager().isWrittenInMainFile(Typo.getLoc())) + return clang::TypoCorrection(); + + TypoRecord Record; + Record.Typo = Typo.getAsString(); + Record.Loc = Typo.getBeginLoc(); + assert(S); + Record.S = S; + Record.LookupKind = static_cast(LookupKind); + + // FIXME: support invalid scope before a type name. In the following example, + // namespace "clang::tidy::" hasn't been declared/imported. + // namespace clang { + // void f() { + // tidy::Check c; + // ~~~~ + // // or + // clang::tidy::Check c; + // ~~~~ + // } + // } + // For both cases, the typo and the diagnostic are both on "tidy", and no + // diagnostic is generated for "Check". However, what we want to fix is + // "clang::tidy::Check". + if (SS && SS->isNotEmpty()) { // "::" or "ns::" + if (auto *Nested = SS->getScopeRep()) { + if (Nested->getKind() == NestedNameSpecifier::Global) + Record.SS = ""; + else if (const auto *NS = Nested->getAsNamespace()) + Record.SS = printNamespaceScope(*NS); + else + // We don't fix symbols in scopes that are not top-level e.g. class + // members, as we don't collect includes for them. + return TypoCorrection(); + } + } + + LastTypo = std::move(Record); + + return TypoCorrection(); +} + +std::vector IncludeFixer::fixTypo(const TypoRecord &Typo) const { + std::vector Scopes; + if (Typo.SS) { + Scopes.push_back(*Typo.SS); + } else { + // No scope qualifier is specified. Collect all accessible scopes in the + // context. + VisitedContextCollector Collector; + Compiler.getSema().LookupVisibleDecls(Typo.S, Typo.LookupKind, Collector, + /*IncludeGlobalScope=*/false, + /*LoadExternal=*/false); + + Scopes.push_back(""); + for (const auto *Ctx : Collector.takeVisitedContexts()) + if (isa(Ctx)) + Scopes.push_back(printNamespaceScope(*Ctx)); + } + vlog("Trying to fix typo \"{0}\" in scopes: [{1}]", Typo.Typo, + llvm::join(Scopes.begin(), Scopes.end(), ", ")); + + FuzzyFindRequest Req; + Req.AnyScope = false; + Req.Query = Typo.Typo; + Req.Scopes = Scopes; + Req.RestrictForCodeCompletion = true; + + SymbolSlab::Builder Matches; + Index.fuzzyFind(Req, [&](const Symbol &Sym) { + if (Sym.Name != Req.Query) + return; + if (!Sym.IncludeHeaders.empty()) + Matches.insert(Sym); + }); + auto Syms = std::move(Matches).build(); + if (Syms.empty()) + return {}; + std::vector Results; + for (const auto &Sym : Syms) { + auto Fixes = fixesForSymbol(Sym); + Results.insert(Results.end(), Fixes.begin(), Fixes.end()); + } + return Results; +} + +} // namespace clangd +} // namespace clang Index: clangd/SourceCode.h =================================================================== --- clangd/SourceCode.h +++ clangd/SourceCode.h @@ -16,7 +16,9 @@ #include "clang/Basic/Diagnostic.h" #include "clang/Basic/SourceLocation.h" #include "clang/Basic/SourceManager.h" +#include "clang/Format/Format.h" #include "clang/Tooling/Core/Replacement.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/SHA1.h" namespace clang { @@ -91,6 +93,11 @@ const SourceManager &SourceMgr); bool IsRangeConsecutive(const Range &Left, const Range &Right); + +format::FormatStyle getFormatStyleForFile(llvm::StringRef File, + llvm::StringRef Content, + llvm::vfs::FileSystem *FS); + } // namespace clangd } // namespace clang #endif Index: clangd/SourceCode.cpp =================================================================== --- clangd/SourceCode.cpp +++ clangd/SourceCode.cpp @@ -248,5 +248,18 @@ return digest(Content); } +format::FormatStyle getFormatStyleForFile(llvm::StringRef File, + llvm::StringRef Content, + llvm::vfs::FileSystem *FS) { + auto Style = format::getStyle(format::DefaultFormatStyle, File, + format::DefaultFallbackStyle, Content, FS); + if (!Style) { + log("getStyle() failed for file {0}: {1}. Fallback is LLVM style.", File, + Style.takeError()); + Style = format::getLLVMStyle(); + } + return *Style; +} + } // namespace clangd } // namespace clang Index: unittests/clangd/CMakeLists.txt =================================================================== --- unittests/clangd/CMakeLists.txt +++ unittests/clangd/CMakeLists.txt @@ -28,6 +28,7 @@ FuzzyMatchTests.cpp GlobalCompilationDatabaseTests.cpp HeadersTests.cpp + IncludeFixerTests.cpp IndexActionTests.cpp IndexTests.cpp JSONTransportTests.cpp Index: unittests/clangd/FileIndexTests.cpp =================================================================== --- unittests/clangd/FileIndexTests.cpp +++ unittests/clangd/FileIndexTests.cpp @@ -360,10 +360,10 @@ /*StoreInMemory=*/true, [&](ASTContext &Ctx, std::shared_ptr PP) {}); // Build AST for main file with preamble. - auto AST = - ParsedAST::build(createInvocationFromCommandLine(Cmd), PreambleData, - llvm::MemoryBuffer::getMemBufferCopy(Main.code()), - std::make_shared(), PI.FS); + auto AST = ParsedAST::build( + createInvocationFromCommandLine(Cmd), PreambleData, + llvm::MemoryBuffer::getMemBufferCopy(Main.code()), + std::make_shared(), PI.FS, /*Index=*/nullptr); ASSERT_TRUE(AST); FileIndex Index; Index.updateMain(MainFile, *AST); Index: unittests/clangd/IncludeFixerTests.cpp =================================================================== --- /dev/null +++ unittests/clangd/IncludeFixerTests.cpp @@ -0,0 +1,156 @@ +//===-- ClangdUnitTests.cpp - ClangdUnit tests ------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "Annotations.h" +#include "ClangdUnit.h" +#include "IncludeFixer.h" +#include "TestIndex.h" +#include "TestTU.h" +#include "index/MemIndex.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ScopedPrinter.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace clang { +namespace clangd { +namespace { + +using testing::UnorderedElementsAre; + +testing::Matcher WithFix(testing::Matcher FixMatcher) { + return Field(&Diag::Fixes, UnorderedElementsAre(FixMatcher)); +} + +testing::Matcher WithFix(testing::Matcher FixMatcher1, + testing::Matcher FixMatcher2) { + return Field(&Diag::Fixes, UnorderedElementsAre(FixMatcher1, FixMatcher2)); +} + +MATCHER_P2(Diag, Range, Message, + "Diag at " + llvm::to_string(Range) + " = [" + Message + "]") { + return arg.Range == Range && arg.Message == Message; +} + +MATCHER_P3(Fix, Range, Replacement, Message, + "Fix " + llvm::to_string(Range) + " => " + + testing::PrintToString(Replacement) + " = [" + Message + "]") { + return arg.Message == Message && arg.Edits.size() == 1 && + arg.Edits[0].range == Range && arg.Edits[0].newText == Replacement; +} + +struct SymbolWithHeader { + std::string QName; + std::string DeclaringFile; + std::string IncludeHeader; +}; + +std::unique_ptr buildIndexWithSymbol(llvm::ArrayRef Syms) { + SymbolSlab::Builder Slab; + for (const auto &S : Syms) { + Symbol Sym = symbol(S.QName); + Sym.Flags |= Symbol::IndexedForCodeCompletion; + Sym.CanonicalDeclaration.FileURI = S.DeclaringFile.c_str(); + Sym.IncludeHeaders.emplace_back(S.IncludeHeader, 1); + Slab.insert(Sym); + } + return MemIndex::build(std::move(Slab).build(), RefSlab()); +} + +TEST(IncludeFixerTest, IncompleteType) { + Annotations Test(R"cpp( +$insert[[]]namespace ns { + class X; +} +class Y : $base[[public ns::X]] {}; +int main() { + ns::X *x; + x$access[[->]]f(); +} + )cpp"); + auto TU = TestTU::withCode(Test.code()); + auto Index = buildIndexWithSymbol( + {SymbolWithHeader{"ns::X", "unittest:///x.h", "\"x.h\""}}); + TU.ExternalIndex = Index.get(); + + EXPECT_THAT( + TU.build().getDiagnostics(), + UnorderedElementsAre( + AllOf(Diag(Test.range("base"), "base class has incomplete type"), + WithFix(Fix(Test.range("insert"), "#include \"x.h\"\n", + "Add include \"x.h\" for symbol ns::X"))), + AllOf(Diag(Test.range("access"), + "member access into incomplete type 'ns::X'"), + WithFix(Fix(Test.range("insert"), "#include \"x.h\"\n", + "Add include \"x.h\" for symbol ns::X"))))); +} + +TEST(IncludeFixerTest, Typo) { + Annotations Test(R"cpp( +$insert[[]]namespace ns { +void foo() { + $unqualified[[X]] x; +} +} +void bar() { + ns::$qualified[[X]] x; // ns:: is valid. + ::$global[[Global]] glob; +} + )cpp"); + auto TU = TestTU::withCode(Test.code()); + auto Index = buildIndexWithSymbol( + {SymbolWithHeader{"ns::X", "unittest:///x.h", "\"x.h\""}, + SymbolWithHeader{"Global", "unittest:///global.h", "\"global.h\""}}); + TU.ExternalIndex = Index.get(); + + EXPECT_THAT( + TU.build().getDiagnostics(), + UnorderedElementsAre( + AllOf(Diag(Test.range("unqualified"), "unknown type name 'X'"), + WithFix(Fix(Test.range("insert"), "#include \"x.h\"\n", + "Add include \"x.h\" for symbol ns::X"))), + AllOf(Diag(Test.range("qualified"), + "no type named 'X' in namespace 'ns'"), + WithFix(Fix(Test.range("insert"), "#include \"x.h\"\n", + "Add include \"x.h\" for symbol ns::X"))), + AllOf(Diag(Test.range("global"), + "no type named 'Global' in the global namespace"), + WithFix(Fix(Test.range("insert"), "#include \"global.h\"\n", + "Add include \"global.h\" for symbol Global"))))); +} + +TEST(IncludeFixerTest, MultipleMatchedSymbols) { + Annotations Test(R"cpp( +$insert[[]]namespace na { +namespace nb { +void foo() { + $unqualified[[X]] x; +} +} +} + )cpp"); + auto TU = TestTU::withCode(Test.code()); + auto Index = buildIndexWithSymbol( + {SymbolWithHeader{"na::X", "unittest:///a.h", "\"a.h\""}, + SymbolWithHeader{"na::nb::X", "unittest:///b.h", "\"b.h\""}}); + TU.ExternalIndex = Index.get(); + + EXPECT_THAT(TU.build().getDiagnostics(), + UnorderedElementsAre(AllOf( + Diag(Test.range("unqualified"), "unknown type name 'X'"), + WithFix(Fix(Test.range("insert"), "#include \"a.h\"\n", + "Add include \"a.h\" for symbol na::X"), + Fix(Test.range("insert"), "#include \"b.h\"\n", + "Add include \"b.h\" for symbol na::nb::X"))))); +} + +} // namespace +} // namespace clangd +} // namespace clang Index: unittests/clangd/TUSchedulerTests.cpp =================================================================== --- unittests/clangd/TUSchedulerTests.cpp +++ unittests/clangd/TUSchedulerTests.cpp @@ -37,8 +37,9 @@ class TUSchedulerTests : public ::testing::Test { protected: ParseInputs getInputs(PathRef File, std::string Contents) { - return ParseInputs{*CDB.getCompileCommand(File), - buildTestFS(Files, Timestamps), std::move(Contents)}; + return ParseInputs(*CDB.getCompileCommand(File), + buildTestFS(Files, Timestamps), std::move(Contents), + /*Index=*/nullptr); } void updateWithCallback(TUScheduler &S, PathRef File, Index: unittests/clangd/TestTU.h =================================================================== --- unittests/clangd/TestTU.h +++ unittests/clangd/TestTU.h @@ -48,6 +48,9 @@ // Extra arguments for the compiler invocation. std::vector ExtraArgs; + // Index to use when building AST. + const SymbolIndex *ExternalIndex = nullptr; + ParsedAST build() const; SymbolSlab headerSymbols() const; std::unique_ptr index() const; Index: unittests/clangd/TestTU.cpp =================================================================== --- unittests/clangd/TestTU.cpp +++ unittests/clangd/TestTU.cpp @@ -35,6 +35,7 @@ Inputs.CompileCommand.Directory = testRoot(); Inputs.Contents = Code; Inputs.FS = buildTestFS({{FullFilename, Code}, {FullHeaderName, HeaderCode}}); + Inputs.Index = ExternalIndex; auto PCHs = std::make_shared(); auto CI = buildCompilerInvocation(Inputs); assert(CI && "Failed to build compilation invocation.");