diff --git a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp --- a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp +++ b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp @@ -47,6 +47,7 @@ //===----------------------------------------------------------------------===// #include "AST.h" +#include "FindTarget.h" #include "ParsedAST.h" #include "Selection.h" #include "SourceCode.h" @@ -54,6 +55,7 @@ #include "support/Logger.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Decl.h" +#include "clang/AST/DeclBase.h" #include "clang/AST/DeclTemplate.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/Stmt.h" @@ -65,6 +67,8 @@ #include "clang/Tooling/Refactoring/Extract/SourceExtraction.h" #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" @@ -152,6 +156,9 @@ const FunctionDecl *EnclosingFunction = nullptr; // The half-open file range of the enclosing function. SourceRange EnclosingFuncRange; + // Set of statements that form the ExtractionZone. + llvm::DenseSet RootStmts; + SourceLocation getInsertionPoint() const { return EnclosingFuncRange.getBegin(); } @@ -159,10 +166,45 @@ // The last root statement is important to decide where we need to insert a // semicolon after the extraction. const Node *getLastRootStmt() const { return Parent->Children.back(); } - void generateRootStmts(); -private: - llvm::DenseSet RootStmts; + // Checks if declarations inside extraction zone are accessed afterwards. + // + // This performs a partial AST traversal proportional to the size of the + // enclosing function, so it is possibly expensive. + bool requiresHoisting(const SourceManager &SM) const { + // First find all the declarations that happened inside extraction zone. + llvm::SmallSet DeclsInExtZone; + for (auto *RootStmt : RootStmts) { + findExplicitReferences(RootStmt, + [&DeclsInExtZone](const ReferenceLoc &Loc) { + if (!Loc.IsDecl) + return; + DeclsInExtZone.insert(Loc.Targets.front()); + }); + } + // Early exit without performing expensive traversal below. + if (DeclsInExtZone.empty()) + return false; + // Then make sure they are not used outside the zone. + for (const auto *S : EnclosingFunction->getBody()->children()) { + if (SM.isBeforeInTranslationUnit(S->getSourceRange().getEnd(), + ZoneRange.getEnd())) + continue; + bool HasPostUse = false; + findExplicitReferences(S, [&](const ReferenceLoc &Loc) { + if (HasPostUse || + SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd())) + return; + HasPostUse = + llvm::any_of(Loc.Targets, [&DeclsInExtZone](const Decl *Target) { + return DeclsInExtZone.contains(Target); + }); + }); + if (HasPostUse) + return true; + } + return false; + } }; // Whether the code in the extraction zone is guaranteed to return, assuming @@ -185,12 +227,6 @@ return RootStmts.find(S) != RootStmts.end(); } -// Generate RootStmts set -void ExtractionZone::generateRootStmts() { - for (const Node *Child : Parent->Children) - RootStmts.insert(Child->ASTNode.get()); -} - // Finds the function in which the zone lies. const FunctionDecl *findEnclosingFunction(const Node *CommonAnc) { // Walk up the SelectionTree until we find a function Decl @@ -281,7 +317,10 @@ ExtZone.ZoneRange = *ZoneRange; if (ExtZone.EnclosingFuncRange.isInvalid() || ExtZone.ZoneRange.isInvalid()) return llvm::None; - ExtZone.generateRootStmts(); + + for (const Node *Child : ExtZone.Parent->Children) + ExtZone.RootStmts.insert(Child->ASTNode.get()); + return ExtZone; } @@ -670,16 +709,18 @@ } bool ExtractFunction::prepare(const Selection &Inputs) { - const Node *CommonAnc = Inputs.ASTSelection.commonAncestor(); - const SourceManager &SM = Inputs.AST->getSourceManager(); const LangOptions &LangOpts = Inputs.AST->getLangOpts(); if (!LangOpts.CPlusPlus) return false; - if (auto MaybeExtZone = findExtractionZone(CommonAnc, SM, LangOpts)) { - ExtZone = std::move(*MaybeExtZone); - return true; - } - return false; + const Node *CommonAnc = Inputs.ASTSelection.commonAncestor(); + const SourceManager &SM = Inputs.AST->getSourceManager(); + auto MaybeExtZone = findExtractionZone(CommonAnc, SM, LangOpts); + // FIXME: Get rid of this check once we support hoisting. + if (!MaybeExtZone || MaybeExtZone->requiresHoisting(SM)) + return false; + + ExtZone = std::move(*MaybeExtZone); + return true; } Expected ExtractFunction::apply(const Selection &Inputs) { diff --git a/clang-tools-extra/clangd/unittests/TweakTests.cpp b/clang-tools-extra/clangd/unittests/TweakTests.cpp --- a/clang-tools-extra/clangd/unittests/TweakTests.cpp +++ b/clang-tools-extra/clangd/unittests/TweakTests.cpp @@ -593,6 +593,8 @@ EXPECT_EQ(apply("auto lam = [](){ [[int x;]] }; "), "unavailable"); // Partial statements aren't extracted. EXPECT_THAT(apply("int [[x = 0]];"), "unavailable"); + // FIXME: Support hoisting. + EXPECT_THAT(apply(" [[int a = 5;]] a++; "), "unavailable"); // Ensure that end of Zone and Beginning of PostZone being adjacent doesn't // lead to break being included in the extraction zone. @@ -600,8 +602,6 @@ // FIXME: ExtractFunction should be unavailable inside loop construct // initializer/condition. EXPECT_THAT(apply(" for([[int i = 0;]];);"), HasSubstr("extracted")); - // Don't extract because needs hoisting. - EXPECT_THAT(apply(" [[int a = 5;]] a++; "), StartsWith("fail")); // Extract certain return EXPECT_THAT(apply(" if(true) [[{ return; }]] "), HasSubstr("extracted")); // Don't extract uncertain return