diff --git a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp --- a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp +++ b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp @@ -165,6 +165,21 @@ llvm::DenseSet RootStmts; }; +// Whether the code in the extraction zone is guaranteed to return, assuming +// no broken control flow (unbound break/continue). +// This is a very naive check (does it end with a return stmt). +// Doing some rudimentary control flow analysis would cover more cases. +bool alwaysReturns(const ExtractionZone &EZ) { + const Stmt *Last = EZ.Parent->Children.back()->ASTNode.get(); + // Unwrap enclosing (unconditional) compound statement. + while (const auto *CS = llvm::dyn_cast(Last)) + if (CS->body_empty()) + return false; + else + Last = CS->body_back(); + return llvm::isa(Last); +} + bool ExtractionZone::isRootStmt(const Stmt *S) const { return RootStmts.find(S) != RootStmts.end(); } @@ -283,11 +298,12 @@ } }; std::string Name = "extracted"; - std::string ReturnType; + QualType ReturnType; std::vector Parameters; SourceRange BodyRange; SourceLocation InsertionPoint; const DeclContext *EnclosingFuncContext; + bool CallerReturnsValue = false; // Decides whether the extracted function body and the function call need a // semicolon after extraction. tooling::ExtractionSemicolonPolicy SemicolonPolicy; @@ -330,13 +346,16 @@ } std::string NewFunction::renderCall() const { - return Name + "(" + renderParametersForCall() + ")" + - (SemicolonPolicy.isNeededInOriginalFunction() ? ";" : ""); + return llvm::formatv( + "{0}{1}({2}){3}", CallerReturnsValue ? "return " : "", Name, + renderParametersForCall(), + (SemicolonPolicy.isNeededInOriginalFunction() ? ";" : "")); } std::string NewFunction::renderDefinition(const SourceManager &SM) const { - return ReturnType + " " + Name + "(" + renderParametersForDefinition() + ")" + - " {\n" + getFuncBody(SM) + "\n}\n"; + return llvm::formatv("{0} {1}({2}) {\n{3}\n}\n", + printType(ReturnType, *EnclosingFuncContext), Name, + renderParametersForDefinition(), getFuncBody(SM)); } std::string NewFunction::getFuncBody(const SourceManager &SM) const { @@ -370,8 +389,8 @@ }; // Maps Decls to their DeclInfo llvm::DenseMap DeclInfoMap; - // True if there is a return statement in zone. - bool HasReturnStmt = false; + bool HasReturnStmt = false; // Are there any return statements in the zone? + bool AlwaysReturns = false; // Does the zone always return? // Control flow is broken if we are extracting a break/continue without a // corresponding parent loop/switch bool BrokenControlFlow = false; @@ -519,7 +538,9 @@ unsigned CurNumberOfSwitch = 0; }; ExtractionZoneVisitor Visitor(ExtZone); - return std::move(Visitor.Info); + CapturedZoneInfo Result = std::move(Visitor.Info); + Result.AlwaysReturns = alwaysReturns(ExtZone); + return Result; } // Adds parameters to ExtractedFunc. @@ -582,13 +603,26 @@ // Generate return type for ExtractedFunc. Return false if unable to do so. bool generateReturnProperties(NewFunction &ExtractedFunc, + const FunctionDecl &EnclosingFunc, const CapturedZoneInfo &CapturedInfo) { - - // FIXME: Use Existing Return statements (if present) + // If the selected code always returns, we preserve those return statements. + // The return type should be the same as the enclosing function. + // (Others are possible if there are conversions, but this seems clearest). + if (CapturedInfo.HasReturnStmt) { + // If the return is conditional, neither replacing the code with + // `extracted()` nor `return extracted()` is correct. + if (!CapturedInfo.AlwaysReturns) + return false; + QualType Ret = EnclosingFunc.getReturnType(); + // Once we support members, it'd be nice to support e.g. extracting a method + // of Foo that returns T. But it's not clear when that's safe. + if (Ret->isDependentType()) + return false; + ExtractedFunc.ReturnType = Ret; + return true; + } // FIXME: Generate new return statement if needed. - if (CapturedInfo.HasReturnStmt) - return false; - ExtractedFunc.ReturnType = "void"; + ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy; return true; } @@ -608,8 +642,10 @@ ExtractedFunc.InsertionPoint = ExtZone.getInsertionPoint(); ExtractedFunc.EnclosingFuncContext = ExtZone.EnclosingFunction->getDeclContext(); + ExtractedFunc.CallerReturnsValue = CapturedInfo.AlwaysReturns; if (!createParameters(ExtractedFunc, CapturedInfo) || - !generateReturnProperties(ExtractedFunc, CapturedInfo)) + !generateReturnProperties(ExtractedFunc, *ExtZone.EnclosingFunction, + CapturedInfo)) return llvm::createStringError(llvm::inconvertibleErrorCode(), +"Too complex to extract."); return ExtractedFunc; diff --git a/clang-tools-extra/clangd/unittests/TweakTests.cpp b/clang-tools-extra/clangd/unittests/TweakTests.cpp --- a/clang-tools-extra/clangd/unittests/TweakTests.cpp +++ b/clang-tools-extra/clangd/unittests/TweakTests.cpp @@ -573,8 +573,10 @@ EXPECT_THAT(apply(" for([[int i = 0;]];);"), HasSubstr("extracted")); // Don't extract because needs hoisting. EXPECT_THAT(apply(" [[int a = 5;]] a++; "), StartsWith("fail")); - // Don't extract return - EXPECT_THAT(apply(" if(true) [[return;]] "), StartsWith("fail")); + // Extract certain return + EXPECT_THAT(apply(" if(true) [[{ return; }]] "), HasSubstr("extracted")); + // Don't extract uncertain return + EXPECT_THAT(apply(" if(true) [[if (false) return;]] "), StartsWith("fail")); } TEST_F(ExtractFunctionTest, FileTest) { @@ -679,6 +681,42 @@ StartsWith("fail")); } +TEST_F(ExtractFunctionTest, ExistingReturnStatement) { + Context = File; + const char* Before = R"cpp( + bool lucky(int N); + int getNum(bool Superstitious, int Min, int Max) { + if (Superstitious) [[{ + for (int I = Min; I <= Max; ++I) + if (lucky(I)) + return I; + return -1; + }]] else { + return (Min + Max) / 2; + } + } + )cpp"; + // FIXME: min/max should be by value. + // FIXME: avoid emitting redundant braces + const char* After = R"cpp( + bool lucky(int N); + int extracted(int &Min, int &Max) { +{ + for (int I = Min; I <= Max; ++I) + if (lucky(I)) + return I; + return -1; + } +} +int getNum(bool Superstitious, int Min, int Max) { + if (Superstitious) return extracted(Min, Max); else { + return (Min + Max) / 2; + } + } + )cpp"; + EXPECT_EQ(apply(Before), After); +} + TWEAK_TEST(RemoveUsingNamespace); TEST_F(RemoveUsingNamespaceTest, All) { std::pair Cases[] = {