Index: clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp =================================================================== --- clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp +++ clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp @@ -56,6 +56,8 @@ #include "clang/AST/ASTContext.h" #include "clang/AST/Decl.h" #include "clang/AST/DeclBase.h" +#include "clang/AST/ExternalASTSource.h" +#include "clang/AST/NestedNameSpecifier.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/Stmt.h" #include "clang/Basic/LangOptions.h" @@ -71,6 +73,8 @@ #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_os_ostream.h" namespace clang { namespace clangd { @@ -88,6 +92,12 @@ OutsideFunc // Outside EnclosingFunction. }; +enum FunctionDeclKind { + InlineDefinition, + ForwardDeclaration, + OutOfLineDefinition +}; + // A RootStmt is a statement that's fully selected including all it's children // and it's parent is unselected. // Check if a node is a root statement. @@ -340,45 +350,62 @@ QualType ReturnType; std::vector Parameters; SourceRange BodyRange; - SourceLocation InsertionPoint; - llvm::Optional DeclarationPoint; - llvm::Optional ParentContext; - const DeclContext *EnclosingFuncContext; + SourceLocation DefinitionPoint; + llvm::Optional ForwardDeclarationPoint; + llvm::Optional EnclosingClass; + NestedNameSpecifier *NestedNameSpec = nullptr; + const DeclContext *EnclosingFuncContext = nullptr; bool CallerReturnsValue = false; bool Static = false; - bool Constexpr = false; - bool OutOfLine = false; - bool ContextConst = false; + ConstexprSpecKind Constexpr = ConstexprSpecKind::Unspecified; + bool Const = false; + + const DeclContext &getEnclosing() const; + // Decides whether the extracted function body and the function call need a // semicolon after extraction. tooling::ExtractionSemicolonPolicy SemicolonPolicy; - NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy) - : SemicolonPolicy(SemicolonPolicy) {} + const LangOptions &LangOpts; + NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy, + const LangOptions &LangOpts) + : SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {} // Render the call for this function. std::string renderCall() const; // Render the definition for this function. - std::string renderInlineDefinition(const SourceManager &SM) const; - std::string renderDefinition(const SourceManager &SM) const; - std::string renderDeclaration(const SourceManager &SM) const; + std::string renderDefinition(FunctionDeclKind K, const DeclContext &Enclosing, + const SourceManager &SM) const; + std::string renderDeclaration(FunctionDeclKind K, + const DeclContext &Enclosing, + const SourceManager &SM) const; private: - std::string renderParametersForDefinition() const; + std::string + renderParametersForDeclaration(const DeclContext &Enclosing) const; std::string renderParametersForCall() const; - std::string renderAttributes() const; - std::string renderAttributesAfter() const; - std::string renderNamespaceAndClass() const; + std::string renderSpecifiers(FunctionDeclKind K) const; + std::string renderQualifiers() const; + std::string renderNestedName() const; // Generate the function body. std::string getFuncBody(const SourceManager &SM) const; }; -std::string NewFunction::renderParametersForDefinition() const { +const DeclContext &NewFunction::getEnclosing() const { + if (EnclosingClass.hasValue() && ForwardDeclarationPoint.hasValue()) { + return *EnclosingClass.getValue(); + } + + return *EnclosingFuncContext; +} + +std::string NewFunction::renderParametersForDeclaration( + const DeclContext &Enclosing) const { std::string Result; bool NeedCommaBefore = false; for (const Parameter &P : Parameters) { if (NeedCommaBefore) Result += ", "; NeedCommaBefore = true; - Result += P.render(EnclosingFuncContext); + Result += P.render(&Enclosing); } return Result; } @@ -395,39 +422,47 @@ return Result; } -std::string NewFunction::renderAttributes() const { +std::string NewFunction::renderSpecifiers(FunctionDeclKind K) const { std::string Attributes; - if (Static) { + if (Static && K != FunctionDeclKind::OutOfLineDefinition) { Attributes += "static "; } - if (Constexpr) { + switch (Constexpr) { + case ConstexprSpecKind::Unspecified: + case ConstexprSpecKind::Constinit: + break; + case ConstexprSpecKind::Constexpr: Attributes += "constexpr "; + break; + case ConstexprSpecKind::Consteval: + Attributes += "consteval "; + break; } return Attributes; } -std::string NewFunction::renderAttributesAfter() const { +std::string NewFunction::renderQualifiers() const { std::string Attributes; - if (ContextConst) { + if (Const) { Attributes += " const"; } return Attributes; } -std::string NewFunction::renderNamespaceAndClass() const { - std::string NamespaceClass; - - if (ParentContext) { - NamespaceClass = ParentContext.getValue()->getNameAsString(); - NamespaceClass += "::"; +std::string NewFunction::renderNestedName() const { + if (NestedNameSpec == nullptr) { + return {}; } - return NamespaceClass; + std::string NestedName; + llvm::raw_string_ostream Oss(NestedName); + NestedNameSpec->print(Oss, LangOpts); + return NestedName; } std::string NewFunction::renderCall() const { @@ -437,24 +472,34 @@ (SemicolonPolicy.isNeededInOriginalFunction() ? ";" : ""))); } -std::string NewFunction::renderInlineDefinition(const SourceManager &SM) const { - return std::string( - llvm::formatv("{0} {\n{1}\n}\n", renderDeclaration(SM), getFuncBody(SM))); +std::string NewFunction::renderDefinition(FunctionDeclKind K, + const DeclContext &Enclosing, + const SourceManager &SM) const { + return std::string(llvm::formatv( + "{0} {\n{1}\n}\n", renderDeclaration(K, Enclosing, SM), getFuncBody(SM))); } -std::string NewFunction::renderDefinition(const SourceManager &SM) const { - return std::string(llvm::formatv("{0} {1}{2}({3}){4} {\n{5}\n}\n", - printType(ReturnType, *EnclosingFuncContext), - renderNamespaceAndClass(), Name, - renderParametersForDefinition(), - renderAttributesAfter(), getFuncBody(SM))); -} +std::string NewFunction::renderDeclaration(FunctionDeclKind K, + const DeclContext &Enclosing, + const SourceManager &SM) const { + std::string FullName; + + switch (K) { + case ForwardDeclaration: + FullName = Name; + break; + case OutOfLineDefinition: + FullName = llvm::formatv("{0}{1}", renderNestedName(), Name); + break; + case InlineDefinition: + FullName = Name; + break; + } -std::string NewFunction::renderDeclaration(const SourceManager &SM) const { - return std::string(llvm::formatv("{0}{1} {2}({3}){4}", renderAttributes(), - printType(ReturnType, *EnclosingFuncContext), - Name, renderParametersForDefinition(), - renderAttributesAfter())); + return std::string(llvm::formatv("{0}{1} {2}({3}){4}", renderSpecifiers(K), + printType(ReturnType, Enclosing), FullName, + renderParametersForDeclaration(Enclosing), + renderQualifiers())); } std::string NewFunction::getFuncBody(const SourceManager &SM) const { @@ -725,27 +770,31 @@ if (CapturedInfo.BrokenControlFlow) return error("Cannot extract break/continue without corresponding " "loop/switch statement."); - NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts)); + NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts), + LangOpts); if (isa(ExtZone.EnclosingFunction)) { const auto *Method = llvm::dyn_cast(ExtZone.EnclosingFunction); const auto *FirstOriginalDecl = Method->getCanonicalDecl(); - ExtractedFunc.OutOfLine = Method->isOutOfLine(); ExtractedFunc.Static = Method->isStatic(); - ExtractedFunc.Constexpr = Method->isConstexpr(); - ExtractedFunc.ContextConst = Method->isConst(); + ExtractedFunc.Constexpr = Method->getConstexprKind(); + ExtractedFunc.Const = Method->isConst(); + ExtractedFunc.NestedNameSpec = ExtZone.EnclosingFunction->getQualifier(); - auto DeclPos = - toHalfOpenFileRange(SM, LangOpts, FirstOriginalDecl->getSourceRange()); + ExtractedFunc.EnclosingClass = Method->getParent(); - ExtractedFunc.ParentContext = Method->getParent(); - ExtractedFunc.DeclarationPoint = DeclPos.getValue(); + if (Method->isOutOfLine()) { + // FIXME: Put the extracted method in a private section of the class + auto DeclPos = toHalfOpenFileRange(SM, LangOpts, + FirstOriginalDecl->getSourceRange()); + ExtractedFunc.ForwardDeclarationPoint = DeclPos.getValue(); + } } ExtractedFunc.BodyRange = ExtZone.ZoneRange; - ExtractedFunc.InsertionPoint = ExtZone.getInsertionPoint(); + ExtractedFunc.DefinitionPoint = ExtZone.getInsertionPoint(); ExtractedFunc.EnclosingFuncContext = ExtZone.EnclosingFunction->getDeclContext(); ExtractedFunc.CallerReturnsValue = CapturedInfo.AlwaysReturns; @@ -783,19 +832,27 @@ const SourceManager &SM) { std::string FunctionDef; - if (ExtractedFunc.OutOfLine) { - FunctionDef = ExtractedFunc.renderDefinition(SM); - } else { - FunctionDef = ExtractedFunc.renderInlineDefinition(SM); + const auto &Enclosing = ExtractedFunc.getEnclosing(); + + FunctionDeclKind DeclKind = InlineDefinition; + + if (ExtractedFunc.ForwardDeclarationPoint.hasValue()) { + DeclKind = OutOfLineDefinition; } - return tooling::Replacement(SM, ExtractedFunc.InsertionPoint, 0, FunctionDef); + FunctionDef = ExtractedFunc.renderDefinition(DeclKind, Enclosing, SM); + + return tooling::Replacement(SM, ExtractedFunc.DefinitionPoint, 0, + FunctionDef); } tooling::Replacement createFunctionDeclaration(const NewFunction &ExtractedFunc, const SourceManager &SM) { - auto FunctionDecl = ExtractedFunc.renderDeclaration(SM) + ";\n"; - auto DeclPoint = ExtractedFunc.DeclarationPoint.getValue(); + const auto &Enclosing = ExtractedFunc.getEnclosing(); + auto FunctionDecl = + ExtractedFunc.renderDeclaration(ForwardDeclaration, Enclosing, SM) + + ";\n"; + auto DeclPoint = ExtractedFunc.ForwardDeclarationPoint.getValue(); return tooling::Replacement(SM, DeclPoint.getBegin(), 0, FunctionDecl); } @@ -803,7 +860,7 @@ llvm::Error addDeclarationTweakEffect(llvm::Expected &FileEdit, const NewFunction &ExtractedFunc, const SourceManager &SM) { - auto DeclPoint = ExtractedFunc.DeclarationPoint.getValue(); + auto DeclPoint = ExtractedFunc.ForwardDeclarationPoint.getValue(); tooling::Replacements ResultDecl; if (auto Err = ResultDecl.add(createFunctionDeclaration(ExtractedFunc, SM))) return Err; @@ -870,10 +927,7 @@ if (auto Err = Result.add(replaceWithFuncCall(*ExtractedFunc, SM, LangOpts))) return std::move(Err); - auto DefinitionIsOutOfLine = - ExtractedFunc->OutOfLine && ExtractedFunc->DeclarationPoint.hasValue(); - - if (!DefinitionIsOutOfLine) { + if (!ExtractedFunc->ForwardDeclarationPoint.hasValue()) { return Effect::mainFileEdit(SM, std::move(Result)); } Index: clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp =================================================================== --- clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp +++ clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp @@ -226,6 +226,87 @@ EXPECT_EQ(EditedFiles.begin()->second, HeaderOutputCheck); } +TEST_F(ExtractFunctionTest, DifferentFilesNestedTest) { + Header = R"cpp( + class T { + class SomeClass { + void f(); + }; + }; + )cpp"; + + std::string NestedOutOfLineSource = R"cpp( + void T::SomeClass::f() { + [[int x;]] + } + )cpp"; + + std::string NestedOutOfLineSourceOutputCheck = R"cpp( + void T::SomeClass::extracted() { +int x; +} +void T::SomeClass::f() { + extracted(); + } + )cpp"; + + std::string NestedHeaderOutputCheck = R"cpp( + class T { + class SomeClass { + void extracted(); +void f(); + }; + }; + )cpp"; + + llvm::StringMap EditedFiles; + + EXPECT_EQ(apply(NestedOutOfLineSource, &EditedFiles), + NestedOutOfLineSourceOutputCheck); + EXPECT_EQ(EditedFiles.begin()->second, NestedHeaderOutputCheck); +} + +TEST_F(ExtractFunctionTest, ConstexprDifferentHeaderSourceTest) { + Header = R"cpp( + class SomeClass { + constexpr void f() const; + }; + )cpp"; + + std::string OutOfLineSource = R"cpp( + constexpr void SomeClass::f() const { + [[int x;]] + } + )cpp"; + + std::string OutOfLineSourceOutputCheck = R"cpp( + constexpr void SomeClass::extracted() const { +int x; +} +constexpr void SomeClass::f() const { + extracted(); + } + )cpp"; + + std::string HeaderOutputCheck = R"cpp( + class SomeClass { + constexpr void extracted() const; +constexpr void f() const; + }; + )cpp"; + + llvm::StringMap EditedFiles; + + EXPECT_EQ(apply(OutOfLineSource, &EditedFiles), OutOfLineSourceOutputCheck); + EXPECT_NE(EditedFiles.begin(), EditedFiles.end()) + << "The header should be edited and receives the declaration of the new " + "function"; + + if (EditedFiles.begin() != EditedFiles.end()) { + EXPECT_EQ(EditedFiles.begin()->second, HeaderOutputCheck); + } +} + TEST_F(ExtractFunctionTest, ConstDifferentHeaderSourceTest) { Header = R"cpp( class SomeClass {