diff --git a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt --- a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt +++ b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt @@ -21,6 +21,7 @@ ExpandMacro.cpp ExtractFunction.cpp ExtractVariable.cpp + ImplementAbstract.cpp ObjCLocalizeStringLiteral.cpp PopulateSwitch.cpp RawStringLiteral.cpp diff --git a/clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp b/clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp new file mode 100644 --- /dev/null +++ b/clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp @@ -0,0 +1,296 @@ +//===--- ImplementAbstract.cpp -----------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "refactor/Tweak.h" +#include "support/Logger.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/SmallPtrSet.h" + +namespace clang { +namespace clangd { +using MethodAndAccess = + llvm::PointerIntPair; + +namespace { +class ImplementAbstract : public Tweak { +public: + const char *id() const override; + + bool prepare(const Selection &Inputs) override; + Expected apply(const Selection &Inputs) override; + std::string title() const override { + return "Implement pure virtual methods"; + } + llvm::StringLiteral kind() const override { + return CodeAction::REFACTOR_KIND; + } + +private: + const CXXRecordDecl *Selected; + const CXXRecordDecl *Recent; + std::vector PureVirtualMethods; +}; + +AccessSpecifier getConstrained(AccessSpecifier InheritSpecifier, + AccessSpecifier DefinedAs) { + return std::max(InheritSpecifier, DefinedAs); +} + +bool collectPureVirtual(const CXXRecordDecl &Record, + std::vector &Results, + AccessSpecifier Access, + llvm::SmallPtrSetImpl &Overrides, + bool IsRoot) { + if (Record.getNumBases() == 0) { + // If there are no base classes, don't bother populating the Overrides set + // as we'll never read it. + for (const CXXMethodDecl *Method : Record.methods()) { + if (!Method->isPure()) + continue; + if (IsRoot) + return true; + if (!Overrides.contains(Method)) { + // Method hasn't been overridden in any derived class. + Results.emplace_back(Method, + getConstrained(Access, Method->getAccess())); + } + } + return false; + } + + for (const CXXMethodDecl *Method : Record.methods()) { + if (!Method->isVirtual()) + continue; + if (IsRoot && Method->isPure()) + return true; + for (const auto *Overriding : Method->overridden_methods()) + Overrides.insert(Overriding); + } + for (auto Base : Record.bases()) { + const RecordType *RT = Base.getType()->getAs(); + if (!RT) + // Probably a dependent base, just error out. + return true; + const CXXRecordDecl *BaseDecl = cast(RT->getDecl()); + if (!BaseDecl->isPolymorphic()) + continue; + if (collectPureVirtual(*BaseDecl, Results, + getConstrained(Access, Base.getAccessSpecifier()), + Overrides, false)) + // Propergate any error back up. + return true; + } + // Add the Pure methods from this class after traversing the bases, this means + // they will appear after in the + for (const CXXMethodDecl *Method : Record.methods()) { + if (!Method->isPure()) + continue; + if (!Overrides.contains(Method)) { + // Method hasn't been overridden in any derived class. + Results.emplace_back(Method, getConstrained(Access, Method->getAccess())); + } + } + return false; +} + +static const CXXRecordDecl * +getSelectedRecord(const SelectionTree::Node *SelNode) { + if (!SelNode) + return nullptr; + const DynTypedNode &AstNode = SelNode->ASTNode; + return AstNode.get(); +} + +bool ImplementAbstract::prepare(const Selection &Inputs) { + // FIXME: This method won't return the class when the caret in the body of the + // class. So the only way to get the tweak offered is to be be touching the + // marked ranges. It would be nicer if this was offered if cursor was inside + // the class (but perhaps not inside the classes decls). + // [[class]] [[Derived]] [[:]] [[public]] Base [[{]] + // ^ + // [[}]]; + Selected = getSelectedRecord(Inputs.ASTSelection.commonAncestor()); + if (!Selected) + return false; + + // Some sanity checks before we try. + if (!Selected->isThisDeclarationADefinition()) + return false; + if (!Selected->isClass() && !Selected->isStruct()) + return false; + if (Selected->hasAnyDependentBases() || Selected->getNumBases() == 0) + return false; + // We should check for abstract, but that prevents working on template classes + // that don't have any dependent bases. + if (!Selected->isPolymorphic()) + return false; + + Recent = Selected->getMostRecentDecl(); + + llvm::SmallPtrSet Overrides; + if (collectPureVirtual(*Selected, PureVirtualMethods, AS_public, Overrides, + true)) + return false; + return !PureVirtualMethods.empty(); +} + +static void printMethods(llvm::raw_ostream &Out, + ArrayRef Items, + const CXXRecordDecl *PrintContext, + StringRef AccessSpec = {}) { + class PrintCB : public PrintingCallbacks { + public: + PrintCB(const DeclContext *CurContext) : CurContext(CurContext) {} + virtual ~PrintCB() {} + bool isScopeVisible(const DeclContext *DC) const override { + return DC->Encloses(CurContext); + } + + private: + const DeclContext *CurContext; + }; + PrintCB Callbacks(PrintContext); + auto Policy = PrintContext->getASTContext().getPrintingPolicy(); + Policy.SuppressScope = false; + Policy.Callbacks = &Callbacks; + if (!AccessSpec.empty()) + Out << "\n" << AccessSpec << ":\n"; + Out << "\n"; + for (const CXXMethodDecl *Method : Items) { + Method->getReturnType().print(Out, Policy); + Out << ' '; + Out << Method->getNameAsString() << "("; + bool IsFirst = true; + for (const auto &Param : Method->parameters()) { + if (!IsFirst) + Out << ", "; + else + IsFirst = false; + Param->print(Out, Policy); + } + Out << ") "; + if (Method->isConst()) + Out << "const "; + if (Method->isVolatile()) + Out << "volatile "; + // Always suggest `override` over `final`. + Out << "override;\n"; + } +} + +Expected ImplementAbstract::apply(const Selection &Inputs) { + llvm::SmallVector GroupedAccessMethods[3]; + + for (const MethodAndAccess &PVM : PureVirtualMethods) { + GroupedAccessMethods[PVM.getInt()].push_back(PVM.getPointer()); + } + + // We should have at least one pure virtual method to add. + assert(llvm::any_of( + GroupedAccessMethods, + [](ArrayRef Array) { return !Array.empty(); })); + + struct InsertionDetail { + SourceLocation Loc = {}; + bool RefersToMethod = false; + }; + + using DetailAndAccess = std::pair; + SmallVector InsertionPoints; + + auto GetDetailForAccess = [&](AccessSpecifier Spec) -> InsertionDetail & { + assert(Spec != AS_none); + for (DetailAndAccess &Item : InsertionPoints) { + if (Item.second == Spec) + return Item.first; + } + return InsertionPoints.emplace_back(InsertionDetail{}, Spec).first; + }; + + // FIXME: This is a little hacky but EndLoc of a function decl is the start of + // the last token not including a semi-colon if its just a declaration. This + // skips past the last token plus one just incase there is a semi-colon. + // Should really find a nicer way around this. + auto Next = [&](SourceLocation Loc) { + return Loc.getLocWithOffset( + Lexer::MeasureTokenLength(Loc, Inputs.AST->getSourceManager(), + Inputs.AST->getLangOpts()) + + 1); + }; + // This whole block is designed to get an insertion point after the last + // method has been declared with each access specifier. Doing this ensures we + // keep the same visibility for implemented methods without the need to add + // unnecessary access specifiers. + for (auto *Decl : Selected->decls()) { + // Ignore things like compiler generated special member functions. + if (Decl->isImplicit()) + continue; + // Hack to try and leave the destructor as last method in a block. + if (isa(Decl)) + continue; + InsertionDetail &Detail = GetDetailForAccess(Decl->getAccess()); + if (isa(Decl)) { + Detail.Loc = Next(Decl->getSourceRange().getEnd()); + Detail.RefersToMethod = true; + } else if (!Detail.RefersToMethod) { + // Last decl with this access wasn't method decl. + Detail.Loc = Next(Decl->getSourceRange().getEnd()); + } + } + if (InsertionPoints.empty()) { + // No non-implicit declarations in the body, use the default access for the + // first potential insertion. + GetDetailForAccess(Selected->isClass() ? AS_private : AS_public) = + InsertionDetail{ + Selected->getBraceRange().getBegin().getLocWithOffset(1), true}; + } + + SmallString<256> Buffer; + llvm::raw_svector_ostream OS(Buffer); + tooling::Replacements Replacements; + for (auto &Item : InsertionPoints) { + assert(Item.first.Loc.isValid()); + llvm::SmallVectorImpl &GroupedMethods = + GroupedAccessMethods[Item.second]; + if (GroupedMethods.empty()) + continue; + printMethods(OS, GroupedMethods, Selected); + if (auto Err = Replacements.add(tooling::Replacement( + Inputs.AST->getSourceManager(), Item.first.Loc, 0, Buffer))) { + return std::move(Err); + } + // Clear the methods as in the fallback loop we don't want to print them + // again. + GroupedMethods.clear(); + Buffer.clear(); + } + + // Any access specifiers not convered can be added in one insertion. + for (AccessSpecifier Spec : {AS_public, AS_protected, AS_private}) { + llvm::SmallVectorImpl &GroupedMethods = + GroupedAccessMethods[Spec]; + if (GroupedMethods.empty()) + continue; + printMethods(OS, GroupedMethods, Selected, getAccessSpelling(Spec)); + } + if (!Buffer.empty()) { + if (auto Err = Replacements.add(tooling::Replacement( + Inputs.AST->getSourceManager(), Selected->getBraceRange().getEnd(), + 0, Buffer))) { + return std::move(Err); + } + } + return Effect::mainFileEdit(Inputs.AST->getASTContext().getSourceManager(), + std::move(Replacements)); +} + +REGISTER_TWEAK(ImplementAbstract) + +} // namespace +} // namespace clangd +} // namespace clang diff --git a/clang-tools-extra/clangd/unittests/CMakeLists.txt b/clang-tools-extra/clangd/unittests/CMakeLists.txt --- a/clang-tools-extra/clangd/unittests/CMakeLists.txt +++ b/clang-tools-extra/clangd/unittests/CMakeLists.txt @@ -118,6 +118,7 @@ tweaks/ExpandMacroTests.cpp tweaks/ExtractFunctionTests.cpp tweaks/ExtractVariableTests.cpp + tweaks/ImplementAbstractTests.cpp tweaks/ObjCLocalizeStringLiteralTests.cpp tweaks/PopulateSwitchTests.cpp tweaks/RawStringLiteralTests.cpp diff --git a/clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp b/clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp new file mode 100644 --- /dev/null +++ b/clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp @@ -0,0 +1,349 @@ +//===-- ImplementAbstractTests.cpp ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestTU.h" +#include "TweakTesting.h" +#include "gmock/gmock-matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using ::testing::Not; + +namespace clang { +namespace clangd { +namespace { + +TWEAK_TEST(ImplementAbstract); + +TEST_F(ImplementAbstractTest, TestUnavailable) { + + StringRef Cases[]{ + // Not a pure virtual method. + R"cpp( + class A { + virtual void Foo(); + }; + class ^B : public A {}; + )cpp", + // Pure virtual method overridden in class. + R"cpp( + class A { + virtual void Foo() = 0; + }; + class ^B : public A { + void Foo() override; + }; + )cpp", + // Pure virtual method overridden in class with virtual keyword + R"cpp( + class A { + virtual void Foo() = 0; + }; + class ^B : public A { + virtual void Foo() override; + }; + )cpp", + // Pure virtual method overridden in class without override keyword + R"cpp( + class A { + virtual void Foo() = 0; + }; + class ^B : public A { + void Foo(); + }; + )cpp", + // Pure virtual method overriden in base class. + R"cpp( + class A { + virtual void Foo() = 0; + }; + class B : public A { + void Foo() override; + }; + class ^C : public B { + }; + )cpp"}; + for (const auto &Case : Cases) { + EXPECT_THAT(Case, Not(isAvailable())); + } +} + +TEST_F(ImplementAbstractTest, NormalAvailable) { + struct Case { + llvm::StringRef TestHeader; + llvm::StringRef TestSource; + llvm::StringRef ExpectedSource; + }; + + Case Cases[]{ + { + R"cpp( + class A { + virtual void Foo() = 0; + };)cpp", + R"cpp( + class B : public A {^}; + )cpp", + R"cpp( + class B : public A { +void Foo() override; +}; + )cpp", + }, + { + R"cpp( + class A { + public: + virtual void Foo() = 0; + };)cpp", + R"cpp( + class ^B : public A {}; + )cpp", + R"cpp( + class B : public A { +public: + +void Foo() override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo(int Param) = 0; + };)cpp", + R"cpp( + class ^B : public A {}; + )cpp", + R"cpp( + class B : public A { +void Foo(int Param) override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo(int Param) = 0; + };)cpp", + R"cpp( + struct ^B : public A {}; + )cpp", + R"cpp( + struct B : public A { +private: + +void Foo(int Param) override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo(int Param) const volatile = 0; + public: + virtual void Bar(int Param) = 0; + };)cpp", + R"cpp( + class ^B : public A { + void Foo(int Param) const volatile override; + }; + )cpp", + R"cpp( + class B : public A { + void Foo(int Param) const volatile override; + +public: + +void Bar(int Param) override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo() = 0; + virtual void Bar() = 0; + }; + class B : public A { + void Foo() override; + }; + )cpp", + R"cpp( + class ^C : public B { + virtual void Baz(); + }; + )cpp", + R"cpp( + class C : public B { + virtual void Baz(); +void Bar() override; + + }; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo() = 0; + };)cpp", + R"cpp( + class ^B : public A { + ~B(); + }; + )cpp", + R"cpp( + class B : public A { +void Foo() override; + + ~B(); + }; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo() = 0; + public: + virtual void Bar() = 0; + };)cpp", + R"cpp( + class ^B : public A { + }; + )cpp", + R"cpp( + class B : public A { +void Foo() override; + + +public: + +void Bar() override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo() = 0; + }; + struct B : public A { + virtual void Bar() = 0; + };)cpp", + R"cpp( + class ^C : public B { + }; + )cpp", + R"cpp( + class C : public B { +void Foo() override; + + +public: + +void Bar() override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo() = 0; + }; + struct B : public A { + virtual void Bar() = 0; + };)cpp", + R"cpp( + class ^C : private B { + }; + )cpp", + R"cpp( + class C : private B { +void Foo() override; +void Bar() override; + + }; + )cpp", + }, + }; + + for (const auto &Case : Cases) { + Header = Case.TestHeader.str(); + EXPECT_EQ(apply(Case.TestSource), Case.ExpectedSource); + } +} + +TEST_F(ImplementAbstractTest, TemplateUnavailable) { + StringRef Cases[]{ + R"cpp( + template + class A { + virtual void Foo() = 0; + }; + template + class ^B : public A {}; + )cpp", + R"cpp( + template + class ^B : public T {}; + )cpp", + }; + for (const auto &Case : Cases) { + EXPECT_THAT(Case, Not(isAvailable())); + } +} + +TEST_F(ImplementAbstractTest, TemplateAvailable) { + struct Case { + llvm::StringRef TestHeader; + llvm::StringRef TestSource; + llvm::StringRef ExpectedSource; + }; + Case Cases[]{ + { + R"cpp( + template + class A { + virtual void Foo() = 0; + }; + )cpp", + R"cpp( + class ^B : public A {}; + )cpp", + R"cpp( + class B : public A { +void Foo() override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo() = 0; + };)cpp", + R"cpp( + template + class ^B : public A {}; + )cpp", + R"cpp( + template + class B : public A { +void Foo() override; +}; + )cpp", + }, + }; + for (const auto &Case : Cases) { + Header = Case.TestHeader.str(); + EXPECT_EQ(apply(Case.TestSource), Case.ExpectedSource); + } +} + +} // namespace +} // namespace clangd +} // namespace clang