diff --git a/clang-tools-extra/clangd/ClangdLSPServer.h b/clang-tools-extra/clangd/ClangdLSPServer.h --- a/clang-tools-extra/clangd/ClangdLSPServer.h +++ b/clang-tools-extra/clangd/ClangdLSPServer.h @@ -121,6 +121,8 @@ Callback>); void onGoToDefinition(const TextDocumentPositionParams &, Callback>); + void onGoToType(const TextDocumentPositionParams &, + Callback>); void onGoToImplementation(const TextDocumentPositionParams &, Callback>); void onReference(const ReferenceParams &, Callback>); diff --git a/clang-tools-extra/clangd/ClangdLSPServer.cpp b/clang-tools-extra/clangd/ClangdLSPServer.cpp --- a/clang-tools-extra/clangd/ClangdLSPServer.cpp +++ b/clang-tools-extra/clangd/ClangdLSPServer.cpp @@ -560,6 +560,7 @@ {"declarationProvider", true}, {"definitionProvider", true}, {"implementationProvider", true}, + {"typeDefinitionProvider", true}, {"documentHighlightProvider", true}, {"documentLinkProvider", llvm::json::Object{ @@ -1278,6 +1279,21 @@ }); } +void ClangdLSPServer::onGoToType(const TextDocumentPositionParams &Params, + Callback> Reply) { + Server->findType( + Params.textDocument.uri.file(), Params.position, + [Reply = std::move(Reply)]( + llvm::Expected> Types) mutable { + if (!Types) + return Reply(Types.takeError()); + std::vector Response; + for (const LocatedSymbol &Sym : *Types) + Response.push_back(Sym.PreferredDeclaration); + return Reply(std::move(Response)); + }); +} + void ClangdLSPServer::onGoToImplementation( const TextDocumentPositionParams &Params, Callback> Reply) { @@ -1448,6 +1464,7 @@ Bind.method("textDocument/signatureHelp", this, &ClangdLSPServer::onSignatureHelp); Bind.method("textDocument/definition", this, &ClangdLSPServer::onGoToDefinition); Bind.method("textDocument/declaration", this, &ClangdLSPServer::onGoToDeclaration); + Bind.method("textDocument/typeDefinition", this, &ClangdLSPServer::onGoToType); Bind.method("textDocument/implementation", this, &ClangdLSPServer::onGoToImplementation); Bind.method("textDocument/references", this, &ClangdLSPServer::onReference); Bind.method("textDocument/switchSourceHeader", this, &ClangdLSPServer::onSwitchSourceHeader); diff --git a/clang-tools-extra/clangd/ClangdServer.h b/clang-tools-extra/clangd/ClangdServer.h --- a/clang-tools-extra/clangd/ClangdServer.h +++ b/clang-tools-extra/clangd/ClangdServer.h @@ -282,6 +282,10 @@ void findImplementations(PathRef File, Position Pos, Callback> CB); + /// Retrieve symbols for types referenced at \p Pos. + void findType(PathRef File, Position Pos, + Callback> CB); + /// Retrieve locations for symbol references. void findReferences(PathRef File, Position Pos, uint32_t Limit, Callback CB); diff --git a/clang-tools-extra/clangd/ClangdServer.cpp b/clang-tools-extra/clangd/ClangdServer.cpp --- a/clang-tools-extra/clangd/ClangdServer.cpp +++ b/clang-tools-extra/clangd/ClangdServer.cpp @@ -810,6 +810,17 @@ Transient); } +void ClangdServer::findType(llvm::StringRef File, Position Pos, + Callback> CB) { + auto Action = + [Pos, CB = std::move(CB)](llvm::Expected InpAST) mutable { + if (!InpAST) + return CB(InpAST.takeError()); + CB(clangd::findType(InpAST->AST, Pos)); + }; + WorkScheduler->runWithAST("FindType", File, std::move(Action)); +} + void ClangdServer::findImplementations( PathRef File, Position Pos, Callback> CB) { auto Action = [Pos, CB = std::move(CB), diff --git a/clang-tools-extra/clangd/XRefs.h b/clang-tools-extra/clangd/XRefs.h --- a/clang-tools-extra/clangd/XRefs.h +++ b/clang-tools-extra/clangd/XRefs.h @@ -105,6 +105,12 @@ std::vector findImplementations(ParsedAST &AST, Position Pos, const SymbolIndex *Index); +/// Returns symbols for types referenced at \p Pos. +/// +/// For example, given `b^ar()` wher bar return Foo, this function returns the +/// definition of class Foo. +std::vector findType(ParsedAST &AST, Position Pos); + /// Returns references of the symbol at a specified \p Pos. /// \p Limit limits the number of results returned (0 means no limit). ReferencesResult findReferences(ParsedAST &AST, Position Pos, uint32_t Limit, diff --git a/clang-tools-extra/clangd/XRefs.cpp b/clang-tools-extra/clangd/XRefs.cpp --- a/clang-tools-extra/clangd/XRefs.cpp +++ b/clang-tools-extra/clangd/XRefs.cpp @@ -29,11 +29,13 @@ #include "clang/AST/DeclCXX.h" #include "clang/AST/DeclObjC.h" #include "clang/AST/DeclTemplate.h" +#include "clang/AST/DeclVisitor.h" #include "clang/AST/ExprCXX.h" #include "clang/AST/ExternalASTSource.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/Stmt.h" #include "clang/AST/StmtCXX.h" +#include "clang/AST/StmtVisitor.h" #include "clang/AST/Type.h" #include "clang/Basic/CharInfo.h" #include "clang/Basic/LLVM.h" @@ -503,6 +505,8 @@ return {}; } + // FIXME: this sends unique_ptr to unique_ptr. + // Likely it would be better to send it to Foo (heuristically) or to both. auto Decls = targetDecl(DynTypedNode::create(Type.getNonReferenceType()), DeclRelation::TemplatePattern | DeclRelation::Alias, AST.getHeuristicResolver()); @@ -1785,6 +1789,182 @@ return Result; } +// Return the type most associated with an AST node. +// This isn't precisely defined: we want "go to type" to do something useful. +static QualType typeForNode(const SelectionTree::Node *N) { + // If we're looking at a namespace qualifier, walk up to what it's qualifying. + // (If we're pointing at a *class* inside a NNS, N will be a TypeLoc). + while (N && N->ASTNode.get()) + N = N->Parent; + if (!N) + return QualType(); + + // If we're pointing at a type => return it. + if (const TypeLoc *TL = N->ASTNode.get()) { + if (llvm::isa(TL->getTypePtr())) + if (auto Deduced = getDeducedType( + N->getDeclContext().getParentASTContext(), TL->getBeginLoc())) + return *Deduced; + // Exception: an alias => underlying type. + if (llvm::isa(TL->getTypePtr())) + return TL->getTypePtr()->getLocallyUnqualifiedSingleStepDesugaredType(); + return TL->getType(); + } + + // Constructor initializers => the type of thing being initialized. + if (const auto *CCI = N->ASTNode.get()) { + if (const FieldDecl *FD = CCI->getAnyMember()) + return FD->getType(); + if (const Type *Base = CCI->getBaseClass()) + return QualType(Base, 0); + } + + // Base specifier => the base type. + if (const auto *CBS = N->ASTNode.get()) + return CBS->getType(); + + if (const Decl *D = N->ASTNode.get()) { + struct Visitor : ConstDeclVisitor { + QualType VisitValueDecl(const ValueDecl *D) { return D->getType(); } + // Declaration of a type => that type. + QualType VisitTypeDecl(const TypeDecl *D) { + return QualType(D->getTypeForDecl(), 0); + } + // Exception: alias declaration => the underlying type, not the alias. + QualType VisitTypedefNameDecl(const TypedefNameDecl *D) { + return D->getUnderlyingType(); + } + // Look inside templates. + QualType VisitTemplateDecl(const TemplateDecl *D) { + return Visit(D->getTemplatedDecl()); + } + } V; + return V.Visit(D); + } + + if (const Stmt *S = N->ASTNode.get()) { + struct Visitor : ConstStmtVisitor { + // Null-safe version of visit simplifies recursive calls below. + QualType type(const Stmt *S) { return S ? Visit(S) : QualType(); } + + // In general, expressions => type of expression. + QualType VisitExpr(const Expr *S) { + return S->IgnoreImplicitAsWritten()->getType(); + } + // Exceptions for void expressions that operate on a type in some way. + QualType VisitCXXDeleteExpr(const CXXDeleteExpr *S) { + return S->getDestroyedType(); + } + QualType VisitCXXPseudoDestructorExpr(const CXXPseudoDestructorExpr *S) { + return S->getDestroyedType(); + } + QualType VisitCXXThrowExpr(const CXXThrowExpr *S) { + return S->getSubExpr()->getType(); + } + QualType VisitCoyieldStmt(const CoyieldExpr *S) { + return type(S->getOperand()); + } + // Treat a designated initializer like a reference to the field. + QualType VisitDesignatedInitExpr(const DesignatedInitExpr *S) { + // In .foo.bar we want to jump to bar's type, so find *last* field. + for (auto &D : llvm::reverse(S->designators())) + if (D.isFieldDesignator()) + if (const auto *FD = D.getField()) + return FD->getType(); + return QualType(); + } + + // Control flow statements that operate on data: use the data type. + QualType VisitSwitchStmt(const SwitchStmt *S) { + return type(S->getCond()); + } + QualType VisitWhileStmt(const WhileStmt *S) { return type(S->getCond()); } + QualType VisitDoStmt(const DoStmt *S) { return type(S->getCond()); } + QualType VisitIfStmt(const IfStmt *S) { return type(S->getCond()); } + QualType VisitCaseStmt(const CaseStmt *S) { return type(S->getLHS()); } + QualType VisitCXXForRangeStmt(const CXXForRangeStmt *S) { + return S->getLoopVariable()->getType(); + } + QualType VisitReturnStmt(const ReturnStmt *S) { + return type(S->getRetValue()); + } + QualType VisitCoreturnStmt(const CoreturnStmt *S) { + return type(S->getOperand()); + } + QualType VisitCXXCatchStmt(const CXXCatchStmt *S) { + return S->getCaughtType(); + } + QualType VisitObjCAtThrowStmt(const ObjCAtThrowStmt *S) { + return type(S->getThrowExpr()); + } + QualType VisitObjCAtCatchStmt(const ObjCAtCatchStmt *S) { + return S->getCatchParamDecl() ? S->getCatchParamDecl()->getType() + : QualType(); + } + } V; + return V.Visit(S); + } + + return QualType(); +} + +// Given a type targeted by the cursor, return a type that's more interesting +// to target. +static QualType unwrapFindType(QualType T) { + if (T.isNull()) + return T; + + // If there's a specific type alias, point at that rather than unwrapping. + if (const auto* TDT = T->getAs()) + return QualType(TDT, 0); + + // Pointers etc => pointee type. + if (const auto *PT = T->getAs()) + return unwrapFindType(PT->getPointeeType()); + if (const auto *RT = T->getAs()) + return unwrapFindType(RT->getPointeeType()); + if (const auto *AT = T->getAsArrayTypeUnsafe()) + return unwrapFindType(AT->getElementType()); + // FIXME: use HeuristicResolver to unwrap smart pointers? + + // Function type => return type. + if (auto FT = T->getAs()) + return unwrapFindType(FT->getReturnType()); + if (auto CRD = T->getAsCXXRecordDecl()) { + if (CRD->isLambda()) + return unwrapFindType(CRD->getLambdaCallOperator()->getReturnType()); + // FIXME: more cases we'd prefer the return type of the call operator? + // std::function etc? + } + + return T; +} + +std::vector findType(ParsedAST &AST, Position Pos) { + const SourceManager &SM = AST.getSourceManager(); + auto Offset = positionToOffset(SM.getBufferData(SM.getMainFileID()), Pos); + std::vector Result; + if (!Offset) { + elog("failed to convert position {0} for findTypes: {1}", Pos, + Offset.takeError()); + return Result; + } + // The general scheme is: position -> AST node -> type -> declaration. + auto SymbolsFromNode = + [&AST](const SelectionTree::Node *N) -> std::vector { + QualType Type = unwrapFindType(typeForNode(N)); + if (Type.isNull()) + return {}; + return locateSymbolForType(AST, Type); + }; + SelectionTree::createEach(AST.getASTContext(), AST.getTokens(), *Offset, + *Offset, [&](SelectionTree ST) { + Result = SymbolsFromNode(ST.commonAncestor()); + return !Result.empty(); + }); + return Result; +} + std::vector typeParents(const CXXRecordDecl *CXXRD) { std::vector Result; diff --git a/clang-tools-extra/clangd/test/initialize-params.test b/clang-tools-extra/clangd/test/initialize-params.test --- a/clang-tools-extra/clangd/test/initialize-params.test +++ b/clang-tools-extra/clangd/test/initialize-params.test @@ -121,6 +121,7 @@ # CHECK-NEXT: "openClose": true, # CHECK-NEXT: "save": true # CHECK-NEXT: }, +# CHECK-NEXT: "typeDefinitionProvider": true, # CHECK-NEXT: "typeHierarchyProvider": true # CHECK-NEXT: "workspaceSymbolProvider": true # CHECK-NEXT: }, diff --git a/clang-tools-extra/clangd/test/type-definition.test b/clang-tools-extra/clangd/test/type-definition.test new file mode 100644 --- /dev/null +++ b/clang-tools-extra/clangd/test/type-definition.test @@ -0,0 +1,32 @@ +# RUN: clangd -lit-test < %s | FileCheck -strict-whitespace %s +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{}} +--- +{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"uri":"test:///main.cpp","languageId":"cpp","version":1, + "text":"class X {};\nauto x = X{};" +}}} +--- +{"jsonrpc":"2.0","id":1,"method":"textDocument/typeDefinition","params":{ + "textDocument":{"uri":"test:///main.cpp"}, + "position":{"line":1,"character":5} +}} +# CHECK: "id": 1 +# CHECK-NEXT: "jsonrpc": "2.0", +# CHECK-NEXT: "result": [ +# CHECK-NEXT: { +# CHECK-NEXT: "range": { +# CHECK-NEXT: "end": { +# CHECK-NEXT: "character": 7, +# CHECK-NEXT: "line": 0 +# CHECK-NEXT: }, +# CHECK-NEXT: "start": { +# CHECK-NEXT: "character": 6, +# CHECK-NEXT: "line": 0 +# CHECK-NEXT: } +# CHECK-NEXT: }, +# CHECK-NEXT: "uri": "file://{{.*}}/clangd-test/main.cpp" +# CHECK-NEXT: } +# CHECK-NEXT: ] +--- +{"jsonrpc":"2.0","id":3,"method":"shutdown"} +--- +{"jsonrpc":"2.0","method":"exit"} diff --git a/clang-tools-extra/clangd/unittests/XRefsTests.cpp b/clang-tools-extra/clangd/unittests/XRefsTests.cpp --- a/clang-tools-extra/clangd/unittests/XRefsTests.cpp +++ b/clang-tools-extra/clangd/unittests/XRefsTests.cpp @@ -38,10 +38,12 @@ namespace { using ::testing::AllOf; +using ::testing::Contains; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::Matcher; +using ::testing::Not; using ::testing::UnorderedElementsAre; using ::testing::UnorderedElementsAreArray; using ::testing::UnorderedPointwise; @@ -1781,6 +1783,69 @@ << Test; } +TEST(FindType, All) { + Annotations HeaderA(R"cpp( + struct [[Target]] { operator int() const; }; + struct Aggregate { Target a, b; }; + Target t; + + template class smart_ptr { + T& operator*(); + T* operator->(); + T* get(); + }; + )cpp"); + auto TU = TestTU::withHeaderCode(HeaderA.code()); + for (const llvm::StringRef Case : { + "str^uct Target;", + "T^arget x;", + "Target ^x;", + "a^uto x = Target{};", + "namespace m { Target tgt; } auto x = m^::tgt;", + "Target funcCall(); auto x = ^funcCall();", + "Aggregate a = { {}, ^{} };", + "Aggregate a = { ^.a=t, };", + "struct X { Target a; X() : ^a() {} };", + "^using T = Target; ^T foo();", + "^template Target foo();", + "void x() { try {} ^catch(Target e) {} }", + "void x() { ^throw t; }", + "int x() { ^return t; }", + "void x() { ^switch(t) {} }", + "void x() { ^delete (Target*)nullptr; }", + "Target& ^tref = t;", + "void x() { ^if (t) {} }", + "void x() { ^while (t) {} }", + "void x() { ^do { } while (t); }", + "^auto x = []() { return t; };", + "Target* ^tptr = &t;", + "Target ^tarray[3];", + }) { + Annotations A(Case); + TU.Code = A.code().str(); + ParsedAST AST = TU.build(); + + ASSERT_GT(A.points().size(), 0u) << Case; + for (auto Pos : A.points()) + EXPECT_THAT(findType(AST, Pos), + ElementsAre(Sym("Target", HeaderA.range(), HeaderA.range()))) + << Case; + } + + // FIXME: We'd like these cases to work. Fix them and move above. + for (const llvm::StringRef Case : { + "smart_ptr ^tsmart;", + }) { + Annotations A(Case); + TU.Code = A.code().str(); + ParsedAST AST = TU.build(); + + EXPECT_THAT(findType(AST, A.point()), + Not(Contains(Sym("Target", HeaderA.range(), HeaderA.range())))) + << Case; + } +} + void checkFindRefs(llvm::StringRef Test, bool UseIndex = false) { Annotations T(Test); auto TU = TestTU::withCode(T.code());