Index: clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp =================================================================== --- clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp +++ clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp @@ -56,6 +56,7 @@ #include "clang/AST/ASTContext.h" #include "clang/AST/Decl.h" #include "clang/AST/DeclBase.h" +#include "clang/AST/NestedNameSpecifier.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/Stmt.h" #include "clang/Basic/LangOptions.h" @@ -71,6 +72,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Error.h" +#include "llvm/Support/raw_os_ostream.h" namespace clang { namespace clangd { @@ -88,6 +90,12 @@ OutsideFunc // Outside EnclosingFunction. }; +enum FunctionDeclKind { + InlineDefinition, + ForwardDeclaration, + OutOfLineDefinition +}; + // A RootStmt is a statement that's fully selected including all it's children // and it's parent is unselected. // Check if a node is a root statement. @@ -237,9 +245,6 @@ if (CurNode->ASTNode.get()) return nullptr; if (const FunctionDecl *Func = CurNode->ASTNode.get()) { - // FIXME: Support extraction from methods. - if (isa(Func)) - return nullptr; // FIXME: Support extraction from templated functions. if (Func->isTemplated()) return nullptr; @@ -343,34 +348,53 @@ QualType ReturnType; std::vector Parameters; SourceRange BodyRange; - SourceLocation InsertionPoint; - const DeclContext *EnclosingFuncContext; + SourceLocation DefinitionPoint; + llvm::Optional ForwardDeclarationPoint; + const CXXRecordDecl *EnclosingClass = nullptr; + const NestedNameSpecifier *DefinitionQualifier = nullptr; + const DeclContext *SemanticDC = nullptr; + const DeclContext *SyntacticDC = nullptr; + const DeclContext *ForwardDeclarationSyntacticDC = nullptr; bool CallerReturnsValue = false; + bool Static = false; + ConstexprSpecKind Constexpr = ConstexprSpecKind::Unspecified; + bool Const = false; + // Decides whether the extracted function body and the function call need a // semicolon after extraction. tooling::ExtractionSemicolonPolicy SemicolonPolicy; - NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy) - : SemicolonPolicy(SemicolonPolicy) {} + const LangOptions *LangOpts; + NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy, + const LangOptions *LangOpts) + : SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {} // Render the call for this function. std::string renderCall() const; // Render the definition for this function. - std::string renderDefinition(const SourceManager &SM) const; + std::string renderDeclaration(FunctionDeclKind K, + const DeclContext &SemanticDC, + const DeclContext &SyntacticDC, + const SourceManager &SM) const; private: - std::string renderParametersForDefinition() const; + std::string + renderParametersForDeclaration(const DeclContext &Enclosing) const; std::string renderParametersForCall() const; + std::string renderSpecifiers(FunctionDeclKind K) const; + std::string renderQualifiers() const; + std::string renderDeclarationName(FunctionDeclKind K) const; // Generate the function body. std::string getFuncBody(const SourceManager &SM) const; }; -std::string NewFunction::renderParametersForDefinition() const { +std::string NewFunction::renderParametersForDeclaration( + const DeclContext &Enclosing) const { std::string Result; bool NeedCommaBefore = false; for (const Parameter &P : Parameters) { if (NeedCommaBefore) Result += ", "; NeedCommaBefore = true; - Result += P.render(EnclosingFuncContext); + Result += P.render(&Enclosing); } return Result; } @@ -387,6 +411,49 @@ return Result; } +std::string NewFunction::renderSpecifiers(FunctionDeclKind K) const { + std::string Attributes; + + if (Static && K != FunctionDeclKind::OutOfLineDefinition) { + Attributes += "static "; + } + + switch (Constexpr) { + case ConstexprSpecKind::Unspecified: + case ConstexprSpecKind::Constinit: + break; + case ConstexprSpecKind::Constexpr: + Attributes += "constexpr "; + break; + case ConstexprSpecKind::Consteval: + Attributes += "consteval "; + break; + } + + return Attributes; +} + +std::string NewFunction::renderQualifiers() const { + std::string Attributes; + + if (Const) { + Attributes += " const"; + } + + return Attributes; +} + +std::string NewFunction::renderDeclarationName(FunctionDeclKind K) const { + if (DefinitionQualifier == nullptr || K != OutOfLineDefinition) { + return Name; + } + + std::string QualifierName; + llvm::raw_string_ostream Oss(QualifierName); + DefinitionQualifier->print(Oss, *LangOpts); + return llvm::formatv("{0}{1}", QualifierName, Name); +} + std::string NewFunction::renderCall() const { return std::string( llvm::formatv("{0}{1}({2}){3}", CallerReturnsValue ? "return " : "", Name, @@ -394,10 +461,24 @@ (SemicolonPolicy.isNeededInOriginalFunction() ? ";" : ""))); } -std::string NewFunction::renderDefinition(const SourceManager &SM) const { - return std::string(llvm::formatv( - "{0} {1}({2}) {\n{3}\n}\n", printType(ReturnType, *EnclosingFuncContext), - Name, renderParametersForDefinition(), getFuncBody(SM))); +std::string NewFunction::renderDeclaration(FunctionDeclKind K, + const DeclContext &SemanticDC, + const DeclContext &SyntacticDC, + const SourceManager &SM) const { + std::string Declaration = std::string(llvm::formatv( + "{0}{1} {2}({3}){4}", renderSpecifiers(K), + printType(ReturnType, SyntacticDC), renderDeclarationName(K), + renderParametersForDeclaration(SemanticDC), renderQualifiers())); + + switch (K) { + case ForwardDeclaration: + return std::string(llvm::formatv("{0};\n", Declaration)); + case OutOfLineDefinition: + case InlineDefinition: + return std::string( + llvm::formatv("{0} {\n{1}\n}\n", Declaration, getFuncBody(SM))); + break; + } } std::string NewFunction::getFuncBody(const SourceManager &SM) const { @@ -668,11 +749,40 @@ if (CapturedInfo.BrokenControlFlow) return error("Cannot extract break/continue without corresponding " "loop/switch statement."); - NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts)); + NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts), + &LangOpts); + + ExtractedFunc.ForwardDeclarationSyntacticDC = ExtractedFunc.SemanticDC = + ExtractedFunc.SyntacticDC = ExtZone.EnclosingFunction->getDeclContext(); + + if (isa(ExtZone.EnclosingFunction)) { + const auto *Method = + llvm::dyn_cast(ExtZone.EnclosingFunction); + + ExtractedFunc.Static = Method->isStatic(); + ExtractedFunc.Constexpr = Method->getConstexprKind(); + ExtractedFunc.Const = Method->isConst(); + ExtractedFunc.DefinitionQualifier = + ExtZone.EnclosingFunction->getQualifier(); + + ExtractedFunc.EnclosingClass = Method->getParent(); + ExtractedFunc.SemanticDC = ExtractedFunc.EnclosingClass; + ExtractedFunc.ForwardDeclarationSyntacticDC = ExtractedFunc.SemanticDC; + + if (Method->isOutOfLine()) { + // FIXME: Put the extracted method in a private section of the class + const auto *FirstOriginalDecl = Method->getCanonicalDecl(); + auto DeclPos = toHalfOpenFileRange(SM, LangOpts, + FirstOriginalDecl->getSourceRange()); + ExtractedFunc.ForwardDeclarationPoint = DeclPos.getValue().getBegin(); + } else { + ExtractedFunc.SyntacticDC = ExtractedFunc.SemanticDC; + } + } + ExtractedFunc.BodyRange = ExtZone.ZoneRange; - ExtractedFunc.InsertionPoint = ExtZone.getInsertionPoint(); - ExtractedFunc.EnclosingFuncContext = - ExtZone.EnclosingFunction->getDeclContext(); + ExtractedFunc.DefinitionPoint = ExtZone.getInsertionPoint(); + ExtractedFunc.CallerReturnsValue = CapturedInfo.AlwaysReturns; if (!createParameters(ExtractedFunc, CapturedInfo) || !generateReturnProperties(ExtractedFunc, *ExtZone.EnclosingFunction, @@ -706,8 +816,47 @@ tooling::Replacement createFunctionDefinition(const NewFunction &ExtractedFunc, const SourceManager &SM) { - std::string FunctionDef = ExtractedFunc.renderDefinition(SM); - return tooling::Replacement(SM, ExtractedFunc.InsertionPoint, 0, FunctionDef); + std::string FunctionDef; + + FunctionDeclKind DeclKind = InlineDefinition; + + if (ExtractedFunc.ForwardDeclarationPoint.hasValue()) { + DeclKind = OutOfLineDefinition; + } + + FunctionDef = ExtractedFunc.renderDeclaration( + DeclKind, *ExtractedFunc.SemanticDC, *ExtractedFunc.SyntacticDC, SM); + + return tooling::Replacement(SM, ExtractedFunc.DefinitionPoint, 0, + FunctionDef); +} + +tooling::Replacement createForwardDeclaration(const NewFunction &ExtractedFunc, + const SourceManager &SM) { + auto FunctionDecl = ExtractedFunc.renderDeclaration( + ForwardDeclaration, *ExtractedFunc.SemanticDC, + *ExtractedFunc.ForwardDeclarationSyntacticDC, SM); + auto DeclPoint = ExtractedFunc.ForwardDeclarationPoint.getValue(); + + return tooling::Replacement(SM, DeclPoint, 0, FunctionDecl); +} + +llvm::Error +addForwardDeclarationTweakEffect(llvm::Expected &FileEdit, + const NewFunction &ExtractedFunc, + const SourceManager &SM) { + auto DeclPoint = ExtractedFunc.ForwardDeclarationPoint.getValue(); + tooling::Replacements ResultDecl; + if (auto Err = ResultDecl.add(createForwardDeclaration(ExtractedFunc, SM))) + return Err; + auto PathAndEdit = Tweak::Effect::fileEdit(SM, SM.getFileID(DeclPoint), + std::move(ResultDecl)); + if (!PathAndEdit) + return PathAndEdit.takeError(); + + FileEdit->ApplyEdits.try_emplace(PathAndEdit->first, PathAndEdit->second); + + return llvm::Error::success(); } // Returns true if ExtZone contains any ReturnStmts. @@ -762,7 +911,22 @@ return std::move(Err); if (auto Err = Result.add(replaceWithFuncCall(*ExtractedFunc, SM, LangOpts))) return std::move(Err); - return Effect::mainFileEdit(SM, std::move(Result)); + + if (!ExtractedFunc->ForwardDeclarationPoint.hasValue()) { + return Effect::mainFileEdit(SM, std::move(Result)); + } + + if (Result.add(createForwardDeclaration(*ExtractedFunc, SM)).success()) { + return Effect::mainFileEdit(SM, std::move(Result)); + } + + auto MainRes = Effect::mainFileEdit(SM, std::move(Result)); + if (MainRes) { + if (auto Err = + addForwardDeclarationTweakEffect(MainRes, *ExtractedFunc, SM)) + return std::move(Err); + } + return MainRes; } } // namespace Index: clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp =================================================================== --- clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp +++ clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp @@ -114,16 +114,48 @@ // Don't extract when we need to make a function as a parameter. EXPECT_THAT(apply("void f() { [[int a; f();]] }"), StartsWith("fail")); - // We don't extract from methods for now since they may involve multi-file - // edits - std::string MethodFailInput = R"cpp( + std::string MethodInput = R"cpp( class T { void f() { [[int x;]] } }; )cpp"; - EXPECT_EQ(apply(MethodFailInput), "unavailable"); + std::string MethodCheckOutput = R"cpp( + class T { + void extracted() { +int x; +} +void f() { + extracted(); + } + }; + )cpp"; + EXPECT_EQ(apply(MethodInput), MethodCheckOutput); + + std::string OutOfLineMethodInput = R"cpp( + class T { + void f(); + }; + + void T::f() { + [[int x;]] + } + )cpp"; + std::string OutOfLineMethodCheckOutput = R"cpp( + class T { + void extracted(); +void f(); + }; + + void T::extracted() { +int x; +} +void T::f() { + extracted(); + } + )cpp"; + EXPECT_EQ(apply(OutOfLineMethodInput), OutOfLineMethodCheckOutput); // We don't extract from templated functions for now as templates are hard // to deal with. @@ -159,6 +191,305 @@ EXPECT_EQ(apply(CompoundFailInput), "unavailable"); } +TEST_F(ExtractFunctionTest, DifferentHeaderSourceTest) { + Header = R"cpp( + class SomeClass { + void f(); + }; + )cpp"; + + std::string OutOfLineSource = R"cpp( + void SomeClass::f() { + [[int x;]] + } + )cpp"; + + std::string OutOfLineSourceOutputCheck = R"cpp( + void SomeClass::extracted() { +int x; +} +void SomeClass::f() { + extracted(); + } + )cpp"; + + std::string HeaderOutputCheck = R"cpp( + class SomeClass { + void extracted(); +void f(); + }; + )cpp"; + + llvm::StringMap EditedFiles; + + EXPECT_EQ(apply(OutOfLineSource, &EditedFiles), OutOfLineSourceOutputCheck); + EXPECT_EQ(EditedFiles.begin()->second, HeaderOutputCheck); +} + +TEST_F(ExtractFunctionTest, DifferentFilesNestedTest) { + Header = R"cpp( + class T { + class SomeClass { + void f(); + }; + }; + )cpp"; + + std::string NestedOutOfLineSource = R"cpp( + void T::SomeClass::f() { + [[int x;]] + } + )cpp"; + + std::string NestedOutOfLineSourceOutputCheck = R"cpp( + void T::SomeClass::extracted() { +int x; +} +void T::SomeClass::f() { + extracted(); + } + )cpp"; + + std::string NestedHeaderOutputCheck = R"cpp( + class T { + class SomeClass { + void extracted(); +void f(); + }; + }; + )cpp"; + + llvm::StringMap EditedFiles; + + EXPECT_EQ(apply(NestedOutOfLineSource, &EditedFiles), + NestedOutOfLineSourceOutputCheck); + EXPECT_EQ(EditedFiles.begin()->second, NestedHeaderOutputCheck); +} + +TEST_F(ExtractFunctionTest, ConstexprDifferentHeaderSourceTest) { + Header = R"cpp( + class SomeClass { + constexpr void f() const; + }; + )cpp"; + + std::string OutOfLineSource = R"cpp( + constexpr void SomeClass::f() const { + [[int x;]] + } + )cpp"; + + std::string OutOfLineSourceOutputCheck = R"cpp( + constexpr void SomeClass::extracted() const { +int x; +} +constexpr void SomeClass::f() const { + extracted(); + } + )cpp"; + + std::string HeaderOutputCheck = R"cpp( + class SomeClass { + constexpr void extracted() const; +constexpr void f() const; + }; + )cpp"; + + llvm::StringMap EditedFiles; + + EXPECT_EQ(apply(OutOfLineSource, &EditedFiles), OutOfLineSourceOutputCheck); + EXPECT_NE(EditedFiles.begin(), EditedFiles.end()) + << "The header should be edited and receives the declaration of the new " + "function"; + + if (EditedFiles.begin() != EditedFiles.end()) { + EXPECT_EQ(EditedFiles.begin()->second, HeaderOutputCheck); + } +} + +TEST_F(ExtractFunctionTest, ConstevalDifferentHeaderSourceTest) { + ExtraArgs.push_back("--std=c++20"); + Header = R"cpp( + class SomeClass { + consteval void f() const; + }; + )cpp"; + + std::string OutOfLineSource = R"cpp( + consteval void SomeClass::f() const { + [[int x;]] + } + )cpp"; + + std::string OutOfLineSourceOutputCheck = R"cpp( + consteval void SomeClass::extracted() const { +int x; +} +consteval void SomeClass::f() const { + extracted(); + } + )cpp"; + + std::string HeaderOutputCheck = R"cpp( + class SomeClass { + consteval void extracted() const; +consteval void f() const; + }; + )cpp"; + + llvm::StringMap EditedFiles; + + EXPECT_EQ(apply(OutOfLineSource, &EditedFiles), OutOfLineSourceOutputCheck); + EXPECT_NE(EditedFiles.begin(), EditedFiles.end()) + << "The header should be edited and receives the declaration of the new " + "function"; + + if (EditedFiles.begin() != EditedFiles.end()) { + EXPECT_EQ(EditedFiles.begin()->second, HeaderOutputCheck); + } +} + +TEST_F(ExtractFunctionTest, ConstDifferentHeaderSourceTest) { + Header = R"cpp( + class SomeClass { + void f() const; + }; + )cpp"; + + std::string OutOfLineSource = R"cpp( + void SomeClass::f() const { + [[int x;]] + } + )cpp"; + + std::string OutOfLineSourceOutputCheck = R"cpp( + void SomeClass::extracted() const { +int x; +} +void SomeClass::f() const { + extracted(); + } + )cpp"; + + std::string HeaderOutputCheck = R"cpp( + class SomeClass { + void extracted() const; +void f() const; + }; + )cpp"; + + llvm::StringMap EditedFiles; + + EXPECT_EQ(apply(OutOfLineSource, &EditedFiles), OutOfLineSourceOutputCheck); + EXPECT_NE(EditedFiles.begin(), EditedFiles.end()) + << "The header should be edited and receives the declaration of the new " + "function"; + + if (EditedFiles.begin() != EditedFiles.end()) { + EXPECT_EQ(EditedFiles.begin()->second, HeaderOutputCheck); + } +} + +TEST_F(ExtractFunctionTest, StaticDifferentHeaderSourceTest) { + Header = R"cpp( + class SomeClass { + static void f(); + }; + )cpp"; + + std::string OutOfLineSource = R"cpp( + void SomeClass::f() { + [[int x;]] + } + )cpp"; + + std::string OutOfLineSourceOutputCheck = R"cpp( + void SomeClass::extracted() { +int x; +} +void SomeClass::f() { + extracted(); + } + )cpp"; + + std::string HeaderOutputCheck = R"cpp( + class SomeClass { + static void extracted(); +static void f(); + }; + )cpp"; + + llvm::StringMap EditedFiles; + + EXPECT_EQ(apply(OutOfLineSource, &EditedFiles), OutOfLineSourceOutputCheck); + EXPECT_NE(EditedFiles.begin(), EditedFiles.end()) + << "The header should be edited and receives the declaration of the new " + "function"; + + if (EditedFiles.begin() != EditedFiles.end()) { + EXPECT_EQ(EditedFiles.begin()->second, HeaderOutputCheck); + } +} + +TEST_F(ExtractFunctionTest, DifferentContextHeaderSourceTest) { + Header = R"cpp( + namespace ns{ + class A { + class C { + public: + class RType {}; + }; + + class T { + class SomeClass { + static C::RType f(); + }; + }; + }; + } // ns + )cpp"; + + std::string OutOfLineSource = R"cpp( + ns::A::C::RType ns::A::T::SomeClass::f() { + [[A::C::RType x; + return x;]] + } + )cpp"; + + std::string OutOfLineSourceOutputCheck = R"cpp( + ns::A::C::RType ns::A::T::SomeClass::extracted() { +A::C::RType x; + return x; +} +ns::A::C::RType ns::A::T::SomeClass::f() { + return extracted(); + } + )cpp"; + + std::string HeaderOutputCheck = R"cpp( + namespace ns{ + class A { + class C { + public: + class RType {}; + }; + + class T { + class SomeClass { + static ns::A::C::RType extracted(); +static C::RType f(); + }; + }; + }; + } // ns + )cpp"; + + llvm::StringMap EditedFiles; + + EXPECT_EQ(apply(OutOfLineSource, &EditedFiles), OutOfLineSourceOutputCheck); + EXPECT_EQ(EditedFiles.begin()->second, HeaderOutputCheck); +} + TEST_F(ExtractFunctionTest, ControlFlow) { Context = Function; // We should be able to extract break/continue with a parent loop/switch.