diff --git a/clang-tools-extra/clangd/HeuristicResolver.h b/clang-tools-extra/clangd/HeuristicResolver.h --- a/clang-tools-extra/clangd/HeuristicResolver.h +++ b/clang-tools-extra/clangd/HeuristicResolver.h @@ -69,6 +69,11 @@ const Type * resolveNestedNameSpecifierToType(const NestedNameSpecifier *NNS) const; + // Given the type T of a dependent expression that appears of the LHS of a + // "->", heuristically find a corresponding pointee type in whose scope we + // could look up the name appearing on the RHS. + const Type *getPointeeType(const Type *T) const; + private: ASTContext &Ctx; @@ -89,11 +94,6 @@ // `E`. const Type *resolveExprToType(const Expr *E) const; std::vector resolveExprToDecls(const Expr *E) const; - - // Given the type T of a dependent expression that appears of the LHS of a - // "->", heuristically find a corresponding pointee type in whose scope we - // could look up the name appearing on the RHS. - const Type *getPointeeType(const Type *T) const; }; } // namespace clangd diff --git a/clang-tools-extra/clangd/XRefs.cpp b/clang-tools-extra/clangd/XRefs.cpp --- a/clang-tools-extra/clangd/XRefs.cpp +++ b/clang-tools-extra/clangd/XRefs.cpp @@ -9,6 +9,7 @@ #include "AST.h" #include "FindSymbols.h" #include "FindTarget.h" +#include "HeuristicResolver.h" #include "ParsedAST.h" #include "Protocol.h" #include "Quality.h" @@ -1907,36 +1908,43 @@ return QualType(); } -// Given a type targeted by the cursor, return a type that's more interesting +// Given a type targeted by the cursor, return one or more types that are more interesting // to target. -static QualType unwrapFindType(QualType T) { +static llvm::SmallVector unwrapFindType( + QualType T, const HeuristicResolver* H) { if (T.isNull()) - return T; + return {T}; // If there's a specific type alias, point at that rather than unwrapping. if (const auto* TDT = T->getAs()) - return QualType(TDT, 0); + return {QualType(TDT, 0)}; // Pointers etc => pointee type. if (const auto *PT = T->getAs()) - return unwrapFindType(PT->getPointeeType()); + return {unwrapFindType(PT->getPointeeType(), H)}; if (const auto *RT = T->getAs()) - return unwrapFindType(RT->getPointeeType()); + return {unwrapFindType(RT->getPointeeType(), H)}; if (const auto *AT = T->getAsArrayTypeUnsafe()) - return unwrapFindType(AT->getElementType()); + return {unwrapFindType(AT->getElementType(), H)}; // FIXME: use HeuristicResolver to unwrap smart pointers? // Function type => return type. if (auto *FT = T->getAs()) - return unwrapFindType(FT->getReturnType()); + return {unwrapFindType(FT->getReturnType(), H)}; if (auto *CRD = T->getAsCXXRecordDecl()) { if (CRD->isLambda()) - return unwrapFindType(CRD->getLambdaCallOperator()->getReturnType()); + return {unwrapFindType(CRD->getLambdaCallOperator()->getReturnType(), H)}; // FIXME: more cases we'd prefer the return type of the call operator? // std::function etc? } - return T; + // For smart pointer types, add the underlying type + llvm::SmallVector Result = {T}; + if (H) + if (const auto* PointeeType = H->getPointeeType(T.getNonReferenceType().getTypePtr())) + Result.append(unwrapFindType(QualType(PointeeType, 0), H)); + + return Result; } std::vector findType(ParsedAST &AST, Position Pos) { @@ -1951,10 +1959,14 @@ // The general scheme is: position -> AST node -> type -> declaration. auto SymbolsFromNode = [&AST](const SelectionTree::Node *N) -> std::vector { - QualType Type = unwrapFindType(typeForNode(N)); - if (Type.isNull()) - return {}; - return locateSymbolForType(AST, Type); + std::vector LocatedSymbols; + + for (const QualType& Type : unwrapFindType(typeForNode(N), AST.getHeuristicResolver())) + if (!Type.isNull()) { + llvm::copy(locateSymbolForType(AST, Type), std::back_inserter(LocatedSymbols)); + } + + return LocatedSymbols; }; SelectionTree::createEach(AST.getASTContext(), AST.getTokens(), *Offset, *Offset, [&](SelectionTree ST) { diff --git a/clang-tools-extra/clangd/unittests/XRefsTests.cpp b/clang-tools-extra/clangd/unittests/XRefsTests.cpp --- a/clang-tools-extra/clangd/unittests/XRefsTests.cpp +++ b/clang-tools-extra/clangd/unittests/XRefsTests.cpp @@ -1786,11 +1786,11 @@ TEST(FindType, All) { Annotations HeaderA(R"cpp( - struct [[Target]] { operator int() const; }; + struct $Target[[Target]] { operator int() const; }; struct Aggregate { Target a, b; }; Target t; - template class smart_ptr { + template class $smart_ptr[[smart_ptr]] { T& operator*(); T* operator->(); T* get(); @@ -1829,11 +1829,11 @@ ASSERT_GT(A.points().size(), 0u) << Case; for (auto Pos : A.points()) EXPECT_THAT(findType(AST, Pos), - ElementsAre(sym("Target", HeaderA.range(), HeaderA.range()))) + ElementsAre( + sym("Target", HeaderA.range("Target"), HeaderA.range("Target")))) << Case; } - // FIXME: We'd like these cases to work. Fix them and move above. for (const llvm::StringRef Case : { "smart_ptr ^tsmart;", }) { @@ -1842,7 +1842,10 @@ ParsedAST AST = TU.build(); EXPECT_THAT(findType(AST, A.point()), - Not(Contains(sym("Target", HeaderA.range(), HeaderA.range())))) + UnorderedElementsAre( + sym("Target", HeaderA.range("Target"), HeaderA.range("Target")), + sym("smart_ptr", HeaderA.range("smart_ptr"), HeaderA.range("smart_ptr")) + )) << Case; } }