diff --git a/clang-tools-extra/clangd/index/CanonicalIncludes.cpp b/clang-tools-extra/clangd/index/CanonicalIncludes.cpp --- a/clang-tools-extra/clangd/index/CanonicalIncludes.cpp +++ b/clang-tools-extra/clangd/index/CanonicalIncludes.cpp @@ -711,13 +711,6 @@ Lang = tooling::stdlib::Lang::C; else return ""; - // FIXME: remove the following special cases when the tooling stdlib supports - // them. - // There are two std::move()s, this is by far the most common. - if (Scope == "std::" && Name == "move") - return ""; - if (Scope == "std::" && Name == "size_t") - return ""; if (auto StdSym = tooling::stdlib::Symbol::named(Scope, Name, Lang)) return StdSym->header().name(); return ""; diff --git a/clang/include/clang/Tooling/Inclusions/StandardLibrary.h b/clang/include/clang/Tooling/Inclusions/StandardLibrary.h --- a/clang/include/clang/Tooling/Inclusions/StandardLibrary.h +++ b/clang/include/clang/Tooling/Inclusions/StandardLibrary.h @@ -70,7 +70,10 @@ public: static std::vector all(Lang L = Lang::CXX); /// \p Scope should have the trailing "::", for example: - /// named("std::chrono::", "system_clock") + /// named("std::chrono::", "system_clock"). + /// + /// For ambiguous symbols (which can not be uniquely identified by the + /// qualified name), only returns the first one. static std::optional named(llvm::StringRef Scope, llvm::StringRef Name, Lang Language = Lang::CXX); diff --git a/clang/lib/Tooling/Inclusions/Stdlib/StandardLibrary.cpp b/clang/lib/Tooling/Inclusions/Stdlib/StandardLibrary.cpp --- a/clang/lib/Tooling/Inclusions/Stdlib/StandardLibrary.cpp +++ b/clang/lib/Tooling/Inclusions/Stdlib/StandardLibrary.cpp @@ -8,6 +8,7 @@ #include "clang/Tooling/Inclusions/StandardLibrary.h" #include "clang/AST/Decl.h" +#include "clang/AST/DeclTemplate.h" #include "clang/Basic/LangOptions.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseSet.h" @@ -35,11 +36,24 @@ const char *Data; // std::vector unsigned ScopeLen; // ~~~~~ unsigned NameLen; // ~~~~~~ + SymbolName() = default; + SymbolName(const char *Data, unsigned ScopeLen, unsigned NameLen, + bool Ambiguous, unsigned ParameterCount) + : Data(Data), ScopeLen(ScopeLen), NameLen(NameLen), + Ambiguous(Ambiguous), ParameterCount(ParameterCount) {} + StringRef scope() const { return StringRef(Data, ScopeLen); } StringRef name() const { return StringRef(Data + ScopeLen, NameLen); } StringRef qualifiedName() const { return StringRef(Data, ScopeLen + NameLen); } + + // True if symbol is ambiguous, which cannot be distinguish by the + // fully qualified name, e.g. std::move. + bool Ambiguous : 1; + // Extra symbol information for disambiguations. + // Only meaningful when IsAmbiguous is true. + unsigned ParameterCount : 4; } *SymbolNames = nullptr; // Symbol name -> Symbol::ID, within a namespace. llvm::DenseMap *NamespaceSymbols = nullptr; @@ -55,6 +69,7 @@ static int countSymbols(Lang Language) { llvm::DenseSet Set; + unsigned AmbiguousSymCount = 0; #define SYMBOL(Name, NS, Header) Set.insert(#NS #Name); switch (Language) { case Lang::C: @@ -62,10 +77,11 @@ break; case Lang::CXX: #include "StdSymbolMap.inc" + AmbiguousSymCount += 4; break; } #undef SYMBOL - return Set.size(); + return Set.size() + AmbiguousSymCount; } static int initialize(Lang Language) { @@ -95,7 +111,9 @@ }; auto Add = [&, SymIndex(-1)](llvm::StringRef QName, unsigned NSLen, - llvm::StringRef HeaderName) mutable { + llvm::StringRef HeaderName, + bool IsAmbiguous = false, + unsigned ParameterCount = 0) mutable { // Correct "Nonefoo" => foo. // FIXME: get rid of "None" from the generated mapping files. if (QName.take_front(NSLen) == "None") { @@ -104,7 +122,9 @@ } if (SymIndex >= 0 && - Mapping->SymbolNames[SymIndex].qualifiedName() == QName) { + Mapping->SymbolNames[SymIndex].qualifiedName() == QName && + Mapping->SymbolNames[SymIndex].Ambiguous == IsAmbiguous && + Mapping->SymbolNames[SymIndex].ParameterCount == ParameterCount) { // Not a new symbol, use the same index. assert(llvm::none_of(llvm::ArrayRef(Mapping->SymbolNames, SymIndex), [&QName](const SymbolHeaderMapping::SymbolName &S) { @@ -117,7 +137,8 @@ ++SymIndex; } Mapping->SymbolNames[SymIndex] = { - QName.data(), NSLen, static_cast(QName.size() - NSLen)}; + QName.data(), NSLen, static_cast(QName.size() - NSLen), + IsAmbiguous, ParameterCount}; Mapping->SymbolHeaderIDs[SymIndex].push_back(AddHeader(HeaderName)); NSSymbolMap &NSSymbols = AddNS(QName.take_front(NSLen)); @@ -130,6 +151,17 @@ break; case Lang::CXX: #include "StdSymbolMap.inc" + // !!NOTE!! when updating this list, please update the AmbiguousSymCount in + // countSymbols as well. + // FIXME: move this list to a separate .inc file. + Add("std::move", /*NSLen=*/5, "", /*IsAmbiguous=*/true, + /*ParameterCount=*/1); + Add("std::move", /*NSLen=*/5, "", /*IsAmbiguous=*/true, + /*ParameterCount=*/3); + Add("std::remove", /*NSLen=*/5, "", /*IsAmbiguous=*/true, + /*ParameterCount=*/3); + Add("std::remove", /*NSLen=*/5, "", /*IsAmbiguous=*/true, + /*ParameterCount=*/1); break; } #undef SYMBOL @@ -278,7 +310,29 @@ auto It = Symbols->find(Name); if (It == Symbols->end()) return std::nullopt; - return Symbol(It->second, L); + + unsigned SymIndex = It->second; + auto SymbolNames = ArrayRef(getMappingPerLang(L)->SymbolNames, + getMappingPerLang(L)->SymbolCount); + if (!SymbolNames[SymIndex].Ambiguous) + return Symbol(SymIndex, L); + + // Perform disambiguation. + const auto *FD = llvm::dyn_cast(D); + if (const auto *FTD = llvm::dyn_cast(D)) + FD = FTD->getTemplatedDecl(); + if (!FD) + return std::nullopt; + + auto ParameterCount = FD->getNumParams(); + auto QName = SymbolNames[SymIndex].qualifiedName(); + do { + if (ParameterCount == SymbolNames[SymIndex].ParameterCount) + return Symbol(SymIndex, L); + ++SymIndex; + } while (SymIndex < SymbolNames.size() && + SymbolNames[SymIndex].qualifiedName() == QName); + return std::nullopt; } } // namespace stdlib diff --git a/clang/unittests/Tooling/StandardLibraryTest.cpp b/clang/unittests/Tooling/StandardLibraryTest.cpp --- a/clang/unittests/Tooling/StandardLibraryTest.cpp +++ b/clang/unittests/Tooling/StandardLibraryTest.cpp @@ -10,6 +10,7 @@ #include "clang/AST/ASTContext.h" #include "clang/AST/Decl.h" #include "clang/AST/DeclarationName.h" +#include "clang/AST/RecursiveASTVisitor.h" #include "clang/Testing/TestAST.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" @@ -62,6 +63,9 @@ ElementsAre(stdlib::Header::named(""), stdlib::Header::named(""), stdlib::Header::named(""))); + // Ambiguous symbol, we return the first one which always map to + EXPECT_THAT(stdlib::Symbol::named("std::", "move")->headers(), + ElementsAre(stdlib::Header::named(""))); EXPECT_THAT(stdlib::Header::all(), Contains(*VectorH)); EXPECT_THAT(stdlib::Symbol::all(), Contains(*Vector)); @@ -94,7 +98,6 @@ struct vector { class nested {}; }; class secret {}; - } // inl inline namespace __1 { @@ -144,6 +147,73 @@ EXPECT_EQ(Recognizer(Sec), std::nullopt); } +TEST(StdlibTest, RecognizerAmbiguousSymbol) { + struct { + llvm::StringRef Code; + llvm::StringRef QName; + + llvm::StringRef ExpectedHeader; + } TestCases[] = { + { + R"cpp( + namespace std { + template + constexpr OutputIt move(InputIt first, InputIt last, OutputIt dest); + })cpp", + "std::move", + "", + }, + { + R"cpp( + namespace std { + template constexpr T move(T&& t) noexcept; + })cpp", + "std::move", + "", + }, + { + R"cpp( + namespace std { + template + ForwardIt remove(ForwardIt first, ForwardIt last, const T& value); + })cpp", + "std::remove", + "", + }, + { + "namespace std { int remove(const char*); }", + "std::remove", + "", + }, + }; + + struct DeclCapturer : RecursiveASTVisitor { + llvm::StringRef TargetQName; + const NamedDecl *Out = nullptr; + + DeclCapturer(llvm::StringRef TargetQName) : TargetQName(TargetQName) {} + bool VisitNamedDecl(const NamedDecl *ND) { + if (auto *TD = ND->getDescribedTemplate()) + ND = TD; + if (ND->getQualifiedNameAsString() == TargetQName) { + EXPECT_TRUE(Out == nullptr || Out == ND->getCanonicalDecl()) + << "Found multiple matches for " << TargetQName; + Out = cast(ND->getCanonicalDecl()); + } + return true; + } + }; + stdlib::Recognizer Recognizer; + for (auto &T : TestCases) { + TestAST AST(T.Code); + DeclCapturer V(T.QName); + V.TraverseDecl(AST.context().getTranslationUnitDecl()); + ASSERT_TRUE(V.Out); + EXPECT_THAT(Recognizer(V.Out)->headers(), + ElementsAre(stdlib::Header::named(T.ExpectedHeader))); + } +} + } // namespace } // namespace tooling } // namespace clang