diff --git a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp --- a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp +++ b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp @@ -708,6 +708,27 @@ return tooling::Replacement(SM, ExtractedFunc.InsertionPoint, 0, FunctionDef); } +// Returns true if ExtZone contains any ReturnStms. +bool hasReturnStmt(const ExtractionZone &ExtZone) { + class ReturnStmtVisitor + : public clang::RecursiveASTVisitor { + public: + bool VisitReturnStmt(ReturnStmt *Return) { + Found = true; + return false; // We found the answer, abort the scan. + } + bool Found = false; + }; + + ReturnStmtVisitor V; + for (const auto *RootStmt : ExtZone.RootStmts) { + V.TraverseStmt(const_cast(RootStmt)); + if (V.Found) + break; + } + return V.Found; +} + bool ExtractFunction::prepare(const Selection &Inputs) { const LangOptions &LangOpts = Inputs.AST->getLangOpts(); if (!LangOpts.CPlusPlus) @@ -715,8 +736,12 @@ const Node *CommonAnc = Inputs.ASTSelection.commonAncestor(); const SourceManager &SM = Inputs.AST->getSourceManager(); auto MaybeExtZone = findExtractionZone(CommonAnc, SM, LangOpts); + if (!MaybeExtZone || + (hasReturnStmt(*MaybeExtZone) && !alwaysReturns(*MaybeExtZone))) + return false; + // FIXME: Get rid of this check once we support hoisting. - if (!MaybeExtZone || MaybeExtZone->requiresHoisting(SM)) + if (MaybeExtZone->requiresHoisting(SM)) return false; ExtZone = std::move(*MaybeExtZone); diff --git a/clang-tools-extra/clangd/unittests/TweakTests.cpp b/clang-tools-extra/clangd/unittests/TweakTests.cpp --- a/clang-tools-extra/clangd/unittests/TweakTests.cpp +++ b/clang-tools-extra/clangd/unittests/TweakTests.cpp @@ -607,7 +607,11 @@ // Extract certain return EXPECT_THAT(apply(" if(true) [[{ return; }]] "), HasSubstr("extracted")); // Don't extract uncertain return - EXPECT_THAT(apply(" if(true) [[if (false) return;]] "), StartsWith("fail")); + EXPECT_THAT(apply(" if(true) [[if (false) return;]] "), + StartsWith("unavailable")); + EXPECT_THAT( + apply("#define RETURN_IF_ERROR(x) if (x) return\nRETU^RN_IF_ERROR(4);"), + StartsWith("unavailable")); FileName = "a.c"; EXPECT_THAT(apply(" for([[int i = 0;]];);"), HasSubstr("unavailable"));