Index: clang-rename/USRFinder.cpp =================================================================== --- clang-rename/USRFinder.cpp +++ clang-rename/USRFinder.cpp @@ -81,6 +81,11 @@ dyn_cast(Loc.getType())) { return setResult(TemplateTypeParm->getDecl(), TypeBeginLoc, TypeEndLoc); } + if (const auto *TemplateSpecType = + dyn_cast(Loc.getType())) { + return setResult(TemplateSpecType->getTemplateName().getAsTemplateDecl(), + TypeBeginLoc, TypeEndLoc); + } return setResult(Loc.getType()->getAsCXXRecordDecl(), TypeBeginLoc, TypeEndLoc); } Index: clang-rename/USRFindingAction.cpp =================================================================== --- clang-rename/USRFindingAction.cpp +++ clang-rename/USRFindingAction.cpp @@ -20,7 +20,6 @@ #include "clang/AST/ASTContext.h" #include "clang/AST/Decl.h" #include "clang/AST/RecursiveASTVisitor.h" -#include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/Basic/FileManager.h" #include "clang/Frontend/CompilerInstance.h" #include "clang/Frontend/FrontendAction.h" @@ -36,7 +35,6 @@ using namespace llvm; -using namespace clang::ast_matchers; namespace clang { namespace rename { @@ -46,45 +44,66 @@ // AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given // Decl refers to class and adds USRs of all overridden methods if Decl refers // to virtual method. -class AdditionalUSRFinder : public MatchFinder::MatchCallback { +class AdditionalUSRFinder : public RecursiveASTVisitor { public: explicit AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context, std::vector *USRs) - : FoundDecl(FoundDecl), Context(Context), USRs(USRs), USRSet(), Finder() {} + : FoundDecl(FoundDecl), Context(Context), USRs(USRs) {} void Find() { - USRSet.insert(getUSRForDecl(FoundDecl)); + // Fill OverriddenMethods and PartialSpecs storages. + TraverseDecl(Context.getTranslationUnitDecl()); if (const auto *MethodDecl = dyn_cast(FoundDecl)) { - addUSRsFromOverrideSets(MethodDecl); - } - if (const auto *RecordDecl = dyn_cast(FoundDecl)) { - addUSRsOfCtorDtors(RecordDecl); + addUSRsOfOverridenFunctions(MethodDecl); + for (const auto &OverriddenMethod : OverriddenMethods) { + if (checkIfOverriddenFunctionAscends(OverriddenMethod)) { + USRSet.insert(getUSRForDecl(OverriddenMethod)); + } + } + } else if (const auto *RecordDecl = dyn_cast(FoundDecl)) { + handleCXXRecordDecl(RecordDecl); + } else if (const auto *TemplateDecl = + dyn_cast(FoundDecl)) { + handleClassTemplateDecl(TemplateDecl); + } else { + USRSet.insert(getUSRForDecl(FoundDecl)); } - addMatchers(); - Finder.matchAST(Context); USRs->insert(USRs->end(), USRSet.begin(), USRSet.end()); } + bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) { + if (MethodDecl->isVirtual()) { + OverriddenMethods.push_back(MethodDecl); + } + return true; + } + + bool VisitClassTemplatePartialSpecializationDecl( + const ClassTemplatePartialSpecializationDecl *PartialSpec) { + PartialSpecs.push_back(PartialSpec); + return true; + } + private: - void addMatchers() { - const auto CXXMethodDeclMatcher = - cxxMethodDecl(forEachOverridden(cxxMethodDecl().bind("cxxMethodDecl"))); - Finder.addMatcher(CXXMethodDeclMatcher, this); + void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) { + RecordDecl = RecordDecl->getDefinition(); + if (const auto *ClassTemplateSpecDecl + = dyn_cast(RecordDecl)) { + handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate()); + } + addUSRsOfCtorDtors(RecordDecl); } - // FIXME: Implement matchesUSR matchers to make lookups more efficient. - virtual void run(const MatchFinder::MatchResult &Result) { - const auto *VirtualMethod = - Result.Nodes.getNodeAs("cxxMethodDecl"); - bool Found = false; - for (const auto &OverriddenMethod : VirtualMethod->overridden_methods()) { - if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end()) { - Found = true; - } + void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) { + for (const auto *Specialization : TemplateDecl->specializations()) { + addUSRsOfCtorDtors(Specialization); } - if (Found) { - USRSet.insert(getUSRForDecl(VirtualMethod)); + for (const auto *PartialSpec : PartialSpecs) { + if (PartialSpec->getSpecializedTemplate() == TemplateDecl) { + addUSRsOfCtorDtors(PartialSpec); + } } + addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl()); } void addUSRsOfCtorDtors(const CXXRecordDecl *RecordDecl) { @@ -93,21 +112,33 @@ USRSet.insert(getUSRForDecl(CtorDecl)); } USRSet.insert(getUSRForDecl(RecordDecl->getDestructor())); + USRSet.insert(getUSRForDecl(RecordDecl)); } - void addUSRsFromOverrideSets(const CXXMethodDecl *MethodDecl) { + void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) { USRSet.insert(getUSRForDecl(MethodDecl)); for (auto &OverriddenMethod : MethodDecl->overridden_methods()) { // Recursively visit each OverridenMethod. - addUSRsFromOverrideSets(OverriddenMethod); + addUSRsOfOverridenFunctions(OverriddenMethod); + } + } + + bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) { + for (auto &OverriddenMethod : MethodDecl->overridden_methods()) { + if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end()) { + return true; + } + return checkIfOverriddenFunctionAscends(OverriddenMethod); } + return false; } const Decl *FoundDecl; ASTContext &Context; std::vector *USRs; std::set USRSet; - MatchFinder Finder; + std::vector OverriddenMethods; + std::vector PartialSpecs; }; } // namespace Index: test/clang-rename/ComplexFunctionOverride.cpp =================================================================== --- /dev/null +++ test/clang-rename/ComplexFunctionOverride.cpp @@ -0,0 +1,23 @@ +// RUN: cat %s > %t.cpp +// RUN: clang-rename -offset=307 -new-name=bar %t.cpp -i -- -std=c++11 +// RUN: sed 's,//.*,,' %t.cpp | FileCheck %s + +struct A { + virtual void foo(); // CHECK: virtual void bar(); +}; + +struct B : A { + void foo() override; // CHECK: void bar() override; +}; + +struct C : B { + void foo() override; // CHECK: void bar() override; +}; + +struct D : B { + void foo() override; // CHECK: void bar() override; +}; + +struct E : D { + void foo() override; // CHECK: void bar() override; +}; Index: test/clang-rename/TemplateClassInstantiationFindByDeclaration.cpp =================================================================== --- test/clang-rename/TemplateClassInstantiationFindByDeclaration.cpp +++ test/clang-rename/TemplateClassInstantiationFindByDeclaration.cpp @@ -1,14 +1,9 @@ // RUN: cat %s > %t.cpp -// RUN: clang-rename -offset=287 -new-name=Bar %t.cpp -i -- +// RUN: clang-rename -offset=159 -new-name=Bar %t.cpp -i -- // RUN: sed 's,//.*,,' %t.cpp | FileCheck %s -// Currently unsupported test. -// FIXME: clang-rename should be able to rename classes with templates -// correctly. -// XFAIL: * - template -class Foo { // CHECK: class Bar; +class Foo { // CHECK: class Bar { public: T foo(T arg, T& ref, T* ptr) { T value; Index: test/clang-rename/TemplateClassInstantiationFindByTypeUse.cpp =================================================================== --- test/clang-rename/TemplateClassInstantiationFindByTypeUse.cpp +++ test/clang-rename/TemplateClassInstantiationFindByTypeUse.cpp @@ -1,14 +1,9 @@ // RUN: cat %s > %t.cpp -// RUN: clang-rename -offset=703 -new-name=Bar %t.cpp -i -- +// RUN: clang-rename -offset=575 -new-name=Bar %t.cpp -i -- // RUN: sed 's,//.*,,' %t.cpp | FileCheck %s -// Currently unsupported test. -// FIXME: clang-rename should be able to rename classes with templates -// correctly. -// XFAIL: * - template -class Foo { // CHECK: class Bar; +class Foo { // CHECK: class Bar { public: T foo(T arg, T& ref, T* ptr) { T value; Index: test/clang-rename/TemplateClassInstantiationFindByUninstantiatedType.cpp =================================================================== --- /dev/null +++ test/clang-rename/TemplateClassInstantiationFindByUninstantiatedType.cpp @@ -0,0 +1,39 @@ +// RUN: cat %s > %t.cpp +// RUN: clang-rename -offset=440 -new-name=Bar %t.cpp -i -- +// RUN: sed 's,//.*,,' %t.cpp | FileCheck %s + +template +class Foo { // CHECK: class Bar { +public: + T foo(T arg, T& ref, T* ptr) { + T value; + int number = 42; + value = (T)number; + value = static_cast(number); + return value; + } + static void foo(T value) {} + T member; +}; + +template +void func() { + Foo obj; // CHECK: Bar obj; + obj.member = T(); + Foo::foo(); // CHECK: Bar::foo(); +} + +int main() { + Foo i; // CHECK: Bar i; + i.member = 0; + Foo::foo(0); // CHECK: Bar::foo(0); + + Foo b; // CHECK: Bar b; + b.member = false; + Foo::foo(false); // CHECK: Bar::foo(false); + + return 0; +} + +// Use grep -FUbo 'Foo' to get the correct offset of foo when changing +// this file.