diff --git a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt --- a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt +++ b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt @@ -21,6 +21,7 @@ ExpandMacro.cpp ExtractFunction.cpp ExtractVariable.cpp + ImplementAbstract.cpp ObjCLocalizeStringLiteral.cpp PopulateSwitch.cpp RawStringLiteral.cpp diff --git a/clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp b/clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp new file mode 100644 --- /dev/null +++ b/clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp @@ -0,0 +1,234 @@ +//===--- ImplementAbstract.cpp -----------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "refactor/Tweak.h" +#include "support/Logger.h" +#include "clang/AST/CXXInheritance.h" + +namespace clang { +namespace clangd { +namespace { +class ImplementAbstract : public Tweak { +public: + const char *id() const override; + + bool prepare(const Selection &Inputs) override; + Expected apply(const Selection &Inputs) override; + std::string title() const override { + return "Implement pure virtual methods"; + } + llvm::StringLiteral kind() const override { + return CodeAction::REFACTOR_KIND; + } + +private: + const CXXRecordDecl *Selected; + const CXXRecordDecl *Recent; + CXXFinalOverriderMap OverrideMap; +}; + +static const CXXRecordDecl * +getSelectedRecord(const SelectionTree::Node *SelNode) { + if (!SelNode) + return nullptr; + const DynTypedNode &AstNode = SelNode->ASTNode; + return AstNode.get(); +} + +bool ImplementAbstract::prepare(const Selection &Inputs) { + // FIXME: This method won't return the class when the caret in the body of the + // class. So the only way to get the tweak offered is to be be touching the + // marked ranges. It would be nicer if this was offered if cursor was inside + // the class (but perhaps not inside the classes decls). + // [[class]] [[Derived]] [[:]] [[public]] Base [[{]] + // ^ + // [[}]]; + Selected = getSelectedRecord(Inputs.ASTSelection.commonAncestor()); + if (!Selected) + return false; + + // Some sanity checks before we try. + if (!Selected->isThisDeclarationADefinition()) + return false; + if (!Selected->isClass() && !Selected->isStruct()) + return false; + if (Selected->isTemplateDecl()) + return false; + if (Selected->getNumBases() == 0) + return false; + if (!Selected->isAbstract()) + return false; + + Recent = Selected->getMostRecentDecl(); + + Selected->getFinalOverriders(OverrideMap); + return llvm::any_of( + OverrideMap, + [&](CXXFinalOverriderMap::iterator::value_type &FinalOverride) { + if (FinalOverride.first->getParent()->getMostRecentDecl() == Recent) + return false; + return llvm::any_of(FinalOverride.second, [](auto &SO) { + assert(!SO.second.empty()); + return SO.second.front().Method->isPure(); + }); + }); +} + +static void printMethods(llvm::raw_ostream &Out, + ArrayRef Items, + StringRef AccessSpec = {}) { + if (!AccessSpec.empty()) + Out << "\n" << AccessSpec << ":\n"; + Out << "\n"; + for (const CXXMethodDecl *Method : Items) { + Method->getReturnType().print(Out, + Method->getASTContext().getPrintingPolicy()); + Out << ' '; + Out << Method->getNameAsString() << "("; + bool IsFirst = true; + for (const auto &Param : Method->parameters()) { + if (!IsFirst) + Out << ", "; + else + IsFirst = false; + Param->print(Out); + } + Out << ") "; + if (Method->isConst()) + Out << "const "; + if (Method->isVolatile()) + Out << "volatile "; + // Always suggest `override` over `final`. + Out << "override;\n"; + } +} + +Expected ImplementAbstract::apply(const Selection &Inputs) { + llvm::SmallVector GroupedAccessMethods[3]; + + for (const auto &Items : OverrideMap) { + const auto &FinalOverride = Items.first; + StringRef Name = FinalOverride->getDeclName().isIdentifier() + ? FinalOverride->getName() + : ""; + (void)Name; + if (FinalOverride->getParent()->getMostRecentDecl() == Recent) + continue; + if (llvm::none_of(Items.second, [](auto &SO) { + assert(!SO.second.empty()); + return SO.second.front().Method->isPure(); + })) + continue; + if (FinalOverride->getAccess() == AS_none) + return error("Invalid access for method {0}", + FinalOverride->getNameAsString()); + GroupedAccessMethods[FinalOverride->getAccess()].push_back(FinalOverride); + } + + // We should have at least one pure virtual method to add. + assert(llvm::any_of( + GroupedAccessMethods, + [](ArrayRef Array) { return !Array.empty(); })); + + struct InsertionDetail { + SourceLocation Loc = {}; + bool RefersToMethod = false; + }; + + using DetailAndAccess = std::pair; + SmallVector InsertionPoints; + + auto GetDetailForAccess = [&](AccessSpecifier Spec) -> InsertionDetail & { + assert(Spec != AS_none); + for (DetailAndAccess &Item : InsertionPoints) { + if (Item.second == Spec) + return Item.first; + } + return InsertionPoints.emplace_back(InsertionDetail{}, Spec).first; + }; + + // FIXME: For somereason attributes like final and override aren't classed as + // part of the SourceRange. This results in the source range for some methods + // looking like this: + // [[void Foo() ]]override; + // This screws up the insertion loc and would otherwise try to insert the new + // methods before the override keyword. This method seems a bit hacky and + // should be cleaned up. + auto Next = [&](SourceLocation Loc) { + return Loc.getLocWithOffset( + Lexer::MeasureTokenLength(Loc, Inputs.AST->getSourceManager(), + Inputs.AST->getLangOpts()) + + 1); + }; + for (auto *Decl : Selected->decls()) { + if (Decl->isImplicit()) + continue; + if (isa(Decl)) { + // Hack to try and leave the destructor as last method in a block. + continue; + } + InsertionDetail &Detail = GetDetailForAccess(Decl->getAccess()); + if (isa(Decl)) { + Detail.Loc = Next(Decl->getSourceRange().getEnd()); + Detail.RefersToMethod = true; + } else if (!Detail.RefersToMethod) { + // Last decl with this access wasn't method decl. + Detail.Loc = Next(Decl->getSourceRange().getEnd()); + } + } + if (InsertionPoints.empty()) { + // No non-implicit declarations in the body, use the default access for the + // first potential insertion. + GetDetailForAccess(Selected->isClass() ? AS_private : AS_public) = + InsertionDetail{ + Selected->getBraceRange().getBegin().getLocWithOffset(1), true}; + } + + SmallString<256> Buffer; + llvm::raw_svector_ostream OS(Buffer); + tooling::Replacements Replacements; + for (auto &Item : InsertionPoints) { + assert(Item.first.Loc.isValid()); + llvm::SmallVectorImpl &GroupedMethods = + GroupedAccessMethods[Item.second]; + if (GroupedMethods.empty()) + continue; + printMethods(OS, GroupedMethods); + if (auto Err = Replacements.add(tooling::Replacement( + Inputs.AST->getSourceManager(), Item.first.Loc, 0, Buffer))) { + return std::move(Err); + } + // Clear the methods as in the fallback loop we don't want to print them + // again. + GroupedMethods.clear(); + Buffer.clear(); + } + + for (AccessSpecifier Spec : {AS_public, AS_protected, AS_private}) { + llvm::SmallVectorImpl &GroupedMethods = + GroupedAccessMethods[Spec]; + if (GroupedMethods.empty()) + continue; + printMethods(OS, GroupedMethods, getAccessSpelling(Spec)); + } + if (!Buffer.empty()) { + if (auto Err = Replacements.add(tooling::Replacement( + Inputs.AST->getSourceManager(), Selected->getBraceRange().getEnd(), + 0, Buffer))) { + return std::move(Err); + } + } + return Effect::mainFileEdit(Inputs.AST->getASTContext().getSourceManager(), + std::move(Replacements)); +} + +REGISTER_TWEAK(ImplementAbstract) + +} // namespace +} // namespace clangd +} // namespace clang diff --git a/clang-tools-extra/clangd/unittests/CMakeLists.txt b/clang-tools-extra/clangd/unittests/CMakeLists.txt --- a/clang-tools-extra/clangd/unittests/CMakeLists.txt +++ b/clang-tools-extra/clangd/unittests/CMakeLists.txt @@ -117,6 +117,7 @@ tweaks/ExpandMacroTests.cpp tweaks/ExtractFunctionTests.cpp tweaks/ExtractVariableTests.cpp + tweaks/ImplementAbstractTests.cpp tweaks/ObjCLocalizeStringLiteralTests.cpp tweaks/PopulateSwitchTests.cpp tweaks/RawStringLiteralTests.cpp diff --git a/clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp b/clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp new file mode 100644 --- /dev/null +++ b/clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp @@ -0,0 +1,217 @@ +//===-- ImplementAbstractTests.cpp ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestTU.h" +#include "TweakTesting.h" +#include "gmock/gmock-matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using ::testing::Not; + +namespace clang { +namespace clangd { +namespace { + +TWEAK_TEST(ImplementAbstract); + +TEST_F(ImplementAbstractTest, TestUnavailable) { + + StringRef Cases[]{ + // Not a pure virtual method. + R"cpp( + class A { + virtual void Foo(); + }; + class ^B : public A {}; + )cpp", + // Pure virtual method overridden in class. + R"cpp( + class A { + virtual void Foo() = 0; + }; + class ^B : public A { + void Foo() override; + }; + )cpp", + // Pure virtual method overridden in class with virtual keyword + R"cpp( + class A { + virtual void Foo() = 0; + }; + class ^B : public A { + virtual void Foo() override; + }; + )cpp", + // Pure virtual method overridden in class without override keyword + R"cpp( + class A { + virtual void Foo() = 0; + }; + class ^B : public A { + void Foo(); + }; + )cpp", + // Pure virtual method overriden in base class. + R"cpp( + class A { + virtual void Foo() = 0; + }; + class B : public A { + void Foo() override; + }; + class ^C : public B { + }; + )cpp"}; + for (const auto &Case : Cases) { + EXPECT_THAT(Case, Not(isAvailable())); + } +} + +TEST_F(ImplementAbstractTest, TestApply) { + struct Case { + llvm::StringRef TestHeader; + llvm::StringRef TestSource; + llvm::StringRef ExpectedSource; + }; + + Case Cases[]{ + { + R"cpp( + class A { + virtual void Foo() = 0; + };)cpp", + R"cpp( + class B : public A {^}; + )cpp", + R"cpp( + class B : public A { +void Foo() override; +}; + )cpp", + }, + { + R"cpp( + class A { + public: + virtual void Foo() = 0; + };)cpp", + R"cpp( + class ^B : public A {}; + )cpp", + R"cpp( + class B : public A { +public: + +void Foo() override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo(int Param) = 0; + };)cpp", + R"cpp( + class ^B : public A {}; + )cpp", + R"cpp( + class B : public A { +void Foo(int Param) override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo(int Param) = 0; + };)cpp", + R"cpp( + struct ^B : public A {}; + )cpp", + R"cpp( + struct B : public A { +private: + +void Foo(int Param) override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo(int Param) const volatile = 0; + public: + virtual void Bar(int Param) = 0; + };)cpp", + R"cpp( + class ^B : public A { + void Foo(int Param) const volatile override; + }; + )cpp", + R"cpp( + class B : public A { + void Foo(int Param) const volatile override; + +public: + +void Bar(int Param) override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo() = 0; + virtual void Bar() = 0; + }; + class B : public A { + void Foo() override; + }; + )cpp", + R"cpp( + class ^C : public B { + virtual void Baz(); + }; + )cpp", + R"cpp( + class C : public B { + virtual void Baz(); +void Bar() override; + + }; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo() = 0; + };)cpp", + R"cpp( + class ^B : public A { + ~B(); + }; + )cpp", + R"cpp( + class B : public A { +void Foo() override; + + ~B(); + }; + )cpp", + }, + }; + + for (const auto &Case : Cases) { + Header = Case.TestHeader.str(); + EXPECT_EQ(apply(Case.TestSource), Case.ExpectedSource); + } +} +} // namespace +} // namespace clangd +} // namespace clang