diff --git a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt --- a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt +++ b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt @@ -21,6 +21,7 @@ ExpandMacro.cpp ExtractFunction.cpp ExtractVariable.cpp + ImplementAbstract.cpp ObjCLocalizeStringLiteral.cpp PopulateSwitch.cpp RawStringLiteral.cpp diff --git a/clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp b/clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp new file mode 100644 --- /dev/null +++ b/clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp @@ -0,0 +1,308 @@ +//===--- ImplementAbstract.cpp -----------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "refactor/Tweak.h" +#include "support/Logger.h" +#include "clang/Basic/Specifiers.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" + +namespace clang { +namespace clangd { + +namespace { + +using MethodAndAccess = + llvm::PointerIntPair; + +AccessSpecifier getMostConstrained(AccessSpecifier InheritSpecifier, + AccessSpecifier DefinedAs) { + return std::max(InheritSpecifier, DefinedAs); +} + +bool collectPureVirtual(const CXXRecordDecl &Record, + llvm::SmallVectorImpl &Results, + AccessSpecifier Access, + llvm::SmallPtrSetImpl &Overrides, + bool IsRoot) { + if (Record.getNumBases() > 0) { + for (const CXXMethodDecl *Method : Record.methods()) { + if (!Method->isVirtual()) + continue; + // If we have any pure virtual methods declared in the root (The class + // this tweak was invoked on), assume the user probably doesn't want to + // implement all abstract methods as the class will still be astract. + if (IsRoot && Method->isPure()) + return true; + for (const auto *Overriding : Method->overridden_methods()) + Overrides.insert(Overriding); + } + for (auto Base : Record.bases()) { + const RecordType *RT = Base.getType()->getAs(); + if (!RT) + // Probably a dependent base, just error out. + return true; + const CXXRecordDecl *BaseDecl = cast(RT->getDecl()); + if (!BaseDecl->isPolymorphic()) + continue; + if (collectPureVirtual( + *BaseDecl, Results, + getMostConstrained(Access, Base.getAccessSpecifier()), Overrides, + false)) + // Propergate any error back up. + return true; + } + } else { + assert(!IsRoot && "We should have filtered out this case already"); + } + // Add the Pure methods from this class after traversing the bases. This means + // when it comes time to create implementation, methods from classes higher up + // the heirachy will appear first. + for (const CXXMethodDecl *Method : Record.methods()) { + if (!Method->isPure()) + continue; + if (!Overrides.contains(Method)) + Results.emplace_back(Method, + getMostConstrained(Access, Method->getAccess())); + } + return false; +} + +const CXXRecordDecl *getSelectedRecord(const Tweak::Selection &Inputs) { + // FIXME: This method won't return the class when the caret in the body of the + // class. So the only way to get the tweak offered is to be be touching the + // marked ranges. It would be nicer if this was offered if cursor was inside + // the class (but perhaps not inside the classes decls). + // [[class]] [[Derived]] [[:]] [[public]] Base [[{]] + // ^ + // [[}]]; + if (const SelectionTree::Node *Node = Inputs.ASTSelection.commonAncestor()) + return Node->ASTNode.get(); + return nullptr; +} + +/// Some quick to check basic heuristics to check before we try and collect +/// virtual methods. +bool isClassOK(const CXXRecordDecl &RecordDecl) { + if (!RecordDecl.isThisDeclarationADefinition()) + return false; + if (!RecordDecl.isClass() && !RecordDecl.isStruct()) + return false; + if (RecordDecl.hasAnyDependentBases() || RecordDecl.getNumBases() == 0) + return false; + // We should check for abstract, but that prevents working on template classes + // that don't have any dependent bases. + if (!RecordDecl.isPolymorphic()) + return false; + return true; +} + +struct InsertionDetail { + SourceLocation Loc = {}; + AccessSpecifier Access; + unsigned char AfterPriority = 0; +}; + +// This is a little hacky because EndLoc of a decl doesn't include +// the semi-colon. +auto getLocAfterDecl(const Decl &D, const SourceManager &SM, + const LangOptions &LO) { + if (D.hasBody()) + return D.getEndLoc().getLocWithOffset(1); + if (auto Next = Lexer::findNextToken(D.getEndLoc(), SM, LO)) { + if (Next->is(tok::semi)) + return Next->getEndLoc(); + } + return D.getEndLoc().getLocWithOffset(1); +} + +/// Generate insertion points in \p R that don't require inserting access +/// specifiers. The insertion points generally try to appear after the last +/// method declared in the class with a specific access. \p ShouldIncludeAccess +/// is a way to avoid generating insertion points for access specifiers we +/// aren't going to fill in. +SmallVector +getInsertionPoints(const CXXRecordDecl &R, ArrayRef ShouldIncludeAccess, + const SourceManager &SM, const LangOptions &LO) { + SmallVector Result; + auto GetDetailForAccess = [&](AccessSpecifier Spec) -> InsertionDetail & { + assert(Spec != AS_none); + for (InsertionDetail &Item : Result) { + if (Item.Access == Spec) + return Item; + } + return Result.emplace_back(InsertionDetail{{}, Spec}); + }; + + // This whole block is designed to get an insertion point after the last + // method has been declared with each access specifier. Doing this ensures we + // keep the same visibility for implemented methods without the need to add + // unnecessary access specifiers. + for (auto *Decl : R.decls()) { + if (!ShouldIncludeAccess[Decl->getAccess()]) + continue; + // Ignore things like compiler generated special member functions. + if (Decl->isImplicit()) + continue; + // Hack to try and leave the destructor as last method in a block. + if (isa(Decl)) + continue; + InsertionDetail &Detail = GetDetailForAccess(Decl->getAccess()); + if (isa(Decl)) { + Detail.Loc = getLocAfterDecl(*Decl, SM, LO); + Detail.AfterPriority = 2; + } else { + // Try to put methods after access spec but before fields. + auto Priority = isa(Decl) ? 1 : 0; + if (Detail.AfterPriority <= Priority) { + Detail.Loc = getLocAfterDecl(*Decl, SM, LO); + Detail.AfterPriority = Priority; + } + } + } + if (Result.empty()) { + auto Access = R.isClass() ? AS_private : AS_public; + if (ShouldIncludeAccess[Access]) { + // An empty class so start inserting methods that don't need an access + // specifier just after the open curly brace. + GetDetailForAccess(Access).Loc = + R.getBraceRange().getBegin().getLocWithOffset(1); + } + } + return Result; +} + +void printMethods(llvm::raw_ostream &Out, ArrayRef Items, + AccessSpecifier AccessKind, const CXXRecordDecl *PrintContext, + bool PrintAccessSpec) { + class PrintCB : public PrintingCallbacks { + public: + PrintCB(const DeclContext *CurContext) : CurContext(CurContext) {} + virtual ~PrintCB() {} + bool isScopeVisible(const DeclContext *DC) const override { + return DC->Encloses(CurContext); + } + + private: + const DeclContext *CurContext; + }; + PrintCB Callbacks(PrintContext); + auto Policy = PrintContext->getASTContext().getPrintingPolicy(); + Policy.SuppressScope = false; + Policy.Callbacks = &Callbacks; + if (PrintAccessSpec) + Out << "\n" << getAccessSpelling(AccessKind) << ":\n"; + Out << "\n"; + for (const auto &MethodAndAccess : Items) { + if (MethodAndAccess.getInt() != AccessKind) + continue; + const CXXMethodDecl *Method = MethodAndAccess.getPointer(); + Method->getReturnType().print(Out, Policy); + Out << ' '; + Out << Method->getNameAsString() << "("; + bool IsFirst = true; + for (const auto &Param : Method->parameters()) { + if (!IsFirst) + Out << ", "; + else + IsFirst = false; + Param->print(Out, Policy); + } + Out << ") "; + if (Method->isConst()) + Out << "const "; + if (Method->isVolatile()) + Out << "volatile "; + // Always suggest `override` over `final`. + Out << "override;\n"; + } +} + +class ImplementAbstract : public Tweak { +public: + const char *id() const override; + + bool prepare(const Selection &Inputs) override { + Selected = getSelectedRecord(Inputs); + if (!Selected) + return false; + if (!isClassOK(*Selected)) + return false; + llvm::SmallPtrSet Overrides; + if (collectPureVirtual(*Selected, PureVirtualMethods, AS_public, Overrides, + true)) + return false; + return !PureVirtualMethods.empty(); + } + + Expected apply(const Selection &Inputs) override { + // We should have at least one pure virtual method to add. + assert(!PureVirtualMethods.empty() && + "Prepare returned true when no methodx existed"); + bool AccessNeedsProcessing[3] = {0}; + for (auto Item : PureVirtualMethods) { + AccessNeedsProcessing[Item.getInt()] = true; + } + + auto InsertionPoints = getInsertionPoints(*Selected, AccessNeedsProcessing, + Inputs.AST->getSourceManager(), + Inputs.AST->getLangOpts()); + SmallString<256> Buffer; + llvm::raw_svector_ostream OS(Buffer); + tooling::Replacements Replacements; + for (auto &Item : InsertionPoints) { + assert(Item.Loc.isValid()); + if (!AccessNeedsProcessing[Item.Access]) + continue; + AccessNeedsProcessing[Item.Access] = false; + printMethods(OS, PureVirtualMethods, Item.Access, Selected, + /*PrintAccessSpec=*/false); + if (auto Err = Replacements.add(tooling::Replacement( + Inputs.AST->getSourceManager(), Item.Loc, 0, Buffer))) { + return std::move(Err); + } + Buffer.clear(); + } + + // Any access specifiers not convered can be added in one insertion. + for (AccessSpecifier Spec : {AS_public, AS_protected, AS_private}) { + if (!AccessNeedsProcessing[Spec]) + continue; + printMethods(OS, PureVirtualMethods, Spec, Selected, + /*PrintAccessSpec=*/true); + } + if (!Buffer.empty()) { + if (auto Err = Replacements.add(tooling::Replacement( + Inputs.AST->getSourceManager(), + Selected->getBraceRange().getEnd(), 0, Buffer))) { + return std::move(Err); + } + } + return Effect::mainFileEdit(Inputs.AST->getASTContext().getSourceManager(), + std::move(Replacements)); + } + + std::string title() const override { + return "Implement pure virtual methods"; + } + + llvm::StringLiteral kind() const override { + return CodeAction::REFACTOR_KIND; + } + +private: + const CXXRecordDecl *Selected; + llvm::SmallVector PureVirtualMethods; +}; + +REGISTER_TWEAK(ImplementAbstract) + +} // namespace +} // namespace clangd +} // namespace clang diff --git a/clang-tools-extra/clangd/unittests/CMakeLists.txt b/clang-tools-extra/clangd/unittests/CMakeLists.txt --- a/clang-tools-extra/clangd/unittests/CMakeLists.txt +++ b/clang-tools-extra/clangd/unittests/CMakeLists.txt @@ -118,6 +118,7 @@ tweaks/ExpandMacroTests.cpp tweaks/ExtractFunctionTests.cpp tweaks/ExtractVariableTests.cpp + tweaks/ImplementAbstractTests.cpp tweaks/ObjCLocalizeStringLiteralTests.cpp tweaks/PopulateSwitchTests.cpp tweaks/RawStringLiteralTests.cpp diff --git a/clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp b/clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp new file mode 100644 --- /dev/null +++ b/clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp @@ -0,0 +1,349 @@ +//===-- ImplementAbstractTests.cpp ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestTU.h" +#include "TweakTesting.h" +#include "gmock/gmock-matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using ::testing::Not; + +namespace clang { +namespace clangd { +namespace { + +TWEAK_TEST(ImplementAbstract); + +TEST_F(ImplementAbstractTest, TestUnavailable) { + + StringRef Cases[]{ + // Not a pure virtual method. + R"cpp( + class A { + virtual void Foo(); + }; + class ^B : public A {}; + )cpp", + // Pure virtual method overridden in class. + R"cpp( + class A { + virtual void Foo() = 0; + }; + class ^B : public A { + void Foo() override; + }; + )cpp", + // Pure virtual method overridden in class with virtual keyword + R"cpp( + class A { + virtual void Foo() = 0; + }; + class ^B : public A { + virtual void Foo() override; + }; + )cpp", + // Pure virtual method overridden in class without override keyword + R"cpp( + class A { + virtual void Foo() = 0; + }; + class ^B : public A { + void Foo(); + }; + )cpp", + // Pure virtual method overriden in base class. + R"cpp( + class A { + virtual void Foo() = 0; + }; + class B : public A { + void Foo() override; + }; + class ^C : public B { + }; + )cpp"}; + for (const auto &Case : Cases) { + EXPECT_THAT(Case, Not(isAvailable())); + } +} + +TEST_F(ImplementAbstractTest, NormalAvailable) { + struct Case { + llvm::StringRef TestHeader; + llvm::StringRef TestSource; + llvm::StringRef ExpectedSource; + }; + + Case Cases[]{ + { + R"cpp( + class A { + virtual void Foo() = 0; + };)cpp", + R"cpp( + class B : public A {^}; + )cpp", + R"cpp( + class B : public A { +void Foo() override; +}; + )cpp", + }, + { + R"cpp( + class A { + public: + virtual void Foo() = 0; + };)cpp", + R"cpp( + class ^B : public A {}; + )cpp", + R"cpp( + class B : public A { +public: + +void Foo() override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo(int Param) = 0; + };)cpp", + R"cpp( + class ^B : public A {}; + )cpp", + R"cpp( + class B : public A { +void Foo(int Param) override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo(int Param) = 0; + };)cpp", + R"cpp( + struct ^B : public A {}; + )cpp", + R"cpp( + struct B : public A { +private: + +void Foo(int Param) override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo(int Param) const volatile = 0; + public: + virtual void Bar(int Param) = 0; + };)cpp", + R"cpp( + class ^B : public A { + void Foo(int Param) const volatile override; + }; + )cpp", + R"cpp( + class B : public A { + void Foo(int Param) const volatile override; + +public: + +void Bar(int Param) override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo() = 0; + virtual void Bar() = 0; + }; + class B : public A { + void Foo() override; + }; + )cpp", + R"cpp( + class ^C : public B { + virtual void Baz(); + }; + )cpp", + R"cpp( + class C : public B { + virtual void Baz(); +void Bar() override; + + }; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo() = 0; + };)cpp", + R"cpp( + class ^B : public A { + ~B(); + }; + )cpp", + R"cpp( + class B : public A { +void Foo() override; + + ~B(); + }; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo() = 0; + public: + virtual void Bar() = 0; + };)cpp", + R"cpp( + class ^B : public A { + }; + )cpp", + R"cpp( + class B : public A { +void Foo() override; + + +public: + +void Bar() override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo() = 0; + }; + struct B : public A { + virtual void Bar() = 0; + };)cpp", + R"cpp( + class ^C : public B { + }; + )cpp", + R"cpp( + class C : public B { +void Foo() override; + + +public: + +void Bar() override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo() = 0; + }; + struct B : public A { + virtual void Bar() = 0; + };)cpp", + R"cpp( + class ^C : private B { + }; + )cpp", + R"cpp( + class C : private B { +void Foo() override; +void Bar() override; + + }; + )cpp", + }, + }; + + for (const auto &Case : Cases) { + Header = Case.TestHeader.str(); + EXPECT_EQ(apply(Case.TestSource), Case.ExpectedSource); + } +} + +TEST_F(ImplementAbstractTest, TemplateUnavailable) { + StringRef Cases[]{ + R"cpp( + template + class A { + virtual void Foo() = 0; + }; + template + class ^B : public A {}; + )cpp", + R"cpp( + template + class ^B : public T {}; + )cpp", + }; + for (const auto &Case : Cases) { + EXPECT_THAT(Case, Not(isAvailable())); + } +} + +TEST_F(ImplementAbstractTest, TemplateAvailable) { + struct Case { + llvm::StringRef TestHeader; + llvm::StringRef TestSource; + llvm::StringRef ExpectedSource; + }; + Case Cases[]{ + { + R"cpp( + template + class A { + virtual void Foo() = 0; + }; + )cpp", + R"cpp( + class ^B : public A {}; + )cpp", + R"cpp( + class B : public A { +void Foo() override; +}; + )cpp", + }, + { + R"cpp( + class A { + virtual void Foo() = 0; + };)cpp", + R"cpp( + template + class ^B : public A {}; + )cpp", + R"cpp( + template + class B : public A { +void Foo() override; +}; + )cpp", + }, + }; + for (const auto &Case : Cases) { + Header = Case.TestHeader.str(); + EXPECT_EQ(apply(Case.TestSource), Case.ExpectedSource); + } +} + +} // namespace +} // namespace clangd +} // namespace clang