Index: clangd/index/CanonicalIncludes.h =================================================================== --- clangd/index/CanonicalIncludes.h +++ clangd/index/CanonicalIncludes.h @@ -43,9 +43,17 @@ /// Maps all files matching \p RE to \p CanonicalPath void addRegexMapping(llvm::StringRef RE, llvm::StringRef CanonicalPath); + /// Sets the canonical include for any symbol with \p QualifiedName. + /// Symbol mappings take precedence over header mappings. + void addSymbolMapping(llvm::StringRef QualifiedName, + llvm::StringRef CanonicalPath); + /// \return \p Header itself if there is no mapping for it; otherwise, return /// a canonical header name. - llvm::StringRef mapHeader(llvm::StringRef Header) const; + /// \p QualifiedName of a symbol declared in \p Header can be provided to + /// check against the symbol mapping. + llvm::StringRef mapHeader(llvm::StringRef Header, + llvm::StringRef QualifiedName = "") const; private: // A map from header patterns to header names. This needs to be mutable so @@ -55,6 +63,8 @@ // arbitrary regexes. mutable std::vector> RegexHeaderMappingTable; + // A map from fully qualified symbol names to header names. + llvm::StringMap SymbolMapping; // Guards Regex matching as it's not thread-safe. mutable std::mutex RegexMutex; }; @@ -68,8 +78,9 @@ std::unique_ptr collectIWYUHeaderMaps(CanonicalIncludes *Includes); -/// Adds mapping for system headers. Approximately, the following system headers -/// are handled: +/// Adds mapping for system headers and some special symbols (e.g. STL symbols +/// in need to be mapped individually). Approximately, the following +/// system headers are handled: /// - C++ standard library e.g. bits/basic_string.h$ -> /// - Posix library e.g. bits/pthreadtypes.h$ -> /// - Compiler extensions, e.g. include/avx512bwintrin.h$ -> Index: clangd/index/CanonicalIncludes.cpp =================================================================== --- clangd/index/CanonicalIncludes.cpp +++ clangd/index/CanonicalIncludes.cpp @@ -27,7 +27,19 @@ this->RegexHeaderMappingTable.emplace_back(llvm::Regex(RE), CanonicalPath); } -llvm::StringRef CanonicalIncludes::mapHeader(llvm::StringRef Header) const { +void CanonicalIncludes::addSymbolMapping(llvm::StringRef QualifiedName, + llvm::StringRef CanonicalPath) { + this->SymbolMapping[QualifiedName] = CanonicalPath; +} + +llvm::StringRef +CanonicalIncludes::mapHeader(llvm::StringRef Header, + llvm::StringRef QualifiedName) const { + if (!QualifiedName.empty()) { + auto SE = SymbolMapping.find(QualifiedName); + if (SE != SymbolMapping.end()) + return SE->second; + } std::lock_guard Lock(RegexMutex); for (auto &Entry : RegexHeaderMappingTable) { #ifndef NDEBUG @@ -67,6 +79,53 @@ } void addSystemHeadersMapping(CanonicalIncludes *Includes) { + static const std::vector> SymbolMap = { + // Map symbols in to their preferred includes. + {"std::basic_filebuf", ""}, + {"std::basic_fstream", ""}, + {"std::basic_ifstream", ""}, + {"std::basic_ofstream", ""}, + {"std::filebuf", ""}, + {"std::fstream", ""}, + {"std::ifstream", ""}, + {"std::ofstream", ""}, + {"std::wfilebuf", ""}, + {"std::wfstream", ""}, + {"std::wifstream", ""}, + {"std::wofstream", ""}, + {"std::basic_ios", ""}, + {"std::ios", ""}, + {"std::wios", ""}, + {"std::basic_iostream", ""}, + {"std::iostream", ""}, + {"std::wiostream", ""}, + {"std::basic_istream", ""}, + {"std::istream", ""}, + {"std::wistream", ""}, + {"std::istreambuf_iterator", ""}, + {"std::ostreambuf_iterator", ""}, + {"std::basic_ostream", ""}, + {"std::ostream", ""}, + {"std::wostream", ""}, + {"std::basic_istringstream", ""}, + {"std::basic_ostringstream", ""}, + {"std::basic_stringbuf", ""}, + {"std::basic_stringstream", ""}, + {"std::istringstream", ""}, + {"std::ostringstream", ""}, + {"std::stringbuf", ""}, + {"std::stringstream", ""}, + {"std::wistringstream", ""}, + {"std::wostringstream", ""}, + {"std::wstringbuf", ""}, + {"std::wstringstream", ""}, + {"std::basic_streambuf", ""}, + {"std::streambuf", ""}, + {"std::wstreambuf", ""}, + }; + for (const auto &Pair : SymbolMap) + Includes->addSymbolMapping(Pair.first, Pair.second); + static const std::vector> SystemHeaderMap = { {"include/__stddef_max_align_t.h$", ""}, Index: clangd/index/SymbolCollector.cpp =================================================================== --- clangd/index/SymbolCollector.cpp +++ clangd/index/SymbolCollector.cpp @@ -154,13 +154,13 @@ /// FIXME: we should handle .inc files whose symbols are expected be exported by /// their containing headers. llvm::Optional -getIncludeHeader(const SourceManager &SM, SourceLocation Loc, - const SymbolCollector::Options &Opts) { +getIncludeHeader(llvm::StringRef QName, const SourceManager &SM, + SourceLocation Loc, const SymbolCollector::Options &Opts) { llvm::StringRef FilePath = SM.getFilename(Loc); if (FilePath.empty()) return llvm::None; if (Opts.Includes) { - llvm::StringRef Mapped = Opts.Includes->mapHeader(FilePath); + llvm::StringRef Mapped = Opts.Includes->mapHeader(FilePath, QName); if (Mapped != FilePath) return (Mapped.startswith("<") || Mapped.startswith("\"")) ? Mapped.str() @@ -316,8 +316,8 @@ if (Opts.CollectIncludePath && shouldCollectIncludePath(S.SymInfo.Kind)) { // Use the expansion location to get the #include header since this is // where the symbol is exposed. - if (auto Header = - getIncludeHeader(SM, SM.getExpansionLoc(ND.getLocation()), Opts)) + if (auto Header = getIncludeHeader( + QName, SM, SM.getExpansionLoc(ND.getLocation()), Opts)) Include = std::move(*Header); } S.CompletionFilterText = FilterText; Index: unittests/clangd/SymbolCollectorTests.cpp =================================================================== --- unittests/clangd/SymbolCollectorTests.cpp +++ unittests/clangd/SymbolCollectorTests.cpp @@ -547,6 +547,32 @@ } #endif +TEST_F(SymbolCollectorTest, STLiosfwd) { + CollectorOpts.CollectIncludePath = true; + CanonicalIncludes Includes; + addSystemHeadersMapping(&Includes); + CollectorOpts.Includes = &Includes; + // Symbols from should be mapped individually. + TestHeaderName = testPath("iosfwd"); + TestFileName = testPath("iosfwd.cpp"); + std::string Header = R"( + namespace std { + class no_map {}; + class ios {}; + class ostream {}; + class filebuf {}; + } // namespace std + )"; + runSymbolCollector(Header, /*Main=*/""); + EXPECT_THAT(Symbols, + UnorderedElementsAre( + QName("std"), + AllOf(QName("std::no_map"), IncludeHeader("")), + AllOf(QName("std::ios"), IncludeHeader("")), + AllOf(QName("std::ostream"), IncludeHeader("")), + AllOf(QName("std::filebuf"), IncludeHeader("")))); +} + TEST_F(SymbolCollectorTest, IWYUPragma) { CollectorOpts.CollectIncludePath = true; CanonicalIncludes Includes;