diff --git a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp --- a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp +++ b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp @@ -48,6 +48,7 @@ #include "AST.h" #include "FindTarget.h" +#include "Headers.h" #include "ParsedAST.h" #include "Selection.h" #include "SourceCode.h" @@ -79,6 +80,13 @@ using Node = SelectionTree::Node; +struct HoistSetComparator { + bool operator()(const Decl *const Lhs, const Decl *const Rhs) const { + return Lhs->getLocation() < Rhs->getLocation(); + } +}; +using HoistSet = llvm::SmallSet; + // ExtractionZone is the part of code that is being extracted. // EnclosingFunction is the function/method inside which the zone lies. // We split the file into 4 parts relative to extraction zone. @@ -171,12 +179,13 @@ // semicolon after the extraction. const Node *getLastRootStmt() const { return Parent->Children.back(); } - // Checks if declarations inside extraction zone are accessed afterwards. + // Checks if declarations inside extraction zone are accessed afterwards and + // adds these declarations to the returned set. // // This performs a partial AST traversal proportional to the size of the // enclosing function, so it is possibly expensive. - bool requiresHoisting(const SourceManager &SM, - const HeuristicResolver *Resolver) const { + HoistSet getDeclsToHoist(const SourceManager &SM, + const HeuristicResolver *Resolver) const { // First find all the declarations that happened inside extraction zone. llvm::SmallSet DeclsInExtZone; for (auto *RootStmt : RootStmts) { @@ -191,29 +200,28 @@ } // Early exit without performing expensive traversal below. if (DeclsInExtZone.empty()) - return false; - // Then make sure they are not used outside the zone. + return {}; + // Add any decl used after the selection to the returned set + HoistSet DeclsToHoist{}; for (const auto *S : EnclosingFunction->getBody()->children()) { if (SM.isBeforeInTranslationUnit(S->getSourceRange().getEnd(), ZoneRange.getEnd())) continue; - bool HasPostUse = false; findExplicitReferences( S, [&](const ReferenceLoc &Loc) { - if (HasPostUse || - SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd())) + if (SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd())) return; - HasPostUse = llvm::any_of(Loc.Targets, - [&DeclsInExtZone](const Decl *Target) { - return DeclsInExtZone.contains(Target); - }); + for (const NamedDecl *const PostUse : llvm::make_filter_range( + Loc.Targets, [&DeclsInExtZone](const Decl *Target) { + return DeclsInExtZone.contains(Target); + })) { + DeclsToHoist.insert(PostUse); + } }, Resolver); - if (HasPostUse) - return true; } - return false; + return DeclsToHoist; } }; @@ -367,14 +375,17 @@ bool Static = false; ConstexprSpecKind Constexpr = ConstexprSpecKind::Unspecified; bool Const = false; + const HoistSet &ToHoist; // Decides whether the extracted function body and the function call need a // semicolon after extraction. tooling::ExtractionSemicolonPolicy SemicolonPolicy; const LangOptions *LangOpts; - NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy, + NewFunction(const HoistSet &ToHoist, + tooling::ExtractionSemicolonPolicy SemicolonPolicy, const LangOptions *LangOpts) - : SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {} + : ToHoist(ToHoist), SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) { + } // Render the call for this function. std::string renderCall() const; // Render the definition for this function. @@ -390,6 +401,7 @@ std::string renderSpecifiers(FunctionDeclKind K) const; std::string renderQualifiers() const; std::string renderDeclarationName(FunctionDeclKind K) const; + std::string renderHoistedCall() const; // Generate the function body. std::string getFuncBody(const SourceManager &SM) const; }; @@ -462,7 +474,55 @@ return llvm::formatv("{0}{1}", QualifierName, Name); } +// Renders the HoistSet to a comma separated list or a single named decl. +std::string renderHoistSet(const HoistSet &ToHoist) { + std::string Res{}; + bool NeedsComma = false; + + for (const NamedDecl *DeclToHoist : ToHoist) { + if (llvm::isa(DeclToHoist) || + llvm::isa(DeclToHoist)) { + if (NeedsComma) { + Res += ", "; + } + Res += DeclToHoist->getNameAsString(); + NeedsComma = true; + } + } + return Res; +} + +std::string NewFunction::renderHoistedCall() const { + auto HoistedVarDecls = std::string{}; + auto ExplicitUnpacking = std::string{}; + const auto HasStructuredBinding = LangOpts->CPlusPlus17; + + if (ToHoist.size() > 1) { + if (HasStructuredBinding) { + HoistedVarDecls = "auto [" + renderHoistSet(ToHoist) + "] = "; + } else { + HoistedVarDecls = "auto returned = "; + auto DeclIter = ToHoist.begin(); + for (size_t Index = 0U; Index < ToHoist.size(); ++Index, ++DeclIter) { + ExplicitUnpacking += + llvm::formatv("\nauto {0} = std::get<{1}>(returned);", + (*DeclIter)->getNameAsString(), Index); + } + } + } else { + HoistedVarDecls = "auto " + renderHoistSet(ToHoist) + " = "; + } + + return llvm::formatv( + "{0}{1}({2}){3}{4}", HoistedVarDecls, Name, renderParametersForCall(), + (SemicolonPolicy.isNeededInOriginalFunction() ? ";" : ""), + ExplicitUnpacking); +} + std::string NewFunction::renderCall() const { + if (!ToHoist.empty()) + return renderHoistedCall(); + return std::string( llvm::formatv("{0}{1}({2}){3}", CallerReturnsValue ? "return " : "", Name, renderParametersForCall(), @@ -495,8 +555,22 @@ // - hoist decls // - add return statement // - Add semicolon - return toSourceCode(SM, BodyRange).str() + - (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : ""); + auto Body = toSourceCode(SM, BodyRange).str() + + (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : ""); + + if (ToHoist.empty()) + return Body; + + if (const bool NeedsTupleOrPair = ToHoist.size() > 1; NeedsTupleOrPair) { + const auto NeedsPair = ToHoist.size() == 2; + + Body += "\nreturn " + + std::string(NeedsPair ? "std::pair{" : "std::tuple{") + + renderHoistSet(ToHoist) + "};"; + } else { + Body += "\nreturn " + renderHoistSet(ToHoist) + ";"; + } + return Body; } std::string NewFunction::Parameter::render(const DeclContext *Context) const { @@ -674,10 +748,6 @@ const auto &DeclInfo = KeyVal.second; // If a Decl was Declared in zone and referenced in post zone, it // needs to be hoisted (we bail out in that case). - // FIXME: Support Decl Hoisting. - if (DeclInfo.DeclaredIn == ZoneRelative::Inside && - DeclInfo.IsReferencedInPostZone) - return false; if (!DeclInfo.IsReferencedInZone) continue; // no need to pass as parameter, not referenced if (DeclInfo.DeclaredIn == ZoneRelative::Inside || @@ -723,6 +793,19 @@ return SemicolonPolicy; } +QualType getReturnTypeForHoisted(const FunctionDecl &EnclosingFunc, + const HoistSet &ToHoist) { + // Hoisting just one variable, use that variables type instead of auto + if (ToHoist.size() == 1) { + if (const auto *const VDecl = llvm::dyn_cast(*ToHoist.begin()); + VDecl != nullptr) { + return VDecl->getType(); + } + } + + return EnclosingFunc.getParentASTContext().getAutoDeductType(); +} + // Generate return type for ExtractedFunc. Return false if unable to do so. bool generateReturnProperties(NewFunction &ExtractedFunc, const FunctionDecl &EnclosingFunc, @@ -744,7 +827,11 @@ return true; } // FIXME: Generate new return statement if needed. - ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy; + ExtractedFunc.ReturnType = + ExtractedFunc.ToHoist.empty() + ? EnclosingFunc.getParentASTContext().VoidTy + : getReturnTypeForHoisted(EnclosingFunc, ExtractedFunc.ToHoist); + return true; } @@ -758,6 +845,7 @@ // FIXME: add support for adding other function return types besides void. // FIXME: assign the value returned by non void extracted function. llvm::Expected getExtractedFunction(ExtractionZone &ExtZone, + const HoistSet &ToHoist, const SourceManager &SM, const LangOptions &LangOpts) { CapturedZoneInfo CapturedInfo = captureZoneInfo(ExtZone); @@ -765,7 +853,7 @@ if (CapturedInfo.BrokenControlFlow) return error("Cannot extract break/continue without corresponding " "loop/switch statement."); - NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts), + NewFunction ExtractedFunc(ToHoist, getSemicolonPolicy(ExtZone, SM, LangOpts), &LangOpts); ExtractedFunc.SyntacticDC = @@ -814,6 +902,7 @@ private: ExtractionZone ExtZone; + HoistSet ToHoist; }; REGISTER_TWEAK(ExtractFunction) @@ -879,8 +968,19 @@ (hasReturnStmt(*MaybeExtZone) && !alwaysReturns(*MaybeExtZone))) return false; - // FIXME: Get rid of this check once we support hoisting. - if (MaybeExtZone->requiresHoisting(SM, Inputs.AST->getHeuristicResolver())) + ToHoist = + MaybeExtZone->getDeclsToHoist(SM, Inputs.AST->getHeuristicResolver()); + + // Cannot extract a selection that contains a type declaration that is used + // outside of the selected range + if (llvm::any_of(ToHoist, [](const NamedDecl *NDecl) { + return llvm::isa(NDecl); + })) + return false; + + const auto HasAutoReturnTypeDeduction = LangOpts.CPlusPlus14; + const auto RequiresPairOrTuple = ToHoist.size() > 1; + if (RequiresPairOrTuple && !HasAutoReturnTypeDeduction) return false; ExtZone = std::move(*MaybeExtZone); @@ -890,7 +990,7 @@ Expected ExtractFunction::apply(const Selection &Inputs) { const SourceManager &SM = Inputs.AST->getSourceManager(); const LangOptions &LangOpts = Inputs.AST->getLangOpts(); - auto ExtractedFunc = getExtractedFunction(ExtZone, SM, LangOpts); + auto ExtractedFunc = getExtractedFunction(ExtZone, ToHoist, SM, LangOpts); // FIXME: Add more types of errors. if (!ExtractedFunc) return ExtractedFunc.takeError(); @@ -913,8 +1013,8 @@ tooling::Replacements OtherEdit( createForwardDeclaration(*ExtractedFunc, SM)); - if (auto PathAndEdit = Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc), - OtherEdit)) + if (auto PathAndEdit = + Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc), OtherEdit)) MultiFileEffect->ApplyEdits.try_emplace(PathAndEdit->first, PathAndEdit->second); else diff --git a/clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp b/clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp --- a/clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp +++ b/clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp @@ -30,8 +30,9 @@ EXPECT_EQ(apply("auto lam = [](){ [[int x;]] }; "), "unavailable"); // Partial statements aren't extracted. EXPECT_THAT(apply("int [[x = 0]];"), "unavailable"); - // FIXME: Support hoisting. - EXPECT_THAT(apply(" [[int a = 5;]] a++; "), "unavailable"); + + // Extract regions that require hoisting + EXPECT_THAT(apply(" [[int a = 5;]] a++; "), HasSubstr("extracted")); // Ensure that end of Zone and Beginning of PostZone being adjacent doesn't // lead to break being included in the extraction zone. @@ -192,6 +193,310 @@ EXPECT_EQ(apply(CompoundFailInput), "unavailable"); } +TEST_F(ExtractFunctionTest, Hoisting) { + ExtraArgs.emplace_back("-std=c++17"); + std::string HoistingInput = R"cpp( + int foo() { + int a = 3; + [[int x = 39 + a; + ++x; + int y = x * 2; + int z = 4;]] + return x + y + z; + } + )cpp"; + std::string HoistingOutput = R"cpp( + auto extracted(int &a) { +int x = 39 + a; + ++x; + int y = x * 2; + int z = 4; +return std::tuple{x, y, z}; +} +int foo() { + int a = 3; + auto [x, y, z] = extracted(a); + return x + y + z; + } + )cpp"; + EXPECT_EQ(apply(HoistingInput), HoistingOutput); + + std::string HoistingInput2 = R"cpp( + int foo() { + int a{}; + [[int b = a + 1;]] + return b; + } + )cpp"; + std::string HoistingOutput2 = R"cpp( + int extracted(int &a) { +int b = a + 1; +return b; +} +int foo() { + int a{}; + auto b = extracted(a); + return b; + } + )cpp"; + EXPECT_EQ(apply(HoistingInput2), HoistingOutput2); + + std::string HoistingInput3 = R"cpp( + int foo(int b) { + int a{}; + if (b == 42) { + [[a = 123; + return a + b;]] + } + a = 456; + return a; + } + )cpp"; + std::string HoistingOutput3 = R"cpp( + int extracted(int &b, int &a) { +a = 123; + return a + b; +} +int foo(int b) { + int a{}; + if (b == 42) { + return extracted(b, a); + } + a = 456; + return a; + } + )cpp"; + EXPECT_EQ(apply(HoistingInput3), HoistingOutput3); + + std::string HoistingInput4 = R"cpp( + struct A { + bool flag; + int val; + }; + A bar(); + int foo(int b) { + int a = 0; + [[auto [flag, val] = bar(); + int c = 4; + val = c + a;]] + return a + b + c + val; + } + )cpp"; + std::string HoistingOutput4 = R"cpp( + struct A { + bool flag; + int val; + }; + A bar(); + auto extracted(int &a) { +auto [flag, val] = bar(); + int c = 4; + val = c + a; +return std::pair{val, c}; +} +int foo(int b) { + int a = 0; + auto [val, c] = extracted(a); + return a + b + c + val; + } + )cpp"; + EXPECT_EQ(apply(HoistingInput4), HoistingOutput4); + + // Cannot extract a selection that contains a type declaration that is used + // outside of the selected range + EXPECT_THAT(apply(R"cpp( + [[using MyType = int;]] + MyType x = 42; + MyType y = x; + )cpp"), + "unavailable"); + EXPECT_THAT(apply(R"cpp( + [[using MyType = int; + MyType x = 42;]] + MyType y = x; + )cpp"), + "unavailable"); + EXPECT_THAT(apply(R"cpp( + [[struct Bar { + int X; + }; + auto Y = Bar{42};]] + auto Z = Bar{Y}; + )cpp"), + "unavailable"); + + // Check that selections containing type declarations can be extracted if + // there are no uses of the type after the selection + std::string FullTypeAliasInput = R"cpp( + void foo() { + [[using MyType = int; + MyType x = 42; + MyType y = x;]] + } + )cpp"; + std::string FullTypeAliasOutput = R"cpp( + void extracted() { +using MyType = int; + MyType x = 42; + MyType y = x; +} +void foo() { + extracted(); + } + )cpp"; + EXPECT_EQ(apply(FullTypeAliasInput), FullTypeAliasOutput); + + std::string FullStructInput = R"cpp( + int foo() { + [[struct Bar { + int X; + }; + auto Y = Bar{42}; + auto Z = Bar{Y}; + return 42;]] + } + )cpp"; + std::string FullStructOutput = R"cpp( + int extracted() { +struct Bar { + int X; + }; + auto Y = Bar{42}; + auto Z = Bar{Y}; + return 42; +} +int foo() { + return extracted(); + } + )cpp"; + EXPECT_EQ(apply(FullStructInput), FullStructOutput); + + std::string ReturnTypeIsAliasedInput = R"cpp( + int foo() { + [[struct Bar { + int X; + }; + auto Y = Bar{42}; + auto Z = Bar{Y}; + using MyInt = int; + MyInt A = 42; + return A;]] + } + )cpp"; + std::string ReturnTypeIsAliasedOutput = R"cpp( + int extracted() { +struct Bar { + int X; + }; + auto Y = Bar{42}; + auto Z = Bar{Y}; + using MyInt = int; + MyInt A = 42; + return A; +} +int foo() { + return extracted(); + } + )cpp"; + EXPECT_EQ(apply(ReturnTypeIsAliasedInput), ReturnTypeIsAliasedOutput); + + EXPECT_THAT(apply(R"cpp( + [[struct Bar { + int X; + }; + auto Y = Bar{42};]] + auto Z = Bar{Y}; + )cpp"), + "unavailable"); +} + +TEST_F(ExtractFunctionTest, HoistingCXX11) { + ExtraArgs.emplace_back("-std=c++11"); + std::string HoistingInput = R"cpp( + int foo() { + int a = 3; + [[int x = 39 + a; + ++x; + int y = x * 2; + int z = 4;]] + return x + y + z; + } + )cpp"; + EXPECT_THAT(apply(HoistingInput), HasSubstr("unavailable")); + + std::string HoistingInput2 = R"cpp( + int foo() { + int a; + [[int b = a + 1;]] + return b; + } + )cpp"; + std::string HoistingOutput2 = R"cpp( + int extracted(int &a) { +int b = a + 1; +return b; +} +int foo() { + int a; + auto b = extracted(a); + return b; + } + )cpp"; + EXPECT_EQ(apply(HoistingInput2), HoistingOutput2); +} + +TEST_F(ExtractFunctionTest, HoistingCXX14) { + ExtraArgs.emplace_back("-std=c++14"); + std::string HoistingInput = R"cpp( + int foo() { + int a = 3; + [[int x = 39 + a; + ++x; + int y = x * 2; + int z = 4;]] + return x + y + z; + } + )cpp"; + std::string HoistingOutput = R"cpp( + auto extracted(int &a) { +int x = 39 + a; + ++x; + int y = x * 2; + int z = 4; +return std::tuple{x, y, z}; +} +int foo() { + int a = 3; + auto returned = extracted(a); +auto x = std::get<0>(returned); +auto y = std::get<1>(returned); +auto z = std::get<2>(returned); + return x + y + z; + } + )cpp"; + EXPECT_EQ(apply(HoistingInput), HoistingOutput); + + std::string HoistingInput2 = R"cpp( + int foo() { + int a; + [[int b = a + 1;]] + return b; + } + )cpp"; + std::string HoistingOutput2 = R"cpp( + int extracted(int &a) { +int b = a + 1; +return b; +} +int foo() { + int a; + auto b = extracted(a); + return b; + } + )cpp"; + EXPECT_EQ(apply(HoistingInput2), HoistingOutput2); +} + TEST_F(ExtractFunctionTest, DifferentHeaderSourceTest) { Header = R"cpp( class SomeClass { diff --git a/clang-tools-extra/docs/ReleaseNotes.rst b/clang-tools-extra/docs/ReleaseNotes.rst --- a/clang-tools-extra/docs/ReleaseNotes.rst +++ b/clang-tools-extra/docs/ReleaseNotes.rst @@ -78,6 +78,9 @@ Miscellaneous ^^^^^^^^^^^^^ +- The extract function tweak gained support for hoisting, i.e. returning decls declared + inside the selection that are used outside of the selection. + Improvements to clang-doc -------------------------