diff --git a/clang-tools-extra/clangd/XRefs.cpp b/clang-tools-extra/clangd/XRefs.cpp --- a/clang-tools-extra/clangd/XRefs.cpp +++ b/clang-tools-extra/clangd/XRefs.cpp @@ -10,12 +10,15 @@ #include "FindSymbols.h" #include "FindTarget.h" #include "HeuristicResolver.h" +#include "IncludeCleaner.h" #include "ParsedAST.h" #include "Protocol.h" #include "Quality.h" #include "Selection.h" #include "SourceCode.h" #include "URI.h" +#include "clang-include-cleaner/Analysis.h" +#include "clang-include-cleaner/Types.h" #include "index/Index.h" #include "index/Merge.h" #include "index/Relation.h" @@ -48,6 +51,7 @@ #include "clang/Index/IndexingAction.h" #include "clang/Index/IndexingOptions.h" #include "clang/Index/USRGeneration.h" +#include "clang/Lex/Lexer.h" #include "clang/Tooling/Syntax/Tokens.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -61,6 +65,7 @@ #include "llvm/Support/Path.h" #include "llvm/Support/raw_ostream.h" #include +#include #include namespace clang { @@ -1324,6 +1329,58 @@ return {}; } + auto Includes = AST.getIncludeStructure().MainFileIncludes; + auto ConvertedMainFileIncludes = convertIncludes(SM, Includes); + for (auto &Inc : Includes) { + if (Inc.HashLine != Pos.line) + continue; + + auto ReferencedInclude = convertIncludes(SM, Inc); + include_cleaner::walkUsed( + AST.getLocalTopLevelDecls(), collectMacroReferences(AST), + AST.getPragmaIncludes(), SM, + [&](const include_cleaner::SymbolReference &Ref, + llvm::ArrayRef Providers) { + if (Ref.RT != include_cleaner::RefType::Explicit) + return; + + auto Loc = SM.getFileLoc(Ref.RefLocation); + for (const auto &H : Providers) { + auto MatchingIncludes = ConvertedMainFileIncludes.match(H); + // No match for this provider in the main file. + if (MatchingIncludes.empty()) + continue; + + // Check if the referenced include matches this provider. + if (!ReferencedInclude.match(H).empty()) { + ReferencesResult::Reference Result; + auto TokLen = + Lexer::MeasureTokenLength(Loc, SM, AST.getLangOpts()); + Result.Loc.range = + halfOpenToRange(SM, CharSourceRange::getCharRange( + Loc, Loc.getLocWithOffset(TokLen))); + Result.Loc.uri = URIMainFile; + Results.References.push_back(std::move(Result)); + } + + // Don't look for rest of the providers once we've found a match + // in the main file. + return; + } + }); + if (Results.References.empty()) + return {}; + + // Add the #include line to the references list. + auto IncludeLen = + std::string{"#include"}.length() + Inc.Written.length() + 1; + ReferencesResult::Reference Result; + Result.Loc.range = clangd::Range{Position{Inc.HashLine, 0}, + Position{Inc.HashLine, (int)IncludeLen}}; + Result.Loc.uri = URIMainFile; + Results.References.push_back(std::move(Result)); + } + llvm::DenseSet IDsToQuery, OverriddenMethods; const auto *IdentifierAtCursor = @@ -1944,15 +2001,15 @@ return QualType(); } -// Given a type targeted by the cursor, return one or more types that are more interesting -// to target. -static void unwrapFindType( - QualType T, const HeuristicResolver* H, llvm::SmallVector& Out) { +// Given a type targeted by the cursor, return one or more types that are more +// interesting to target. +static void unwrapFindType(QualType T, const HeuristicResolver *H, + llvm::SmallVector &Out) { if (T.isNull()) return; // If there's a specific type alias, point at that rather than unwrapping. - if (const auto* TDT = T->getAs()) + if (const auto *TDT = T->getAs()) return Out.push_back(QualType(TDT, 0)); // Pointers etc => pointee type. @@ -1968,30 +2025,31 @@ return unwrapFindType(FT->getReturnType(), H, Out); if (auto *CRD = T->getAsCXXRecordDecl()) { if (CRD->isLambda()) - return unwrapFindType(CRD->getLambdaCallOperator()->getReturnType(), H, Out); + return unwrapFindType(CRD->getLambdaCallOperator()->getReturnType(), H, + Out); // FIXME: more cases we'd prefer the return type of the call operator? // std::function etc? } // For smart pointer types, add the underlying type if (H) - if (const auto* PointeeType = H->getPointeeType(T.getNonReferenceType().getTypePtr())) { - unwrapFindType(QualType(PointeeType, 0), H, Out); - return Out.push_back(T); + if (const auto *PointeeType = + H->getPointeeType(T.getNonReferenceType().getTypePtr())) { + unwrapFindType(QualType(PointeeType, 0), H, Out); + return Out.push_back(T); } return Out.push_back(T); } // Convenience overload, to allow calling this without the out-parameter -static llvm::SmallVector unwrapFindType( - QualType T, const HeuristicResolver* H) { - llvm::SmallVector Result; - unwrapFindType(T, H, Result); - return Result; +static llvm::SmallVector unwrapFindType(QualType T, + const HeuristicResolver *H) { + llvm::SmallVector Result; + unwrapFindType(T, H, Result); + return Result; } - std::vector findType(ParsedAST &AST, Position Pos) { const SourceManager &SM = AST.getSourceManager(); auto Offset = positionToOffset(SM.getBufferData(SM.getMainFileID()), Pos); @@ -2007,11 +2065,13 @@ std::vector LocatedSymbols; // NOTE: unwrapFindType might return duplicates for something like - // unique_ptr>. Let's *not* remove them, because it gives you some - // information about the type you may have not known before - // (since unique_ptr> != unique_ptr). - for (const QualType& Type : unwrapFindType(typeForNode(N), AST.getHeuristicResolver())) - llvm::copy(locateSymbolForType(AST, Type), std::back_inserter(LocatedSymbols)); + // unique_ptr>. Let's *not* remove them, because it gives you + // some information about the type you may have not known before (since + // unique_ptr> != unique_ptr). + for (const QualType &Type : + unwrapFindType(typeForNode(N), AST.getHeuristicResolver())) + llvm::copy(locateSymbolForType(AST, Type), + std::back_inserter(LocatedSymbols)); return LocatedSymbols; }; diff --git a/clang-tools-extra/clangd/unittests/XRefsTests.cpp b/clang-tools-extra/clangd/unittests/XRefsTests.cpp --- a/clang-tools-extra/clangd/unittests/XRefsTests.cpp +++ b/clang-tools-extra/clangd/unittests/XRefsTests.cpp @@ -5,8 +5,8 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#include "Annotations.h" #include "AST.h" +#include "Annotations.h" #include "ParsedAST.h" #include "Protocol.h" #include "SourceCode.h" @@ -43,6 +43,10 @@ using ::testing::UnorderedElementsAreArray; using ::testing::UnorderedPointwise; +std::string guard(llvm::StringRef Code) { + return "#pragma once\n" + Code.str(); +} + MATCHER_P2(FileRange, File, Range, "") { return Location{URIForFile::canonicalize(File, testRoot()), Range} == arg; } @@ -1876,8 +1880,8 @@ ASSERT_GT(A.points().size(), 0u) << Case; for (auto Pos : A.points()) EXPECT_THAT(findType(AST, Pos), - ElementsAre( - sym("Target", HeaderA.range("Target"), HeaderA.range("Target")))) + ElementsAre(sym("Target", HeaderA.range("Target"), + HeaderA.range("Target")))) << Case; } @@ -1888,11 +1892,12 @@ TU.Code = A.code().str(); ParsedAST AST = TU.build(); - EXPECT_THAT(findType(AST, A.point()), - UnorderedElementsAre( - sym("Target", HeaderA.range("Target"), HeaderA.range("Target")), - sym("smart_ptr", HeaderA.range("smart_ptr"), HeaderA.range("smart_ptr")) - )) + EXPECT_THAT( + findType(AST, A.point()), + UnorderedElementsAre( + sym("Target", HeaderA.range("Target"), HeaderA.range("Target")), + sym("smart_ptr", HeaderA.range("smart_ptr"), + HeaderA.range("smart_ptr")))) << Case; } } @@ -1901,6 +1906,25 @@ Annotations T(Test); auto TU = TestTU::withCode(T.code()); TU.ExtraArgs.push_back("-std=c++20"); + TU.AdditionalFiles["bar.h"] = guard(R"cpp( + #define BAR 5 + int bar1(); + int bar2(); + class Bar {}; + )cpp"); + TU.AdditionalFiles["private.h"] = guard(R"cpp( + // IWYU pragma: private, include "public.h" + int foo(); + )cpp"); + TU.AdditionalFiles["public.h"] = guard(""); + TU.AdditionalFiles["system/vector"] = guard(R"cpp( + namespace std { + template + class vector{}; + } + )cpp"); + TU.AdditionalFiles["forward.h"] = guard("class Bar;"); + TU.ExtraArgs.push_back("-isystem" + testPath("system")); auto AST = TU.build(); std::vector> ExpectedLocations; @@ -2293,6 +2317,42 @@ checkFindRefs(Test); } +TEST(FindReferences, UsedSymbolsFromInclude) { + const char *Tests[] = { + R"cpp([[#include ^"bar.h"]] + int fstBar = [[bar1]](); + int sndBar = [[bar2]](); + [[Bar]] bar; + int macroBar = [[BAR]]; + )cpp", + + R"cpp([[#in^clude ]] + std::[[vector]] vec; + )cpp", + + R"cpp([[#in^clude "public.h"]] + #include "private.h" + int fooVar = [[foo]](); + )cpp", + + R"cpp(#include "bar.h" + #include "for^ward.h" + Bar *x; + )cpp", + + R"cpp([[#include "b^ar.h"]] + #define DEF(X) const Bar *X + [[DEF]](a); + )cpp", + + R"cpp([[#in^clude "bar.h"]] + #define BAZ(X) const X x + BAZ([[Bar]]); + )cpp"}; + for (const char *Test : Tests) + checkFindRefs(Test); +} + TEST(FindReferences, NeedsIndexForSymbols) { const char *Header = "int foo();"; Annotations Main("int main() { [[f^oo]](); }");