Index: clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp =================================================================== --- clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp +++ clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp @@ -56,6 +56,7 @@ #include "clang/AST/ASTContext.h" #include "clang/AST/Decl.h" #include "clang/AST/DeclBase.h" +#include "clang/AST/NestedNameSpecifier.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/Stmt.h" #include "clang/Basic/LangOptions.h" @@ -71,6 +72,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Error.h" +#include "llvm/Support/raw_os_ostream.h" namespace clang { namespace clangd { @@ -88,6 +90,12 @@ OutsideFunc // Outside EnclosingFunction. }; +enum FunctionDeclKind { + InlineDefinition, + ForwardDeclaration, + OutOfLineDefinition +}; + // A RootStmt is a statement that's fully selected including all it's children // and it's parent is unselected. // Check if a node is a root statement. @@ -237,9 +245,6 @@ if (CurNode->ASTNode.get()) return nullptr; if (const FunctionDecl *Func = CurNode->ASTNode.get()) { - // FIXME: Support extraction from methods. - if (isa(Func)) - return nullptr; // FIXME: Support extraction from templated functions. if (Func->isTemplated()) return nullptr; @@ -343,34 +348,62 @@ QualType ReturnType; std::vector Parameters; SourceRange BodyRange; - SourceLocation InsertionPoint; - const DeclContext *EnclosingFuncContext; + SourceLocation DefinitionPoint; + llvm::Optional ForwardDeclarationPoint; + const CXXRecordDecl *EnclosingClass = nullptr; + NestedNameSpecifier *NestedNameSpec = nullptr; + const DeclContext *EnclosingFuncContext = nullptr; bool CallerReturnsValue = false; + bool Static = false; + ConstexprSpecKind Constexpr = ConstexprSpecKind::Unspecified; + bool Const = false; + + const DeclContext &getEnclosing() const; + // Decides whether the extracted function body and the function call need a // semicolon after extraction. tooling::ExtractionSemicolonPolicy SemicolonPolicy; - NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy) - : SemicolonPolicy(SemicolonPolicy) {} + const LangOptions &LangOpts; + NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy, + const LangOptions &LangOpts) + : SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {} // Render the call for this function. std::string renderCall() const; // Render the definition for this function. - std::string renderDefinition(const SourceManager &SM) const; + std::string renderDefinition(FunctionDeclKind K, const DeclContext &Enclosing, + const SourceManager &SM) const; + std::string renderDeclaration(FunctionDeclKind K, + const DeclContext &Enclosing, + const SourceManager &SM) const; private: - std::string renderParametersForDefinition() const; + std::string + renderParametersForDeclaration(const DeclContext &Enclosing) const; std::string renderParametersForCall() const; + std::string renderSpecifiers(FunctionDeclKind K) const; + std::string renderQualifiers() const; + std::string renderNestedName() const; // Generate the function body. std::string getFuncBody(const SourceManager &SM) const; }; -std::string NewFunction::renderParametersForDefinition() const { +const DeclContext &NewFunction::getEnclosing() const { + if (EnclosingClass != nullptr && ForwardDeclarationPoint.hasValue()) { + return *EnclosingClass; + } + + return *EnclosingFuncContext; +} + +std::string NewFunction::renderParametersForDeclaration( + const DeclContext &Enclosing) const { std::string Result; bool NeedCommaBefore = false; for (const Parameter &P : Parameters) { if (NeedCommaBefore) Result += ", "; NeedCommaBefore = true; - Result += P.render(EnclosingFuncContext); + Result += P.render(&Enclosing); } return Result; } @@ -387,6 +420,49 @@ return Result; } +std::string NewFunction::renderSpecifiers(FunctionDeclKind K) const { + std::string Attributes; + + if (Static && K != FunctionDeclKind::OutOfLineDefinition) { + Attributes += "static "; + } + + switch (Constexpr) { + case ConstexprSpecKind::Unspecified: + case ConstexprSpecKind::Constinit: + break; + case ConstexprSpecKind::Constexpr: + Attributes += "constexpr "; + break; + case ConstexprSpecKind::Consteval: + Attributes += "consteval "; + break; + } + + return Attributes; +} + +std::string NewFunction::renderQualifiers() const { + std::string Attributes; + + if (Const) { + Attributes += " const"; + } + + return Attributes; +} + +std::string NewFunction::renderNestedName() const { + if (NestedNameSpec == nullptr) { + return {}; + } + + std::string NestedName; + llvm::raw_string_ostream Oss(NestedName); + NestedNameSpec->print(Oss, LangOpts); + return NestedName; +} + std::string NewFunction::renderCall() const { return std::string( llvm::formatv("{0}{1}({2}){3}", CallerReturnsValue ? "return " : "", Name, @@ -394,10 +470,34 @@ (SemicolonPolicy.isNeededInOriginalFunction() ? ";" : ""))); } -std::string NewFunction::renderDefinition(const SourceManager &SM) const { +std::string NewFunction::renderDefinition(FunctionDeclKind K, + const DeclContext &Enclosing, + const SourceManager &SM) const { return std::string(llvm::formatv( - "{0} {1}({2}) {\n{3}\n}\n", printType(ReturnType, *EnclosingFuncContext), - Name, renderParametersForDefinition(), getFuncBody(SM))); + "{0} {\n{1}\n}\n", renderDeclaration(K, Enclosing, SM), getFuncBody(SM))); +} + +std::string NewFunction::renderDeclaration(FunctionDeclKind K, + const DeclContext &Enclosing, + const SourceManager &SM) const { + std::string FullName; + + switch (K) { + case ForwardDeclaration: + FullName = Name; + break; + case OutOfLineDefinition: + FullName = llvm::formatv("{0}{1}", renderNestedName(), Name); + break; + case InlineDefinition: + FullName = Name; + break; + } + + return std::string(llvm::formatv("{0}{1} {2}({3}){4}", renderSpecifiers(K), + printType(ReturnType, Enclosing), FullName, + renderParametersForDeclaration(Enclosing), + renderQualifiers())); } std::string NewFunction::getFuncBody(const SourceManager &SM) const { @@ -668,9 +768,31 @@ if (CapturedInfo.BrokenControlFlow) return error("Cannot extract break/continue without corresponding " "loop/switch statement."); - NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts)); + NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts), + LangOpts); + + if (isa(ExtZone.EnclosingFunction)) { + const auto *Method = + llvm::dyn_cast(ExtZone.EnclosingFunction); + const auto *FirstOriginalDecl = Method->getCanonicalDecl(); + + ExtractedFunc.Static = Method->isStatic(); + ExtractedFunc.Constexpr = Method->getConstexprKind(); + ExtractedFunc.Const = Method->isConst(); + ExtractedFunc.NestedNameSpec = ExtZone.EnclosingFunction->getQualifier(); + + ExtractedFunc.EnclosingClass = Method->getParent(); + + if (Method->isOutOfLine()) { + // FIXME: Put the extracted method in a private section of the class + auto DeclPos = toHalfOpenFileRange(SM, LangOpts, + FirstOriginalDecl->getSourceRange()); + ExtractedFunc.ForwardDeclarationPoint = DeclPos.getValue(); + } + } + ExtractedFunc.BodyRange = ExtZone.ZoneRange; - ExtractedFunc.InsertionPoint = ExtZone.getInsertionPoint(); + ExtractedFunc.DefinitionPoint = ExtZone.getInsertionPoint(); ExtractedFunc.EnclosingFuncContext = ExtZone.EnclosingFunction->getDeclContext(); ExtractedFunc.CallerReturnsValue = CapturedInfo.AlwaysReturns; @@ -706,8 +828,48 @@ tooling::Replacement createFunctionDefinition(const NewFunction &ExtractedFunc, const SourceManager &SM) { - std::string FunctionDef = ExtractedFunc.renderDefinition(SM); - return tooling::Replacement(SM, ExtractedFunc.InsertionPoint, 0, FunctionDef); + std::string FunctionDef; + + const auto &Enclosing = ExtractedFunc.getEnclosing(); + + FunctionDeclKind DeclKind = InlineDefinition; + + if (ExtractedFunc.ForwardDeclarationPoint.hasValue()) { + DeclKind = OutOfLineDefinition; + } + + FunctionDef = ExtractedFunc.renderDefinition(DeclKind, Enclosing, SM); + + return tooling::Replacement(SM, ExtractedFunc.DefinitionPoint, 0, + FunctionDef); +} + +tooling::Replacement createFunctionDeclaration(const NewFunction &ExtractedFunc, + const SourceManager &SM) { + const auto &Enclosing = ExtractedFunc.getEnclosing(); + auto FunctionDecl = + ExtractedFunc.renderDeclaration(ForwardDeclaration, Enclosing, SM) + + ";\n"; + auto DeclPoint = ExtractedFunc.ForwardDeclarationPoint.getValue(); + + return tooling::Replacement(SM, DeclPoint.getBegin(), 0, FunctionDecl); +} + +llvm::Error addDeclarationTweakEffect(llvm::Expected &FileEdit, + const NewFunction &ExtractedFunc, + const SourceManager &SM) { + auto DeclPoint = ExtractedFunc.ForwardDeclarationPoint.getValue(); + tooling::Replacements ResultDecl; + if (auto Err = ResultDecl.add(createFunctionDeclaration(ExtractedFunc, SM))) + return Err; + auto PathAndEdit = Tweak::Effect::fileEdit( + SM, SM.getFileID(DeclPoint.getBegin()), std::move(ResultDecl)); + if (!PathAndEdit) + return PathAndEdit.takeError(); + + FileEdit->ApplyEdits.try_emplace(PathAndEdit->first, PathAndEdit->second); + + return llvm::Error::success(); } // Returns true if ExtZone contains any ReturnStmts. @@ -762,7 +924,21 @@ return std::move(Err); if (auto Err = Result.add(replaceWithFuncCall(*ExtractedFunc, SM, LangOpts))) return std::move(Err); - return Effect::mainFileEdit(SM, std::move(Result)); + + if (!ExtractedFunc->ForwardDeclarationPoint.hasValue()) { + return Effect::mainFileEdit(SM, std::move(Result)); + } + + if (Result.add(createFunctionDeclaration(*ExtractedFunc, SM)).success()) { + return Effect::mainFileEdit(SM, std::move(Result)); + } + + auto MainRes = Effect::mainFileEdit(SM, std::move(Result)); + if (MainRes) { + if (auto Err = addDeclarationTweakEffect(MainRes, *ExtractedFunc, SM)) + return std::move(Err); + } + return MainRes; } } // namespace Index: clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp =================================================================== --- clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp +++ clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp @@ -114,16 +114,48 @@ // Don't extract when we need to make a function as a parameter. EXPECT_THAT(apply("void f() { [[int a; f();]] }"), StartsWith("fail")); - // We don't extract from methods for now since they may involve multi-file - // edits - std::string MethodFailInput = R"cpp( + std::string MethodInput = R"cpp( class T { void f() { [[int x;]] } }; )cpp"; - EXPECT_EQ(apply(MethodFailInput), "unavailable"); + std::string MethodCheckOutput = R"cpp( + class T { + void extracted() { +int x; +} +void f() { + extracted(); + } + }; + )cpp"; + EXPECT_EQ(apply(MethodInput), MethodCheckOutput); + + std::string OutOfLineMethodInput = R"cpp( + class T { + void f(); + }; + + void T::f() { + [[int x;]] + } + )cpp"; + std::string OutOfLineMethodCheckOutput = R"cpp( + class T { + void extracted(); +void f(); + }; + + void T::extracted() { +int x; +} +void T::f() { + extracted(); + } + )cpp"; + EXPECT_EQ(apply(OutOfLineMethodInput), OutOfLineMethodCheckOutput); // We don't extract from templated functions for now as templates are hard // to deal with. @@ -159,6 +191,198 @@ EXPECT_EQ(apply(CompoundFailInput), "unavailable"); } +TEST_F(ExtractFunctionTest, DifferentHeaderSourceTest) { + Header = R"cpp( + class SomeClass { + void f(); + }; + )cpp"; + + std::string OutOfLineSource = R"cpp( + void SomeClass::f() { + [[int x;]] + } + )cpp"; + + std::string OutOfLineSourceOutputCheck = R"cpp( + void SomeClass::extracted() { +int x; +} +void SomeClass::f() { + extracted(); + } + )cpp"; + + std::string HeaderOutputCheck = R"cpp( + class SomeClass { + void extracted(); +void f(); + }; + )cpp"; + + llvm::StringMap EditedFiles; + + EXPECT_EQ(apply(OutOfLineSource, &EditedFiles), OutOfLineSourceOutputCheck); + EXPECT_EQ(EditedFiles.begin()->second, HeaderOutputCheck); +} + +TEST_F(ExtractFunctionTest, DifferentFilesNestedTest) { + Header = R"cpp( + class T { + class SomeClass { + void f(); + }; + }; + )cpp"; + + std::string NestedOutOfLineSource = R"cpp( + void T::SomeClass::f() { + [[int x;]] + } + )cpp"; + + std::string NestedOutOfLineSourceOutputCheck = R"cpp( + void T::SomeClass::extracted() { +int x; +} +void T::SomeClass::f() { + extracted(); + } + )cpp"; + + std::string NestedHeaderOutputCheck = R"cpp( + class T { + class SomeClass { + void extracted(); +void f(); + }; + }; + )cpp"; + + llvm::StringMap EditedFiles; + + EXPECT_EQ(apply(NestedOutOfLineSource, &EditedFiles), + NestedOutOfLineSourceOutputCheck); + EXPECT_EQ(EditedFiles.begin()->second, NestedHeaderOutputCheck); +} + +TEST_F(ExtractFunctionTest, ConstexprDifferentHeaderSourceTest) { + Header = R"cpp( + class SomeClass { + constexpr void f() const; + }; + )cpp"; + + std::string OutOfLineSource = R"cpp( + constexpr void SomeClass::f() const { + [[int x;]] + } + )cpp"; + + std::string OutOfLineSourceOutputCheck = R"cpp( + constexpr void SomeClass::extracted() const { +int x; +} +constexpr void SomeClass::f() const { + extracted(); + } + )cpp"; + + std::string HeaderOutputCheck = R"cpp( + class SomeClass { + constexpr void extracted() const; +constexpr void f() const; + }; + )cpp"; + + llvm::StringMap EditedFiles; + + EXPECT_EQ(apply(OutOfLineSource, &EditedFiles), OutOfLineSourceOutputCheck); + EXPECT_NE(EditedFiles.begin(), EditedFiles.end()) + << "The header should be edited and receives the declaration of the new " + "function"; + + if (EditedFiles.begin() != EditedFiles.end()) { + EXPECT_EQ(EditedFiles.begin()->second, HeaderOutputCheck); + } +} + +TEST_F(ExtractFunctionTest, ConstDifferentHeaderSourceTest) { + Header = R"cpp( + class SomeClass { + void f() const; + }; + )cpp"; + + std::string OutOfLineSource = R"cpp( + void SomeClass::f() const { + [[int x;]] + } + )cpp"; + + std::string OutOfLineSourceOutputCheck = R"cpp( + void SomeClass::extracted() const { +int x; +} +void SomeClass::f() const { + extracted(); + } + )cpp"; + + std::string HeaderOutputCheck = R"cpp( + class SomeClass { + void extracted() const; +void f() const; + }; + )cpp"; + + llvm::StringMap EditedFiles; + + EXPECT_EQ(apply(OutOfLineSource, &EditedFiles), OutOfLineSourceOutputCheck); + EXPECT_NE(EditedFiles.begin(), EditedFiles.end()) + << "The header should be edited and receives the declaration of the new " + "function"; + + if (EditedFiles.begin() != EditedFiles.end()) { + EXPECT_EQ(EditedFiles.begin()->second, HeaderOutputCheck); + } +} + +TEST_F(ExtractFunctionTest, StaticDifferentHeaderSourceTest) { + Header = R"cpp( + class SomeClass { + static void f(); + }; + )cpp"; + + std::string OutOfLineSource = R"cpp( + void SomeClass::f() { + [[int x;]] + } + )cpp"; + + std::string OutOfLineSourceOutputCheck = R"cpp( + void SomeClass::extracted() { +int x; +} +void SomeClass::f() { + extracted(); + } + )cpp"; + + std::string HeaderOutputCheck = R"cpp( + class SomeClass { + static void extracted(); +static void f(); + }; + )cpp"; + + llvm::StringMap EditedFiles; + + EXPECT_EQ(apply(OutOfLineSource, &EditedFiles), OutOfLineSourceOutputCheck); + EXPECT_EQ(EditedFiles.begin()->second, HeaderOutputCheck); +} + TEST_F(ExtractFunctionTest, ControlFlow) { Context = Function; // We should be able to extract break/continue with a parent loop/switch.