diff --git a/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp b/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp --- a/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp +++ b/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp @@ -82,16 +82,68 @@ llvm_unreachable("unhandled Symbol kind!"); } +Hints isPublicHeader(const FileEntry *FE, const PragmaIncludes &PI) { + if (PI.isPrivate(FE) || !PI.isSelfContained(FE)) + return Hints::None; + return Hints::PublicHeader; +} + +llvm::SmallVector> +hintedHeadersForStdHeaders(llvm::ArrayRef Headers, + const SourceManager &SM, const PragmaIncludes *PI) { + llvm::SmallVector> Results; + for (const auto &H : Headers) { + Results.emplace_back(H, Hints::PublicHeader); + if (!PI) + continue; + for (const auto *Export : PI->getExporters(H, SM.getFileManager())) + Results.emplace_back(Header(Export), isPublicHeader(Export, *PI)); + } + // StandardLibrary returns headers in preference order, so only mark the + // first. + if (!Results.empty()) + Results.front().Hint |= Hints::PreferredHeader; + return Results; +} + +// Special-case the ambiguous standard library symbols (e.g. std::move) which +// are not supported by the tooling stdlib lib. +llvm::SmallVector> +headersForSpecialSymbol(const Symbol &S, const SourceManager &SM, + const PragmaIncludes *PI) { + if (S.kind() != Symbol::Declaration || !S.declaration().isInStdNamespace()) + return {}; + + const auto *FD = S.declaration().getAsFunction(); + if (!FD) + return {}; + + llvm::StringRef FName = FD->getName(); + llvm::SmallVector Headers; + if (FName == "move") { + if (FD->getNumParams() == 1) + // move(T&& t) + Headers.push_back(*tooling::stdlib::Header::named("")); + if (FD->getNumParams() == 3) + // move(InputIt first, InputIt last, OutputIt dest); + Headers.push_back(*tooling::stdlib::Header::named("")); + } else if (FName == "remove") { + if (FD->getNumParams() == 1) + // remove(const char*); + Headers.push_back(*tooling::stdlib::Header::named("")); + if (FD->getNumParams() == 3) + // remove(ForwardIt first, ForwardIt last, const T& value); + Headers.push_back(*tooling::stdlib::Header::named("")); + } + return applyHints(hintedHeadersForStdHeaders(Headers, SM, PI), + Hints::CompleteSymbol); +} + } // namespace llvm::SmallVector> findHeaders(const SymbolLocation &Loc, const SourceManager &SM, const PragmaIncludes *PI) { - auto IsPublicHeader = [&PI](const FileEntry *FE) { - return (PI->isPrivate(FE) || !PI->isSelfContained(FE)) - ? Hints::None - : Hints::PublicHeader; - }; llvm::SmallVector> Results; switch (Loc.kind()) { case SymbolLocation::Physical: { @@ -102,11 +154,11 @@ if (!PI) return {{FE, Hints::PublicHeader}}; while (FE) { - Hints CurrentHints = IsPublicHeader(FE); + Hints CurrentHints = isPublicHeader(FE, *PI); Results.emplace_back(FE, CurrentHints); // FIXME: compute transitive exporter headers. for (const auto *Export : PI->getExporters(FE, SM.getFileManager())) - Results.emplace_back(Export, IsPublicHeader(Export)); + Results.emplace_back(Export, isPublicHeader(Export, *PI)); if (auto Verbatim = PI->getPublic(FE); !Verbatim.empty()) { Results.emplace_back(Verbatim, @@ -123,16 +175,7 @@ return Results; } case SymbolLocation::Standard: { - for (const auto &H : Loc.standard().headers()) { - Results.emplace_back(H, Hints::PublicHeader); - for (const auto *Export : PI->getExporters(H, SM.getFileManager())) - Results.emplace_back(Header(Export), IsPublicHeader(Export)); - } - // StandardLibrary returns headers in preference order, so only mark the - // first. - if (!Results.empty()) - Results.front().Hint |= Hints::PreferredHeader; - return Results; + return hintedHeadersForStdHeaders(Loc.standard().headers(), SM, PI); } } llvm_unreachable("unhandled SymbolLocation kind!"); @@ -144,9 +187,11 @@ // Get headers for all the locations providing Symbol. Same header can be // reached through different traversals, deduplicate those into a single // Header by merging their hints. - llvm::SmallVector> Headers; - for (auto &Loc : locateSymbol(S)) - Headers.append(applyHints(findHeaders(Loc, SM, PI), Loc.Hint)); + llvm::SmallVector> Headers = + headersForSpecialSymbol(S, SM, PI); + if (Headers.empty()) + for (auto &Loc : locateSymbol(S)) + Headers.append(applyHints(findHeaders(Loc, SM, PI), Loc.Hint)); // If two Headers probably refer to the same file (e.g. Verbatim(foo.h) and // Physical(/path/to/foo.h), we won't deduplicate them or merge their hints llvm::stable_sort( diff --git a/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp b/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp --- a/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp +++ b/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp @@ -278,25 +278,31 @@ class HeadersForSymbolTest : public FindHeadersTest { protected: - llvm::SmallVector
headersForFoo() { + llvm::SmallVector
headersFor(llvm::StringRef Name) { struct Visitor : public RecursiveASTVisitor { const NamedDecl *Out = nullptr; + llvm::StringRef Name; + Visitor(llvm::StringRef Name) : Name(Name) {} bool VisitNamedDecl(const NamedDecl *ND) { - if (ND->getName() == "foo") { + if (auto *TD = ND->getDescribedTemplate()) + ND = TD; + + if (ND->getName() == Name) { EXPECT_TRUE(Out == nullptr || Out == ND->getCanonicalDecl()) - << "Found multiple matches for foo."; + << "Found multiple matches for " << Name << "."; Out = cast(ND->getCanonicalDecl()); } return true; } }; - Visitor V; + Visitor V(Name); V.TraverseDecl(AST->context().getTranslationUnitDecl()); if (!V.Out) - ADD_FAILURE() << "Couldn't find any decls named foo."; + ADD_FAILURE() << "Couldn't find any decls named " << Name << "."; assert(V.Out); return headersForSymbol(*V.Out, AST->sourceManager(), &PI); } + llvm::SmallVector
headersForFoo() { return headersFor("foo"); } }; TEST_F(HeadersForSymbolTest, Deduplicates) { @@ -430,5 +436,55 @@ EXPECT_THAT(headersForFoo(), ElementsAre(Header(StringRef("\"public.h\"")), physicalHeader("foo.h"))); } + +TEST_F(HeadersForSymbolTest, AmbiguousStdSymbols) { + struct { + llvm::StringRef Code; + llvm::StringRef Name; + + llvm::StringRef ExpectedHeader; + } TestCases[] = { + { + R"cpp( + namespace std { + template + constexpr OutputIt move(InputIt first, InputIt last, OutputIt dest); + })cpp", + "move", + "", + }, + { + R"cpp( + namespace std { + template constexpr T move(T&& t) noexcept; + })cpp", + "move", + "", + }, + { + R"cpp( + namespace std { + template + ForwardIt remove(ForwardIt first, ForwardIt last, const T& value); + })cpp", + "remove", + "", + }, + { + "namespace std { int remove(const char*); }", + "remove", + "", + }, + }; + + for (const auto &T : TestCases) { + Inputs.Code = T.Code; + buildAST(); + EXPECT_THAT(headersFor(T.Name), + UnorderedElementsAre( + Header(*tooling::stdlib::Header::named(T.ExpectedHeader)))); + } +} + } // namespace } // namespace clang::include_cleaner