diff --git a/clang-tools-extra/clangd/FindTarget.cpp b/clang-tools-extra/clangd/FindTarget.cpp --- a/clang-tools-extra/clangd/FindTarget.cpp +++ b/clang-tools-extra/clangd/FindTarget.cpp @@ -704,8 +704,24 @@ {OCID->getClassInterface()}}); Refs.push_back(ReferenceLoc{NestedNameSpecifierLoc(), OCID->getCategoryNameLoc(), - /*IsDecl=*/true, + /*IsDecl=*/false, {OCID->getCategoryDecl()}}); + Refs.push_back(ReferenceLoc{NestedNameSpecifierLoc(), + OCID->getCategoryNameLoc(), + /*IsDecl=*/true, + {OCID}}); + } + + void VisitObjCImplementationDecl(const ObjCImplementationDecl *OIMD) { + if (const auto *CI = OIMD->getClassInterface()) + Refs.push_back(ReferenceLoc{NestedNameSpecifierLoc(), + OIMD->getLocation(), + /*IsDecl=*/false, + {CI}}); + Refs.push_back(ReferenceLoc{NestedNameSpecifierLoc(), + OIMD->getLocation(), + /*IsDecl=*/true, + {OIMD}}); } }; diff --git a/clang-tools-extra/clangd/SemanticHighlighting.cpp b/clang-tools-extra/clangd/SemanticHighlighting.cpp --- a/clang-tools-extra/clangd/SemanticHighlighting.cpp +++ b/clang-tools-extra/clangd/SemanticHighlighting.cpp @@ -128,7 +128,7 @@ return HighlightingKind::Class; if (isa(D)) return HighlightingKind::Interface; - if (isa(D)) + if (isa(D)) return HighlightingKind::Namespace; if (auto *MD = dyn_cast(D)) return MD->isStatic() ? HighlightingKind::StaticMethod diff --git a/clang-tools-extra/clangd/refactor/Rename.cpp b/clang-tools-extra/clangd/refactor/Rename.cpp --- a/clang-tools-extra/clangd/refactor/Rename.cpp +++ b/clang-tools-extra/clangd/refactor/Rename.cpp @@ -19,6 +19,7 @@ #include "clang/AST/ASTTypeTraits.h" #include "clang/AST/Decl.h" #include "clang/AST/DeclCXX.h" +#include "clang/AST/DeclObjC.h" #include "clang/AST/DeclTemplate.h" #include "clang/AST/ParentMapContext.h" #include "clang/AST/Stmt.h" @@ -137,6 +138,14 @@ if (const auto *TargetDecl = UD->getTargetDecl()) return canonicalRenameDecl(TargetDecl); } + if (const auto *ID = dyn_cast(D)) { + if (const auto CI = ID->getClassInterface()) + return canonicalRenameDecl(CI); + } + if (const auto *CID = dyn_cast(D)) { + if (const auto CD = CID->getCategoryDecl()) + return canonicalRenameDecl(CD); + } return dyn_cast(D->getCanonicalDecl()); } @@ -156,6 +165,16 @@ targetDecl(SelectedNode->ASTNode, DeclRelation::Alias | DeclRelation::TemplatePattern, AST.getHeuristicResolver())) { + // If we select the interface name in `@interface Class (CategoryName)` or + // the implementation, the decl to rename is actually the interface, not + // the category. + if (const auto *C = dyn_cast(D)) { + if (C->getLocation() == TokenStartLoc) + D = C->getClassInterface(); + else if (const auto *I = C->getImplementation()) + if (I->getLocation() == TokenStartLoc) + D = C->getClassInterface(); + } Result.insert(canonicalRenameDecl(D)); } return Result; diff --git a/clang-tools-extra/clangd/unittests/RenameTests.cpp b/clang-tools-extra/clangd/unittests/RenameTests.cpp --- a/clang-tools-extra/clangd/unittests/RenameTests.cpp +++ b/clang-tools-extra/clangd/unittests/RenameTests.cpp @@ -840,6 +840,30 @@ foo('x'); } )cpp", + + // ObjC class with a category. + R"cpp( + @interface [[Fo^o]] + @end + @implementation [[F^oo]] + @end + @interface [[Fo^o]] (Category) + @end + @implementation [[F^oo]] (Category) + @end + + void func([[Fo^o]] *f) {} + )cpp", + + // ObjC category. + R"cpp( + @interface Foo + @end + @interface Foo ([[Cate^gory]]) + @end + @implementation Foo ([[Cate^gory]]) + @end + )cpp", }; llvm::StringRef NewName = "NewName"; for (llvm::StringRef T : Tests) { @@ -1468,7 +1492,7 @@ TEST(CrossFileRenameTests, WithUpToDateIndex) { MockCompilationDatabase CDB; - CDB.ExtraClangFlags = {"-xc++"}; + CDB.ExtraClangFlags = {"-xobjective-c++"}; // rename is runnning on all "^" points in FooH, and "[[]]" ranges are the // expected rename occurrences. struct Case { @@ -1557,13 +1581,12 @@ } )cpp", }, - { - // virtual templated method - R"cpp( + {// virtual templated method + R"cpp( template class Foo { virtual void [[m]](); }; class Bar : Foo { void [[^m]]() override; }; )cpp", - R"cpp( + R"cpp( #include "foo.h" template void Foo::[[m]]() {} @@ -1571,8 +1594,7 @@ // the canonical Foo::m(). // https://github.com/clangd/clangd/issues/1325 class Baz : Foo { void m() override; }; - )cpp" - }, + )cpp"}, { // rename on constructor and destructor. R"cpp( @@ -1677,6 +1699,20 @@ } )cpp", }, + { + // Objective-C classes. + R"cpp( + @interface [[Fo^o]] + @end + )cpp", + R"cpp( + #include "foo.h" + @implementation [[F^oo]] + @end + + void func([[Foo]] *f) {} + )cpp", + }, }; trace::TestTracer Tracer;