diff --git a/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp b/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp --- a/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp +++ b/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp @@ -82,6 +82,36 @@ llvm_unreachable("unhandled Symbol kind!"); } +// Special-case the ambiguous standard library symbols (e.g. std::move) which +// are not supported by the tooling stdlib lib. +llvm::SmallVector
specialStandardSymbols(const Symbol &S) { + if (S.kind() != Symbol::Declaration || !S.declaration().isInStdNamespace()) + return {}; + + const auto *FD = S.declaration().getAsFunction(); + if (!FD) + return {}; + + llvm::StringRef FName = FD->getName(); + if (FName == "move") { + if (FD->getNumParams() == 1) + // move(T&& t) + return {Header("")}; + if (FD->getNumParams() == 3) + // move(InputIt first, InputIt last, OutputIt dest); + return {Header("")}; + } + if (FName == "remove") { + if (FD->getNumParams() == 1) + // remove(const char*); + return {Header("")}; + if (FD->getNumParams() == 3) + // remove(ForwardIt first, ForwardIt last, const T& value); + return {Header("")}; + } + return {}; +} + } // namespace llvm::SmallVector> findHeaders(const SymbolLocation &Loc, @@ -141,6 +171,9 @@ llvm::SmallVector
headersForSymbol(const Symbol &S, const SourceManager &SM, const PragmaIncludes *PI) { + if (auto Headers = specialStandardSymbols(S); !Headers.empty()) + return Headers; + // Get headers for all the locations providing Symbol. Same header can be // reached through different traversals, deduplicate those into a single // Header by merging their hints. diff --git a/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp b/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp --- a/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp +++ b/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp @@ -278,25 +278,31 @@ class HeadersForSymbolTest : public FindHeadersTest { protected: - llvm::SmallVector
headersForFoo() { + llvm::SmallVector
headersFor(llvm::StringRef Name) { struct Visitor : public RecursiveASTVisitor { const NamedDecl *Out = nullptr; + llvm::StringRef Name; + Visitor(llvm::StringRef Name) : Name(Name) {} bool VisitNamedDecl(const NamedDecl *ND) { - if (ND->getName() == "foo") { + if (auto *TD = ND->getDescribedTemplate()) + ND = TD; + + if (ND->getName() == Name) { EXPECT_TRUE(Out == nullptr || Out == ND->getCanonicalDecl()) - << "Found multiple matches for foo."; + << "Found multiple matches for " << Name << "."; Out = cast(ND->getCanonicalDecl()); } return true; } }; - Visitor V; + Visitor V(Name); V.TraverseDecl(AST->context().getTranslationUnitDecl()); if (!V.Out) - ADD_FAILURE() << "Couldn't find any decls named foo."; + ADD_FAILURE() << "Couldn't find any decls named " << Name << "."; assert(V.Out); return headersForSymbol(*V.Out, AST->sourceManager(), &PI); } + llvm::SmallVector
headersForFoo() { return headersFor("foo"); } }; TEST_F(HeadersForSymbolTest, Deduplicates) { @@ -430,5 +436,54 @@ EXPECT_THAT(headersForFoo(), ElementsAre(Header(StringRef("\"public.h\"")), physicalHeader("foo.h"))); } + +TEST_F(HeadersForSymbolTest, AmbiguousStdSymbols) { + struct { + llvm::StringRef Code; + llvm::StringRef Name; + + llvm::StringRef ExpectedHeader; + } TestCases[] = { + { + R"cpp( + namespace std { + template + constexpr OutputIt move(InputIt first, InputIt last, OutputIt dest); + })cpp", + "move", + "", + }, + { + R"cpp( + namespace std { + template constexpr T move(T&& t) noexcept; + })cpp", + "move", + "", + }, + { + R"cpp( + namespace std { + template + ForwardIt remove(ForwardIt first, ForwardIt last, const T& value); + })cpp", + "remove", + "", + }, + { + "namespace std { int remove(const char*); }", + "remove", + "", + }, + }; + + for (const auto &T : TestCases) { + Inputs.Code = T.Code; + buildAST(); + EXPECT_THAT(headersFor(T.Name), + UnorderedElementsAre(Header(T.ExpectedHeader))); + } +} + } // namespace } // namespace clang::include_cleaner