diff --git a/clang-tools-extra/clangd/Hover.cpp b/clang-tools-extra/clangd/Hover.cpp --- a/clang-tools-extra/clangd/Hover.cpp +++ b/clang-tools-extra/clangd/Hover.cpp @@ -1172,20 +1172,12 @@ UsedSymbols.contains(Ref.Target)) return; - for (const include_cleaner::Header &H : Providers) { - auto MatchingIncludes = ConvertedMainFileIncludes.match(H); - // No match for this provider in the main file. - if (MatchingIncludes.empty()) - continue; - - // Check if the hovered include matches this provider. - if (!HoveredInclude.match(H).empty()) - UsedSymbols.insert(Ref.Target); - - // Don't look for rest of the providers once we've found a match - // in the main file. - break; - } + auto Provider = + firstMatchedProvider(ConvertedMainFileIncludes, Providers); + if (!Provider || HoveredInclude.match(*Provider).empty()) + return; + + UsedSymbols.insert(Ref.Target); }); for (const auto &UsedSymbolDecl : UsedSymbols) diff --git a/clang-tools-extra/clangd/IncludeCleaner.h b/clang-tools-extra/clangd/IncludeCleaner.h --- a/clang-tools-extra/clangd/IncludeCleaner.h +++ b/clang-tools-extra/clangd/IncludeCleaner.h @@ -81,6 +81,11 @@ std::vector collectMacroReferences(ParsedAST &AST); + +/// Find the first provider in the list that is matched by the includes. +std::optional +firstMatchedProvider(const include_cleaner::Includes &Includes, + llvm::ArrayRef Providers); } // namespace clangd } // namespace clang diff --git a/clang-tools-extra/clangd/IncludeCleaner.cpp b/clang-tools-extra/clangd/IncludeCleaner.cpp --- a/clang-tools-extra/clangd/IncludeCleaner.cpp +++ b/clang-tools-extra/clangd/IncludeCleaner.cpp @@ -444,5 +444,15 @@ return Result; } +std::optional +firstMatchedProvider(const include_cleaner::Includes &Includes, + llvm::ArrayRef Providers) { + for (const auto &H : Providers) { + if (!Includes.match(H).empty()) + return H; + } + // No match for this provider in the includes list. + return std::nullopt; +} } // namespace clangd } // namespace clang diff --git a/clang-tools-extra/clangd/XRefs.cpp b/clang-tools-extra/clangd/XRefs.cpp --- a/clang-tools-extra/clangd/XRefs.cpp +++ b/clang-tools-extra/clangd/XRefs.cpp @@ -9,13 +9,17 @@ #include "AST.h" #include "FindSymbols.h" #include "FindTarget.h" +#include "Headers.h" #include "HeuristicResolver.h" +#include "IncludeCleaner.h" #include "ParsedAST.h" #include "Protocol.h" #include "Quality.h" #include "Selection.h" #include "SourceCode.h" #include "URI.h" +#include "clang-include-cleaner/Analysis.h" +#include "clang-include-cleaner/Types.h" #include "index/Index.h" #include "index/Merge.h" #include "index/Relation.h" @@ -48,6 +52,7 @@ #include "clang/Index/IndexingAction.h" #include "clang/Index/IndexingOptions.h" #include "clang/Index/USRGeneration.h" +#include "clang/Lex/Lexer.h" #include "clang/Tooling/Syntax/Tokens.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -61,6 +66,7 @@ #include "llvm/Support/Path.h" #include "llvm/Support/raw_ostream.h" #include +#include #include namespace clang { @@ -1310,6 +1316,63 @@ return printQualifiedName(*ND); return {}; } + +std::optional +maybeFindIncludeReferences(ParsedAST &AST, Position Pos, + URIForFile URIMainFile) { + const auto &Includes = AST.getIncludeStructure().MainFileIncludes; + auto IncludeOnLine = llvm::find_if(Includes, [&Pos](const Inclusion &Inc) { + return Inc.HashLine == Pos.line; + }); + if (IncludeOnLine == Includes.end()) + return std::nullopt; + + const auto &Inc = *IncludeOnLine; + const SourceManager &SM = AST.getSourceManager(); + ReferencesResult Results; + auto ConvertedMainFileIncludes = convertIncludes(SM, Includes); + auto ReferencedInclude = convertIncludes(SM, Inc); + include_cleaner::walkUsed( + AST.getLocalTopLevelDecls(), collectMacroReferences(AST), + AST.getPragmaIncludes(), SM, + [&](const include_cleaner::SymbolReference &Ref, + llvm::ArrayRef Providers) { + if (Ref.RT != include_cleaner::RefType::Explicit) + return; + + auto Provider = + firstMatchedProvider(ConvertedMainFileIncludes, Providers); + if (!Provider || ReferencedInclude.match(*Provider).empty()) + return; + + auto Loc = SM.getFileLoc(Ref.RefLocation); + // File locations can be outside of the main file if macro is + // expanded through an #include. + while (SM.getFileID(Loc) != SM.getMainFileID()) + Loc = SM.getIncludeLoc(SM.getFileID(Loc)); + + ReferencesResult::Reference Result; + const auto *Token = AST.getTokens().spelledTokenAt(Loc); + Result.Loc.range = Range{sourceLocToPosition(SM, Token->location()), + sourceLocToPosition(SM, Token->endLocation())}; + Result.Loc.uri = URIMainFile; + Results.References.push_back(std::move(Result)); + }); + if (Results.References.empty()) + return std::nullopt; + + // Add the #include line to the references list. + auto IncludeLen = std::string{"#include"}.length() + Inc.Written.length() + 1; + ReferencesResult::Reference Result; + Result.Loc.range = clangd::Range{Position{Inc.HashLine, 0}, + Position{Inc.HashLine, (int)IncludeLen}}; + Result.Loc.uri = URIMainFile; + Results.References.push_back(std::move(Result)); + + if (Results.References.empty()) + return std::nullopt; + return Results; +} } // namespace ReferencesResult findReferences(ParsedAST &AST, Position Pos, uint32_t Limit, @@ -1324,6 +1387,11 @@ return {}; } + const auto IncludeReferences = + maybeFindIncludeReferences(AST, Pos, URIMainFile); + if (IncludeReferences) + return *IncludeReferences; + llvm::DenseSet IDsToQuery, OverriddenMethods; const auto *IdentifierAtCursor = diff --git a/clang-tools-extra/clangd/unittests/HoverTests.cpp b/clang-tools-extra/clangd/unittests/HoverTests.cpp --- a/clang-tools-extra/clangd/unittests/HoverTests.cpp +++ b/clang-tools-extra/clangd/unittests/HoverTests.cpp @@ -2999,36 +2999,7 @@ #in^clude std::vector vec; )cpp", - [](HoverInfo &HI) { HI.UsedSymbolNames = {"vector"}; }}, - {R"cpp( - #in^clude "public.h" - #include "private.h" - int fooVar = foo(); - )cpp", - [](HoverInfo &HI) { HI.UsedSymbolNames = {"foo"}; }}, - {R"cpp( - #include "bar.h" - #include "for^ward.h" - Bar *x; - )cpp", - [](HoverInfo &HI) { - HI.UsedSymbolNames = { - // No used symbols, since bar.h is a higher-ranked - // provider for Bar. - }; - }}, - {R"cpp( - #include "b^ar.h" - #define DEF(X) const Bar *X - DEF(a); - )cpp", - [](HoverInfo &HI) { HI.UsedSymbolNames = {"Bar"}; }}, - {R"cpp( - #in^clude "bar.h" - #define BAZ(X) const X x - BAZ(Bar); - )cpp", - [](HoverInfo &HI) { HI.UsedSymbolNames = {"Bar"}; }}}; + [](HoverInfo &HI) { HI.UsedSymbolNames = {"vector"}; }}}; for (const auto &Case : Cases) { Annotations Code{Case.Code}; SCOPED_TRACE(Code.code()); @@ -3042,18 +3013,12 @@ int bar2(); class Bar {}; )cpp"); - TU.AdditionalFiles["private.h"] = guard(R"cpp( - // IWYU pragma: private, include "public.h" - int foo(); - )cpp"); - TU.AdditionalFiles["public.h"] = guard(""); TU.AdditionalFiles["system/vector"] = guard(R"cpp( namespace std { template class vector{}; } )cpp"); - TU.AdditionalFiles["forward.h"] = guard("class Bar;"); TU.ExtraArgs.push_back("-isystem" + testPath("system")); auto AST = TU.build(); diff --git a/clang-tools-extra/clangd/unittests/IncludeCleanerTests.cpp b/clang-tools-extra/clangd/unittests/IncludeCleanerTests.cpp --- a/clang-tools-extra/clangd/unittests/IncludeCleanerTests.cpp +++ b/clang-tools-extra/clangd/unittests/IncludeCleanerTests.cpp @@ -29,6 +29,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include +#include #include #include #include @@ -435,6 +436,48 @@ MainCode.range()); } +TEST(IncludeCleaner, FirstMatchedProvider) { + struct { + const char *Code; + const std::vector Providers; + const std::optional ExpectedProvider; + } Cases[] = { + {R"cpp( + #include "bar.h" + #include "foo.h" + )cpp", + {include_cleaner::Header{"bar.h"}, include_cleaner::Header{"foo.h"}}, + include_cleaner::Header{"bar.h"}}, + {R"cpp( + #include "bar.h" + #include "foo.h" + )cpp", + {include_cleaner::Header{"foo.h"}, include_cleaner::Header{"bar.h"}}, + include_cleaner::Header{"foo.h"}}, + {"#include \"bar.h\"", + {include_cleaner::Header{"bar.h"}}, + include_cleaner::Header{"bar.h"}}, + {"#include \"bar.h\"", {include_cleaner::Header{"foo.h"}}, std::nullopt}, + {"#include \"bar.h\"", {}, std::nullopt}}; + for (const auto &Case : Cases) { + Annotations Code{Case.Code}; + SCOPED_TRACE(Code.code()); + + TestTU TU; + TU.Code = Code.code(); + TU.AdditionalFiles["bar.h"] = ""; + TU.AdditionalFiles["foo.h"] = ""; + + auto AST = TU.build(); + std::optional MatchedProvider = + firstMatchedProvider( + convertIncludes(AST.getSourceManager(), + AST.getIncludeStructure().MainFileIncludes), + Case.Providers); + EXPECT_EQ(MatchedProvider, Case.ExpectedProvider); + } +} + } // namespace } // namespace clangd } // namespace clang diff --git a/clang-tools-extra/clangd/unittests/XRefsTests.cpp b/clang-tools-extra/clangd/unittests/XRefsTests.cpp --- a/clang-tools-extra/clangd/unittests/XRefsTests.cpp +++ b/clang-tools-extra/clangd/unittests/XRefsTests.cpp @@ -43,6 +43,10 @@ using ::testing::UnorderedElementsAreArray; using ::testing::UnorderedPointwise; +std::string guard(llvm::StringRef Code) { + return "#pragma once\n" + Code.str(); +} + MATCHER_P2(FileRange, File, Range, "") { return Location{URIForFile::canonicalize(File, testRoot()), Range} == arg; } @@ -2293,6 +2297,50 @@ checkFindRefs(Test); } +TEST(FindReferences, UsedSymbolsFromInclude) { + const char *Tests[] = { + R"cpp([[#include ^"bar.h"]] + #include + int fstBar = [[bar1]](); + int sndBar = [[bar2]](); + [[Bar]] bar; + int macroBar = [[BAR]]; + std::vector vec; + )cpp", + + R"cpp([[#in^clude ]] + std::[[vector]] vec; + )cpp"}; + for (const char *Test : Tests) { + Annotations T(Test); + auto TU = TestTU::withCode(T.code()); + TU.ExtraArgs.push_back("-std=c++20"); + TU.AdditionalFiles["bar.h"] = guard(R"cpp( + #define BAR 5 + int bar1(); + int bar2(); + class Bar {}; + )cpp"); + TU.AdditionalFiles["system/vector"] = guard(R"cpp( + namespace std { + template + class vector{}; + } + )cpp"); + TU.ExtraArgs.push_back("-isystem" + testPath("system")); + + auto AST = TU.build(); + std::vector> ExpectedLocations; + for (const auto &R : T.ranges()) + ExpectedLocations.push_back(AllOf(rangeIs(R), attrsAre(0u))); + for (const auto &P : T.points()) + EXPECT_THAT(findReferences(AST, P, 0).References, + UnorderedElementsAreArray(ExpectedLocations)) + << "Failed for Refs at " << P << "\n" + << Test; + } +} + TEST(FindReferences, NeedsIndexForSymbols) { const char *Header = "int foo();"; Annotations Main("int main() { [[f^oo]](); }");