diff --git a/clang-tools-extra/clangd/AST.h b/clang-tools-extra/clangd/AST.h --- a/clang-tools-extra/clangd/AST.h +++ b/clang-tools-extra/clangd/AST.h @@ -94,7 +94,8 @@ /// Returns a QualType as string. The result doesn't contain unwritten scopes /// like anonymous/inline namespace. -std::string printType(const QualType QT, const DeclContext &CurContext); +std::string printType(const QualType QT, const DeclContext &CurContext, + llvm::StringRef Placeholder = ""); /// Indicates if \p D is a template instantiation implicitly generated by the /// compiler, e.g. diff --git a/clang-tools-extra/clangd/AST.cpp b/clang-tools-extra/clangd/AST.cpp --- a/clang-tools-extra/clangd/AST.cpp +++ b/clang-tools-extra/clangd/AST.cpp @@ -349,7 +349,8 @@ return SymbolID(USR); } -std::string printType(const QualType QT, const DeclContext &CurContext) { +std::string printType(const QualType QT, const DeclContext &CurContext, + const llvm::StringRef Placeholder) { std::string Result; llvm::raw_string_ostream OS(Result); PrintingPolicy PP(CurContext.getParentASTContext().getPrintingPolicy()); @@ -370,7 +371,7 @@ PrintCB PCB(&CurContext); PP.Callbacks = &PCB; - QT.print(OS, PP); + QT.print(OS, PP, Placeholder); return OS.str(); } diff --git a/clang-tools-extra/clangd/refactor/InsertionPoint.cpp b/clang-tools-extra/clangd/refactor/InsertionPoint.cpp --- a/clang-tools-extra/clangd/refactor/InsertionPoint.cpp +++ b/clang-tools-extra/clangd/refactor/InsertionPoint.cpp @@ -119,8 +119,11 @@ // Fallback: insert at the end. if (Loc.isInvalid()) Loc = endLoc(DC); + if (Loc.isInvalid()) + return error("Couldn't find a valid location for insertion"); const auto &SM = DC.getParentASTContext().getSourceManager(); - if (!SM.isWrittenInSameFile(Loc, cast(DC).getLocation())) + auto DeclLoc = cast(DC).getLocation(); + if (DeclLoc.isValid() && !SM.isWrittenInSameFile(Loc, DeclLoc)) return error("{0} body in wrong file: {1}", DC.getDeclKindName(), Loc.printToString(SM)); return tooling::Replacement(SM, Loc, 0, Code); diff --git a/clang-tools-extra/clangd/refactor/tweaks/AddSubclass.cpp b/clang-tools-extra/clangd/refactor/tweaks/AddSubclass.cpp new file mode 100644 --- /dev/null +++ b/clang-tools-extra/clangd/refactor/tweaks/AddSubclass.cpp @@ -0,0 +1,422 @@ +//===--- ExpandAutoType.cpp --------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "refactor/Tweak.h" + +#include "XRefs.h" +#include "refactor/InsertionPoint.h" +#include "support/Logger.h" +#include "clang/AST/Type.h" +#include "clang/AST/TypeLoc.h" +#include "clang/Basic/LLVM.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Error.h" +#include +#include +#include +#include + +namespace clang { +namespace clangd { +namespace { + +using MethodVector = + std::vector>; + +/// Adds a subclass implementing all virtual functions +/// Given: +/// struct Base { +/// Base(int x, std::string y); +/// /// Do something +/// virtual int foo(); +/// /// Do something else +/// virtual double bar(double x) = 0; +/// }; +/// Adds a new class: +/// struct BaseSub : public Base { +/// using Base::Base; +/// /// Do something +/// int foo() override { return Base::foo(); } +/// /// Do something else +/// double bar(double x) { return Base::bar(x); } +/// }; +/// The new is added directly below its base class in the same file. In the +/// future, a separate tweak could be used to move this class somewhere else. +/// The subclass is always called `BaseSub`, potentially disambiguated with a +/// numeric suffix to avoid name clashes. The "Rename" functionality can be used +/// to rename the new class. +/// +/// The generated methods delegate to the respective implementations of the +/// parent class by default. If not appropriate, the corresponding code can be +/// deleted by the user explicitely. We provide this default implementation to +/// guard the user against accidentally forgetting to call the base class' +/// method. It's easier to spot incorrect code than incorrectly missing code. +/// The methods bodies are defined inside the function declaration, the "Define +/// Outline" tweak can be used to move them into a `.cpp` file. +class AddSubclass : public Tweak { +public: + AddSubclass(bool UnimplementedPureVirtualOnly) + : UnimplementedPureVirtualOnly(UnimplementedPureVirtualOnly) {} + + llvm::StringLiteral kind() const override { + return CodeAction::REFACTOR_KIND; + } + bool prepare(const Selection &Inputs) override; + Expected apply(const Selection &Inputs) override; + +private: + bool UnimplementedPureVirtualOnly; + /// Cache the CXXMethodDecls, so that we do not need to search twice. + MethodVector CachedMethods; + /// Cache the CXXRecordDecl, so that we do not need to search twice. + llvm::Optional CachedLocation; +}; + +/// Variation of the `AddSubclass` tweak which overrides all +/// virtual methods. +class AddSubclassAllVirtuals : public AddSubclass { +public: + AddSubclassAllVirtuals() : AddSubclass(false) {} + const char *id() const final; + std::string title() const override { + return "Add subclass, overriding all virtual methods"; + } +}; + +/// Variation of the `AddSubclass` tweak which overrides only +/// pure virtual methods. +class AddSubclassPureVirtualOnly : public AddSubclass { +public: + AddSubclassPureVirtualOnly() : AddSubclass(true) {} + const char *id() const final; + std::string title() const override { + return "Add subclass, overriding unimplemented pure virtual methods"; + } +}; + +REGISTER_TWEAK(AddSubclassAllVirtuals) +REGISTER_TWEAK(AddSubclassPureVirtualOnly) + +// FIXME: copied from `clangd-doc`. Can I somehow deduplicate this? +static AccessSpecifier getFinalAccessSpecifier(AccessSpecifier FirstAS, + AccessSpecifier SecondAS) { + if (FirstAS == AccessSpecifier::AS_none || + SecondAS == AccessSpecifier::AS_none) + return AccessSpecifier::AS_none; + if (FirstAS == AccessSpecifier::AS_private || + SecondAS == AccessSpecifier::AS_private) + return AccessSpecifier::AS_private; + if (FirstAS == AccessSpecifier::AS_protected || + SecondAS == AccessSpecifier::AS_protected) + return AccessSpecifier::AS_protected; + return AccessSpecifier::AS_public; +} + +static void collectRelevantMethodDecls(const CXXRecordDecl &Decl, + MethodVector &Target, + AccessSpecifier InheritanceAS, + bool UnimplementedPureVirtualOnly) { + // Collect all virtual methods of all base classes + for (const CXXBaseSpecifier &Base : Decl.bases()) { + auto BaseType = Base.getType().getCanonicalType(); + auto BaseAS = Base.getAccessSpecifier(); + if (BaseAS == AccessSpecifier::AS_none) + BaseAS = AccessSpecifier::AS_private; + BaseAS = getFinalAccessSpecifier(BaseAS, InheritanceAS); + auto *BaseRecordDecl = BaseType.getTypePtr()->getAsCXXRecordDecl(); + // We simply ignore all base classes which are not `BaseRecordDecl`. Base + // classes might, e.g., also be template type arguments instead. + if (BaseRecordDecl) { + collectRelevantMethodDecls(*BaseRecordDecl, Target, BaseAS, + UnimplementedPureVirtualOnly); + } + } + + // Collect all virtual methods of this class + for (CXXMethodDecl *M : Decl.methods()) { + // Only collect actual base functions. Ignore functions which already + // override other functions from one of our base classes. + if (M->size_overridden_methods() == 0) { + bool Qualifies = + UnimplementedPureVirtualOnly ? M->isPure() : M->isVirtual(); + if (Qualifies) { + auto MethodAS = getFinalAccessSpecifier(InheritanceAS, M->getAccess()); + Target.emplace_back(MethodAS, M); + } + } + } + + // Remove all overriden methods + if (UnimplementedPureVirtualOnly) { + for (CXXMethodDecl *M : Decl.methods()) { + for (const CXXMethodDecl *Overridden : M->overridden_methods()) { + auto Iter = find_if(Target.begin(), Target.end(), + [&](auto E) { return E.second == Overridden; }); + if (Iter != Target.end()) + Target.erase(Iter); + } + } + } +} + +bool AddSubclass::prepare(const Selection &Inputs) { + // This tweak assumes move semantics. + if (!Inputs.AST->getLangOpts().CPlusPlus11) + return false; + + CachedLocation = llvm::None; + CachedMethods.clear(); + if (auto *Node = Inputs.ASTSelection.commonAncestor()) { + if (auto *Class = Node->ASTNode.get()) { + CachedLocation = Class; + collectRelevantMethodDecls(*Class, CachedMethods, + AccessSpecifier::AS_public, + UnimplementedPureVirtualOnly); + } + } + + return CachedLocation && !CachedMethods.empty(); +} + +/// Find a name which does not conflict with existing names +/// by appending a number to the name, if necessary +std::string getNewIdentifier(std::string Name, const ASTContext &AC, + const DeclContext &DC) { + unsigned Counter = 0; + auto &Idents = AC.Idents; + while (true) { + std::string NumberedName = Name; + if (Counter) { + NumberedName += std::to_string(Counter); + } + IdentifierTable::iterator IdIter = Idents.find(NumberedName); + if (IdIter == Idents.end()) { + return NumberedName; + } + const IdentifierInfo *Identifier = IdIter->getValue(); + if (DC.lookup(DeclarationName{Identifier}).empty()) { + return NumberedName; + } + ++Counter; + } +} + +struct ForwardingParamInfo { + QualType Type; + std::string Name; + bool Move; +}; + +static bool canMoveRecordDecl(const CXXRecordDecl &C) { + // We can't always tell if C is copyable/movable without doing Sema work. + // We assume operations are possible unless we can prove not. + if (C.hasUserDeclaredMoveConstructor()) { + for (const CXXConstructorDecl *CCD : C.ctors()) { + if (CCD->isMoveConstructor() && CCD->isDeleted()) { + return false; + } + } + } + return C.hasUserDeclaredMoveConstructor() || + C.needsOverloadResolutionForMoveConstructor() || + !C.defaultedMoveConstructorIsDeleted(); +} + +static bool shouldMoveType(const Type *T) { + if (auto *RecordDecl = T->getAsCXXRecordDecl()) { + return canMoveRecordDecl(*RecordDecl); + } + return false; +} + +static std::vector +prepareForwardedArgs(const FunctionDecl &Func) { + std::vector Parameters; + Parameters.reserve(Func.param_size()); + unsigned ParamNr = 0; + for (auto &Param : Func.parameters()) { + ++ParamNr; + auto Name = Param->getNameAsString(); + if (Name.empty()) { + // Synthesize name + Name = llvm::formatv("_{0}", ParamNr); + } + auto ParamType = Param->getOriginalType(); + bool Move = ParamType.getLocalUnqualifiedType()->isRValueReferenceType() || + shouldMoveType(ParamType.getCanonicalType().getTypePtr()); + Parameters.push_back({ParamType, Name, Move}); + } + return Parameters; +} + +static llvm::Expected +formatSubclassCode(const CXXRecordDecl &BaseClass, + const MethodVector &Methods) { + auto &DC = *BaseClass.getParent(); + auto &AC = BaseClass.getASTContext(); + + std::string S; + llvm::raw_string_ostream OS(S); + OS << "\n"; + + // Use the same keyword (struct or class) as the base class + AccessSpecifier CurrentAS; + if (BaseClass.isClass()) { + OS << "class "; + // We want a `private:` section header even if the first fuctions are + // private. Hence, don't set `CurrentAS` to `private` but to `none`. + CurrentAS = AccessSpecifier::AS_none; + } else { + OS << "struct "; + CurrentAS = AccessSpecifier::AS_public; + } + + // Find a class name which does not conflict with existing names + std::string SubclassName = + getNewIdentifier(BaseClass.getNameAsString() + "Sub", AC, DC); + OS << SubclassName; + + // Inherit from the base class + OS << " : public " << BaseClass.getName() << " {\n"; + + // We always inherit the constructors + OS << " using " << BaseClass.getName() << "::" << BaseClass.getName() + << ";\n"; + + // Add the methods + for (auto &M : Methods) { + AccessSpecifier MethodAS = M.first; + const CXXMethodDecl *Method = M.second; + + if (MethodAS != CurrentAS) { + OS << getAccessSpelling(MethodAS) << ":\n"; + CurrentAS = MethodAS; + } + + // Copy over the comment from the base class + auto *Comment = AC.getRawCommentForAnyRedecl(Method); + if (Comment && !Comment->isTrailingComment()) + OS << " // " << Comment->getBriefText(AC) << "\n"; + + OS << " "; + if (Method->isConstexprSpecified()) + OS << "constexpr "; + if (Method->isConsteval()) + OS << "consteval "; + + auto DeclName = Method->getDeclName(); + bool printReturnType; + std::string MethodName; + switch (DeclName.getNameKind()) { + case DeclarationName::Identifier: + printReturnType = true; + MethodName = DeclName.getAsString(); + break; + case DeclarationName::CXXDestructorName: + printReturnType = false; + MethodName = std::string{"~"} + SubclassName; + break; + case DeclarationName::CXXConversionFunctionName: + printReturnType = false; + MethodName = std::string{"operator "} + + printType(Method->getReturnType(), BaseClass); + break; + case DeclarationName::CXXOperatorName: + printReturnType = true; + MethodName = std::string{"operator"} + + getOperatorSpelling(DeclName.getCXXOverloadedOperator()); + break; + case DeclarationName::CXXConstructorName: + case DeclarationName::ObjCZeroArgSelector: + case DeclarationName::ObjCOneArgSelector: + case DeclarationName::CXXDeductionGuideName: + case DeclarationName::CXXLiteralOperatorName: + case DeclarationName::CXXUsingDirective: + case DeclarationName::ObjCMultiArgSelector: + return error("Unsupported method type `{0}`", DeclName.getNameKind()); + } + if (printReturnType) { + OS << printType(Method->getReturnType(), BaseClass); + OS << " "; + } + + OS << MethodName; + + // Print the argument list + auto ForwardedArgs = prepareForwardedArgs(*Method); + OS << "("; + const char *Sep = ""; + for (auto &Arg : ForwardedArgs) { + OS << Sep; + OS << printType(Arg.Type, BaseClass, /*Placeholder=*/Arg.Name); + Sep = ", "; + } + OS << ")"; + + if (Method->isConst()) + OS << " const"; + if (Method->getExceptionSpecType() == EST_BasicNoexcept) + OS << " noexcept"; + OS << " override"; + + if (MethodAS == AccessSpecifier::AS_private) { + // We don't provide a default implementation if the overriden method is + // private. + OS << ";\n"; + } else if (Method->getDeclName().getNameKind() == + DeclarationName::CXXDestructorName) { + OS << " = default;\n"; + } else { + // The default implementation simply delegates to the base class + OS << " { "; + if (!Method->getReturnType()->isVoidType()) + OS << "return "; + OS << BaseClass.getName() << "::" << MethodName << "("; + Sep = ""; + for (auto &Arg : ForwardedArgs) { + OS << Sep; + if (Arg.Move) + OS << "std::move("; + OS << Arg.Name; + if (Arg.Move) + OS << ")"; + Sep = ", "; + } + OS << "); }\n"; + } + } + + OS << "};\n"; + OS.flush(); + return S; +} + +Expected AddSubclass::apply(const Selection &Inputs) { + auto *Class = *CachedLocation; + auto &SM = Inputs.AST->getSourceManager(); + + auto SubclassCode = formatSubclassCode(*Class, CachedMethods); + if (!SubclassCode) + return SubclassCode.takeError(); + + tooling::Replacements Replacements; + auto Insertion = insertDecl( + *SubclassCode, *Class->getLexicalParent(), + {Anchor{[&](const Decl *D) { return D == Class; }, Anchor::Below}}); + if (!Insertion) + return Insertion.takeError(); + auto AddError = Replacements.add(std::move(*Insertion)); + if (AddError) + return AddError; + return Effect::mainFileEdit(SM, std::move(Replacements)); +} + +} // namespace +} // namespace clangd +} // namespace clang diff --git a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt --- a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt +++ b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt @@ -12,6 +12,7 @@ # $ to a list of sources, see # clangd/tool/CMakeLists.txt for an example. add_clang_library(clangDaemonTweaks OBJECT + AddSubclass.cpp AddUsing.cpp AnnotateHighlightings.cpp DumpAST.cpp diff --git a/clang-tools-extra/clangd/unittests/CMakeLists.txt b/clang-tools-extra/clangd/unittests/CMakeLists.txt --- a/clang-tools-extra/clangd/unittests/CMakeLists.txt +++ b/clang-tools-extra/clangd/unittests/CMakeLists.txt @@ -106,6 +106,7 @@ support/TestTracer.cpp support/TraceTests.cpp + tweaks/AddSubclassTests.cpp tweaks/AddUsingTests.cpp tweaks/AnnotateHighlightingsTests.cpp tweaks/DefineInlineTests.cpp diff --git a/clang-tools-extra/clangd/unittests/tweaks/AddSubclassTests.cpp b/clang-tools-extra/clangd/unittests/tweaks/AddSubclassTests.cpp new file mode 100644 --- /dev/null +++ b/clang-tools-extra/clangd/unittests/tweaks/AddSubclassTests.cpp @@ -0,0 +1,526 @@ +//===-- AddSubclassTests.cpp ------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Annotations.h" +#include "TweakTesting.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace clang { +namespace clangd { +namespace { + +TWEAK_TEST(AddSubclassAllVirtuals); + +TEST_F(AddSubclassAllVirtualsTest, Prepare) { + // Not available if there are no virtual function + EXPECT_UNAVAILABLE("^struct ^Base { void foo(); };"); + // Available on virtual functions + EXPECT_AVAILABLE("^struct ^Base { virtual void foo(); };"); + // Available on pure virtual functions + EXPECT_AVAILABLE("^struct ^Base { virtual void foo() = 0; };"); + // Available for inherited virtual functions + EXPECT_AVAILABLE(R"cpp( +struct Base { virtual void foo() = 0; }; +^struct ^Intermediate : public Base {}; +)cpp"); + // Available for inherited virtual functions, even if already overriden + EXPECT_AVAILABLE(R"cpp( +struct Base { virtual void foo() = 0; }; +^struct ^Intermediate : public Base { void foo() override; }; +)cpp"); +} + +TEST_F(AddSubclassAllVirtualsTest, ApplyInDifferenctScopes) { + struct { + llvm::StringRef TestSource; + llvm::StringRef ExpectedSource; + } Cases[]{ + // Basic case, outside any namespace + { + R"cpp( +struct ^Base { virtual int foo() = 0; }; +)cpp", + R"cpp( +struct Base { virtual int foo() = 0; }; + +struct BaseSub : public Base { + using Base::Base; + int foo() override { return Base::foo(); } +}; +)cpp", + }, + // Inserted between two classes + { + R"cpp( +struct ^Base { virtual int foo() = 0; }; +struct OtherStruct {}; +)cpp", + R"cpp( +struct Base { virtual int foo() = 0; }; + +struct BaseSub : public Base { + using Base::Base; + int foo() override { return Base::foo(); } +}; +struct OtherStruct {}; +)cpp", + }, + // Inside a namespace + { + R"cpp( +namespace NS { +struct ^Base { virtual int foo() = 0; }; +})cpp", + R"cpp( +namespace NS { +struct Base { virtual int foo() = 0; }; + +struct BaseSub : public Base { + using Base::Base; + int foo() override { return Base::foo(); } +}; +})cpp", + }, + // Inside an outer class + { + R"cpp( +struct Outer { +struct ^Base { virtual int foo() = 0; }; +};)cpp", + R"cpp( +struct Outer { +struct Base { virtual int foo() = 0; }; + +struct BaseSub : public Base { + using Base::Base; + int foo() override { return Base::foo(); } +}; +};)cpp", + }, + // Chooses a fresh unused name + { + R"cpp( +struct ^Base { virtual int foo() = 0; }; +struct BaseSub; +struct BaseSub1; +struct BaseSub2; +struct BaseSub4; +)cpp", + R"cpp( +struct Base { virtual int foo() = 0; }; + +struct BaseSub3 : public Base { + using Base::Base; + int foo() override { return Base::foo(); } +}; +struct BaseSub; +struct BaseSub1; +struct BaseSub2; +struct BaseSub4; +)cpp", + }, + // Does not accidentally collide with a comment on the last line + // of the file without a newline. + { + R"cpp( +struct ^Base { virtual int foo() = 0; }; +// Some comment...)cpp", + R"cpp( +struct Base { virtual int foo() = 0; }; +// Some comment... +struct BaseSub : public Base { + using Base::Base; + int foo() override { return Base::foo(); } +}; +)cpp", + }, + // Is not confused by forward declarations (which might leak in, + // e.g., through `#include`s). Inserts the subclass directly after the + // function we trigerred the result refactoring on. + { + R"cpp( +struct Base; +struct ^Base { virtual int foo() = 0; }; +struct Base; +)cpp", + R"cpp( +struct Base; +struct Base { virtual int foo() = 0; }; + +struct BaseSub : public Base { + using Base::Base; + int foo() override { return Base::foo(); } +}; +struct Base; +)cpp", + }, + }; + llvm::StringMap EditedFiles; + for (const auto &Case : Cases) { + for (const auto &SubCase : expandCases(Case.TestSource)) { + EXPECT_EQ(apply(SubCase, &EditedFiles), Case.ExpectedSource); + } + } +} + +TEST_F(AddSubclassAllVirtualsTest, GeneratesCorrectSubclass) { + struct { + llvm::StringRef BaseClass; + llvm::StringRef GeneratedSubclass; + std::vector ExtraArgs = {}; + } Cases[]{ + // Basic case; generating a `struct` inheriting from the base class + { + R"cpp( +struct ^Base { virtual int foo() = 0; }; +)cpp", + R"cpp( +struct BaseSub : public Base { + using Base::Base; + int foo() override { return Base::foo(); } +}; +)cpp"}, + // Only overrides virtual functions; leaves other functions alone + { + R"cpp( +struct ^Base { + virtual int foo() = 0; + int bar(); +}; +)cpp", + R"cpp( +struct BaseSub : public Base { + using Base::Base; + int foo() override { return Base::foo(); } +}; +)cpp"}, + // Also supports overriding the virtual destructor, overloaded operators + // and conversion functions. + { + R"cpp( +struct ^Base { + virtual ~Base() = 0; + virtual operator double() = 0; + virtual int operator[](int X) = 0; +}; +)cpp", + R"cpp( +struct BaseSub : public Base { + using Base::Base; + ~BaseSub() override = default; + operator double() override { return Base::operator double(); } + int operator[](int X) override { return Base::operator[](X); } +}; +)cpp"}, + // Function attributes like `const`, `noexcept` etc. are copied + { + R"cpp( +struct ^Base { + consteval virtual int foo() noexcept; + constexpr virtual operator double() const; +}; +)cpp", + R"cpp( +struct BaseSub : public Base { + using Base::Base; + consteval int foo() noexcept override { return Base::foo(); } + constexpr operator double() const override { return Base::operator double(); } +}; +)cpp", + {"-std=c++20"}}, + // Uses `class` instead of struct if the base class was also a `class` + { + R"cpp( +class ^Base { +public: + virtual int foo() = 0; +}; +)cpp", + R"cpp( +class BaseSub : public Base { + using Base::Base; +public: + int foo() override { return Base::foo(); } +}; +)cpp"}, + // Default implementation does not contain a `return` for void functions + { + R"cpp( +struct ^Base { + virtual void foo() = 0; +}; +)cpp", + R"cpp( +struct BaseSub : public Base { + using Base::Base; + void foo() override { Base::foo(); } +}; +)cpp"}, + // No default implementation for private functions. + // We can't call the private implementation of the base class. + { + R"cpp( +struct ^Base { +private: + virtual void foo() = 0; +}; +)cpp", + R"cpp( +struct BaseSub : public Base { + using Base::Base; +private: + void foo() override; +}; +)cpp"}, + // Default implementation forwards parameters + { + R"cpp( +struct Moveable { + Moveable() = default; + Moveable(Moveable&&) = default; +}; + +struct ^Base { + virtual void foo(int a) = 0; + virtual void bar(int a, double b) = 0; + virtual void baz(int, double b, char) = 0; + virtual void foobar(int&& x) = 0; + virtual void foobaz(Moveable x) = 0; +}; +)cpp", + R"cpp( +struct BaseSub : public Base { + using Base::Base; + void foo(int a) override { Base::foo(a); } + void bar(int a, double b) override { Base::bar(a, b); } + void baz(int _1, double b, char _3) override { Base::baz(_1, b, _3); } + void foobar(int &&x) override { Base::foobar(std::move(x)); } + void foobaz(Moveable x) override { Base::foobaz(std::move(x)); } +}; +)cpp"}, + // Can expand multiple overloaded functions + { + R"cpp( +struct ^Base { + virtual void foo(int a) = 0; + virtual void foo(double b) = 0; +}; +)cpp", + R"cpp( +struct BaseSub : public Base { + using Base::Base; + void foo(int a) override { Base::foo(a); } + void foo(double b) override { Base::foo(b); } +}; +)cpp"}, + // Collects virtual functions from *all* base classes + { + R"cpp( +struct Base1 { + virtual void foo() = 0; +}; +struct Intermediate : public Base1 { + virtual void bar() = 0; +}; +struct Base2 { + virtual void baz() = 0; +}; +struct ^Base : public Intermediate, Base2 { + virtual void foobar() = 0; +}; +)cpp", + R"cpp( +struct BaseSub : public Base { + using Base::Base; + void foo() override { Base::foo(); } + void bar() override { Base::bar(); } + void baz() override { Base::baz(); } + void foobar() override { Base::foobar(); } +}; +)cpp", + }, + // Correctly propagates visibility + { + R"cpp( +struct ^Base { + virtual int publicFoo() = 0; +private: + virtual int privateFoo() = 0; +protected: + virtual int protectedFoo() = 0; +public: + virtual int publicBar() = 0; +}; +)cpp", + R"cpp( +struct BaseSub : public Base { + using Base::Base; + int publicFoo() override { return Base::publicFoo(); } +private: + int privateFoo() override; +protected: + int protectedFoo() override { return Base::protectedFoo(); } +public: + int publicBar() override { return Base::publicBar(); } +}; +)cpp", + }, + // Keeps the structuring into multiple `public`/`protected`/`private` + // blocks form the base class + { + R"cpp( +class ^Base { +public: + virtual int publicFoo() = 0; + virtual int publicBar() = 0; +protected: + virtual int protectedFoo() = 0; +public: + virtual int publicBaz() = 0; +}; +)cpp", + R"cpp( +class BaseSub : public Base { + using Base::Base; +public: + int publicFoo() override { return Base::publicFoo(); } + int publicBar() override { return Base::publicBar(); } +protected: + int protectedFoo() override { return Base::protectedFoo(); } +public: + int publicBaz() override { return Base::publicBaz(); } +}; +)cpp", + }, + // Correctly propagates visibility also for non-public inheritance + { + R"cpp( +struct Base1 { + virtual void foo() = 0; +private: + virtual void privateFoo() = 0; +}; +struct Intermediate : protected Base1 { + virtual void bar() = 0; +}; +struct Base2 { + virtual void baz() = 0; +}; +struct ^Base : public Intermediate, private Base2 { + virtual void foobar() = 0; +}; +)cpp", + R"cpp( +struct BaseSub : public Base { + using Base::Base; +protected: + void foo() override { Base::foo(); } +private: + void privateFoo() override; +public: + void bar() override { Base::bar(); } +private: + void baz() override; +public: + void foobar() override { Base::foobar(); } +}; +)cpp", + }, + // Copies comments + { + R"cpp( +struct ^Base { + // Some comment + virtual void foo() = 0; + + // A method with a brief description + // + // And a longer description + virtual void bar() = 0; +}; +)cpp", + R"cpp( +struct BaseSub : public Base { + using Base::Base; + // Some comment + void foo() override { Base::foo(); } + // A method with a brief description + void bar() override { Base::bar(); } +}; +)cpp", + }, + }; + llvm::StringMap EditedFiles; + for (const auto &Case : Cases) { + ExtraArgs = Case.ExtraArgs; + Annotations Code(Case.BaseClass); + for (const auto &SubCase : expandCases(Case.BaseClass)) { + EXPECT_EQ(apply(SubCase, &EditedFiles), + (Code.code() + Case.GeneratedSubclass).str()); + } + } +} + +// We do not test everything again, but only test the difference in behavior +TWEAK_TEST(AddSubclassPureVirtualOnly); + +TEST_F(AddSubclassPureVirtualOnlyTest, + OnlyOverridesNonImplementedVirtualFunctions) { + struct { + llvm::StringRef BaseClass; + llvm::StringRef GeneratedSubclass; + } Cases[]{ + // Only overrides pure virtual functions; leaves other virtual functions + // alone + { + R"cpp( +struct ^Base { + virtual int foo() = 0; + virtual int bar(); +}; +)cpp", + R"cpp( +struct BaseSub : public Base { + using Base::Base; + int foo() override { return Base::foo(); } +}; +)cpp"}, + // Does not override pure virtual functions if they were already + // implemented by an intermediate class + // and conversion functions. + { + R"cpp( +struct Base { + virtual int foo() = 0; + virtual int bar() = 0; +}; +struct ^Intermediate : public Base { + int foo() override; +}; +)cpp", + R"cpp( +struct IntermediateSub : public Intermediate { + using Intermediate::Intermediate; + int bar() override { return Intermediate::bar(); } +}; +)cpp"}, + }; + llvm::StringMap EditedFiles; + for (const auto &Case : Cases) { + Annotations Code(Case.BaseClass); + for (const auto &SubCase : expandCases(Case.BaseClass)) { + EXPECT_EQ(apply(SubCase, &EditedFiles), + (Code.code() + Case.GeneratedSubclass).str()); + } + } +} + +} // namespace +} // namespace clangd +} // namespace clang