diff --git a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp --- a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp +++ b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp @@ -56,9 +56,12 @@ #include "clang/AST/ASTContext.h" #include "clang/AST/Decl.h" #include "clang/AST/DeclBase.h" +#include "clang/AST/ExprCXX.h" #include "clang/AST/NestedNameSpecifier.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/Stmt.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/ASTMatchers/ASTMatchers.h" #include "clang/Basic/LangOptions.h" #include "clang/Basic/SourceLocation.h" #include "clang/Basic/SourceManager.h" @@ -73,6 +76,9 @@ #include "llvm/Support/Error.h" #include "llvm/Support/raw_os_ostream.h" +#include +#include + namespace clang { namespace clangd { namespace { @@ -95,6 +101,208 @@ OutOfLineDefinition }; +// Helpers for handling "binary subexpressions" like a + [[b + c]] + d. This is +// taken from ExtractVariable, and adapted a little to handle collection of +// parameters. +struct ExtractedBinarySubexpressionSelection; + +class BinarySubexpressionSelection { + +public: + static inline std::optional + tryParse(const SelectionTree::Node &N, const SourceManager *SM) { + if (const BinaryOperator *Op = + llvm::dyn_cast_or_null(N.ASTNode.get())) { + return BinarySubexpressionSelection{SM, Op->getOpcode(), Op->getExprLoc(), + N.Children}; + } + if (const CXXOperatorCallExpr *Op = + llvm::dyn_cast_or_null( + N.ASTNode.get())) { + if (!Op->isInfixBinaryOp()) + return std::nullopt; + + llvm::SmallVector SelectedOps; + // Not all children are args, there's also the callee (operator). + for (const auto *Child : N.Children) { + const Expr *E = Child->ASTNode.get(); + assert(E && "callee and args should be Exprs!"); + if (E == Op->getArg(0) || E == Op->getArg(1)) + SelectedOps.push_back(Child); + } + return BinarySubexpressionSelection{ + SM, BinaryOperator::getOverloadedOpcode(Op->getOperator()), + Op->getExprLoc(), std::move(SelectedOps)}; + } + return std::nullopt; + } + + bool associative() const { + // Must also be left-associative! + switch (Kind) { + case BO_Add: + case BO_Mul: + case BO_And: + case BO_Or: + case BO_Xor: + case BO_LAnd: + case BO_LOr: + return true; + default: + return false; + } + } + + bool crossesMacroBoundary() const { + FileID F = SM->getFileID(ExprLoc); + for (const SelectionTree::Node *Child : SelectedOperations) + if (SM->getFileID(Child->ASTNode.get()->getExprLoc()) != F) + return true; + return false; + } + + bool isExtractable() const { + return associative() and not crossesMacroBoundary(); + } + + void dumpSelectedOperations(llvm::raw_ostream &Os, + const ASTContext &Cont) const { + for (const auto *Op : SelectedOperations) + Op->ASTNode.dump(Os, Cont); + } + + std::optional tryExtract() const; + +protected: + struct SelectedOperands { + llvm::SmallVector Operands; + const SelectionTree::Node *Start; + const SelectionTree::Node *End; + }; + +private: + BinarySubexpressionSelection( + const SourceManager *SM, BinaryOperatorKind Kind, SourceLocation ExprLoc, + llvm::SmallVector SelectedOps) + : SM{SM}, Kind(Kind), ExprLoc(ExprLoc), + SelectedOperations(std::move(SelectedOps)) {} + + SelectedOperands getSelectedOperands() const { + auto [Start, End]{getClosedRangeWithSelectedOperations()}; + + llvm::SmallVector Operands; + Operands.reserve(SelectedOperations.size()); + const SelectionTree::Node *BinOpSelectionIt{Start->Parent}; + + // Edge case: the selection starts from the most-left LHS, e.g. [[a+b+c]]+d + if (BinOpSelectionIt->Children.size() == 2) + Operands.emplace_back(BinOpSelectionIt->Children.front()); // LHS + // In case of operator+ call, the Children will contain the calle as well. + else if (BinOpSelectionIt->Children.size() == 3) + Operands.emplace_back(BinOpSelectionIt->Children[1]); // LHS + + // Go up the Binary Operation three, up to the most-right RHS + for (; BinOpSelectionIt->Children.back() != End; + BinOpSelectionIt = BinOpSelectionIt->Parent) + Operands.emplace_back(BinOpSelectionIt->Children.back()); // RHS + // Remember to add the most-right RHS + Operands.emplace_back(End); + + SelectedOperands Ops; + Ops.Start = Start; + Ops.End = End; + Ops.Operands = std::move(Operands); + return Ops; + } + + std::pair + getClosedRangeWithSelectedOperations() const { + BinaryOperatorKind OuterOp = Kind; + // Because the tree we're interested in contains only one operator type, and + // all eligible operators are left-associative, the shape of the tree is + // very restricted: it's a linked list along the left edges. + // This simplifies our implementation. + const SelectionTree::Node *Start = SelectedOperations.front(); // LHS + const SelectionTree::Node *End = SelectedOperations.back(); // RHS + + // End is already correct: it can't be an OuterOp (as it's + // left-associative). Start needs to be pushed down int the subtree to the + // right spot. + while (true) { + auto MaybeOp{tryParse(Start->ignoreImplicit(), SM)}; + if (not MaybeOp) + break; + const auto &Op{*MaybeOp}; + if (Op.Kind != OuterOp or Op.crossesMacroBoundary()) + break; + assert(!Op.SelectedOperations.empty() && + "got only operator on one side!"); + if (Op.SelectedOperations.size() == 1) { // Only Op.RHS selected + Start = Op.SelectedOperations.back(); + break; + } + // Op.LHS is (at least partially) selected, so descend into it. + Start = Op.SelectedOperations.front(); + } + return {Start, End}; + } + +protected: + const SourceManager *SM; + BinaryOperatorKind Kind; + SourceLocation ExprLoc; + // May also contain partially selected operations, + // e.g. a + [[b + c]], will keep (a + b) BinaryOperator. + llvm::SmallVector SelectedOperations; +}; + +struct ExtractedBinarySubexpressionSelection : BinarySubexpressionSelection { + ExtractedBinarySubexpressionSelection(BinarySubexpressionSelection BinSubexpr, + SelectedOperands SelectedOps) + : BinarySubexpressionSelection::BinarySubexpressionSelection( + std::move(BinSubexpr)), + Operands{std::move(SelectedOps)} {} + + SourceRange getRange(const LangOptions &LangOpts) const { + auto MakeHalfOpenFileRange{[&](const SelectionTree::Node *N) { + return toHalfOpenFileRange(*SM, LangOpts, N->ASTNode.getSourceRange()); + }}; + + return SourceRange(MakeHalfOpenFileRange(Operands.Start)->getBegin(), + MakeHalfOpenFileRange(Operands.End)->getEnd()); + } + + void dumpSelectedOperands(llvm::raw_ostream &Os, + const ASTContext &Cont) const { + for (const auto *Op : Operands.Operands) + Op->ASTNode.dump(Os, Cont); + } + + llvm::SmallVector + collectReferences(ASTContext &Cont) const { + llvm::SmallVector Refs; + auto Matcher{ + ast_matchers::findAll(ast_matchers::declRefExpr().bind("ref"))}; + for (const auto *SelNode : Operands.Operands) { + auto Matches{ast_matchers::match(Matcher, SelNode->ASTNode, Cont)}; + for (const auto &Match : Matches) + if (const DeclRefExpr * Ref{Match.getNodeAs("ref")}; Ref) + Refs.push_back(Ref); + } + return Refs; + } + +private: + SelectedOperands Operands; +}; + +std::optional +BinarySubexpressionSelection::tryExtract() const { + if (not isExtractable()) + return std::nullopt; + return ExtractedBinarySubexpressionSelection{*this, getSelectedOperands()}; +} + // A RootStmt is a statement that's fully selected including all it's children // and it's parent is unselected. // Check if a node is a root statement. @@ -122,11 +330,14 @@ // begins in selection range, ends in selection range and any scope that begins // outside the selection range, ends outside as well. const Node *getParentOfRootStmts(const Node *CommonAnc) { - if (!CommonAnc) - return nullptr; const Node *Parent = nullptr; switch (CommonAnc->Selected) { case SelectionTree::Selection::Unselected: + // Workaround for an operator call: BinaryOperator will be selecteded + // completely, but the operator call would be unselected, thus we treat it + // as it would be completely selected. + if (CommonAnc->ASTNode.get() != nullptr) + return CommonAnc->Parent; // Typically a block, with the { and } unselected, could also be ForStmt etc // Ensure all Children are RootStmts. Parent = CommonAnc; @@ -152,6 +363,7 @@ // The ExtractionZone class forms a view of the code wrt Zone. struct ExtractionZone { + const Node *CommonAncestor; // Parent of RootStatements being extracted. const Node *Parent = nullptr; // The half-open file range of the code being extracted. @@ -162,6 +374,8 @@ SourceRange EnclosingFuncRange; // Set of statements that form the ExtractionZone. llvm::DenseSet RootStmts; + // If the extraction zone is a "binary subexpression", then this will be set. + std::optional MaybeBinarySubexpr; SourceLocation getInsertionPoint() const { return EnclosingFuncRange.getBegin(); @@ -292,20 +506,12 @@ return toHalfOpenFileRange(SM, LangOpts, EnclosingFunction->getSourceRange()); } -// returns true if Child can be a single RootStmt being extracted from -// EnclosingFunc. -bool validSingleChild(const Node *Child, const FunctionDecl *EnclosingFunc) { - // Don't extract expressions. - // FIXME: We should extract expressions that are "statements" i.e. not - // subexpressions - if (Child->ASTNode.get()) - return false; - // Extracting the body of EnclosingFunc would remove it's definition. - assert(EnclosingFunc->hasBody() && +bool isEntireFunctionBodySelected(const ExtractionZone &ExtZone) { + assert(ExtZone.EnclosingFunction->hasBody() && "We should always be extracting from a function body."); - if (Child->ASTNode.get() == EnclosingFunc->getBody()) - return false; - return true; + return ExtZone.Parent->Children.size() == 1 && + ExtZone.getLastRootStmt()->ASTNode.get() == + ExtZone.EnclosingFunction->getBody(); } // FIXME: Check we're not extracting from the initializer/condition of a control @@ -313,17 +519,30 @@ llvm::Optional findExtractionZone(const Node *CommonAnc, const SourceManager &SM, const LangOptions &LangOpts) { + if (CommonAnc == nullptr) + return std::nullopt; ExtractionZone ExtZone; + ExtZone.CommonAncestor = CommonAnc; + auto MaybeBinarySubexpr{ + BinarySubexpressionSelection::tryParse(CommonAnc->ignoreImplicit(), &SM)}; + if (MaybeBinarySubexpr) { + // FIXME: We shall not allow the user to extract expressions which we don't + // support, or which are weirdly selected (e.g. a [[+ b + c]]). If the + // selected subexpression is an entire expression (not only a part of + // expression), then we don't need the BinarySubexpressionSelection. + if (const auto &BinarySubexpr{*MaybeBinarySubexpr}; + BinarySubexpr.isExtractable()) { + ExtZone.MaybeBinarySubexpr = std::move(MaybeBinarySubexpr); + } + } ExtZone.Parent = getParentOfRootStmts(CommonAnc); if (!ExtZone.Parent || ExtZone.Parent->Children.empty()) return std::nullopt; ExtZone.EnclosingFunction = findEnclosingFunction(ExtZone.Parent); if (!ExtZone.EnclosingFunction) return std::nullopt; - // When there is a single RootStmt, we must check if it's valid for - // extraction. - if (ExtZone.Parent->Children.size() == 1 && - !validSingleChild(ExtZone.getLastRootStmt(), ExtZone.EnclosingFunction)) + // Extracting the body of EnclosingFunc would remove it's definition. + if (isEntireFunctionBodySelected(ExtZone)) return std::nullopt; if (auto FuncRange = computeEnclosingFuncRange(ExtZone.EnclosingFunction, SM, LangOpts)) @@ -367,6 +586,7 @@ bool Static = false; ConstexprSpecKind Constexpr = ConstexprSpecKind::Unspecified; bool Const = false; + bool Expression = false; // Decides whether the extracted function body and the function call need a // semicolon after extraction. @@ -495,8 +715,11 @@ // - hoist decls // - add return statement // - Add semicolon - return toSourceCode(SM, BodyRange).str() + - (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : ""); + auto NewBody{toSourceCode(SM, BodyRange).str() + + (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : "")}; + if (Expression) + return "return " + NewBody; + return NewBody; } std::string NewFunction::Parameter::render(const DeclContext *Context) const { @@ -530,6 +753,7 @@ // FIXME: Capture type information as well. DeclInformation *createDeclInfo(const Decl *D, ZoneRelative RelativeLoc); DeclInformation *getDeclInfoFor(const Decl *D); + const DeclInformation *getDeclInfoFor(const Decl *D) const; }; CapturedZoneInfo::DeclInformation * @@ -543,7 +767,14 @@ CapturedZoneInfo::DeclInformation * CapturedZoneInfo::getDeclInfoFor(const Decl *D) { - // If the Decl doesn't exist, we + auto Iter = DeclInfoMap.find(D); + if (Iter == DeclInfoMap.end()) + return nullptr; + return &Iter->second; +} + +const CapturedZoneInfo::DeclInformation * +CapturedZoneInfo::getDeclInfoFor(const Decl *D) const { auto Iter = DeclInfoMap.find(D); if (Iter == DeclInfoMap.end()) return nullptr; @@ -664,12 +895,29 @@ return Result; } -// Adds parameters to ExtractedFunc. -// Returns true if able to find the parameters successfully and no hoisting -// needed. +static const ValueDecl *unpackDeclForParameter(const Decl *D) { + const ValueDecl *VD = dyn_cast_or_null(D); + // Can't parameterise if the Decl isn't a ValueDecl or is a FunctionDecl + // (this includes the case of recursive call to EnclosingFunc in Zone). + if (!VD || isa(D)) + return nullptr; + return VD; +} + +static QualType getParameterTypeInfo(const ValueDecl *VD) { + // Parameter qualifiers are same as the Decl's qualifiers. + return VD->getType().getNonReferenceType(); +} + +using Parameters = std::vector; +using MaybeParameters = std::optional; + // FIXME: Check if the declaration has a local/anonymous type -bool createParameters(NewFunction &ExtractedFunc, - const CapturedZoneInfo &CapturedInfo) { +// Returns actual parameters if able to find the parameters successfully and no +// hoisting needed. +static MaybeParameters +createParamsForNoSubexpr(const CapturedZoneInfo &CapturedInfo) { + std::vector Params; for (const auto &KeyVal : CapturedInfo.DeclInfoMap) { const auto &DeclInfo = KeyVal.second; // If a Decl was Declared in zone and referenced in post zone, it @@ -677,20 +925,16 @@ // FIXME: Support Decl Hoisting. if (DeclInfo.DeclaredIn == ZoneRelative::Inside && DeclInfo.IsReferencedInPostZone) - return false; + return std::nullopt; if (!DeclInfo.IsReferencedInZone) continue; // no need to pass as parameter, not referenced if (DeclInfo.DeclaredIn == ZoneRelative::Inside || DeclInfo.DeclaredIn == ZoneRelative::OutsideFunc) continue; // no need to pass as parameter, still accessible. - // Parameter specific checks. - const ValueDecl *VD = dyn_cast_or_null(DeclInfo.TheDecl); - // Can't parameterise if the Decl isn't a ValueDecl or is a FunctionDecl - // (this includes the case of recursive call to EnclosingFunc in Zone). - if (!VD || isa(DeclInfo.TheDecl)) - return false; - // Parameter qualifiers are same as the Decl's qualifiers. - QualType TypeInfo = VD->getType().getNonReferenceType(); + const auto *VD{unpackDeclForParameter(DeclInfo.TheDecl)}; + if (VD == nullptr) + return std::nullopt; + QualType TypeInfo{getParameterTypeInfo(VD)}; // FIXME: Need better qualifier checks: check mutated status for // Decl(e.g. was it assigned, passed as nonconst argument, etc) // FIXME: check if parameter will be a non l-value reference. @@ -698,12 +942,61 @@ // pointers, etc by reference. bool IsPassedByReference = true; // We use the index of declaration as the ordering priority for parameters. - ExtractedFunc.Parameters.push_back({std::string(VD->getName()), TypeInfo, - IsPassedByReference, - DeclInfo.DeclIndex}); + Params.push_back({std::string(VD->getName()), TypeInfo, IsPassedByReference, + DeclInfo.DeclIndex}); } - llvm::sort(ExtractedFunc.Parameters); - return true; + llvm::sort(Params); + return Params; +} + +static MaybeParameters +createParamsForSubexpr(const CapturedZoneInfo &CapturedInfo, + const ExtractedBinarySubexpressionSelection &Subexpr, + ASTContext &ASTCont) { + // We use the the Set here, to avoid duplicates, but since the Set will not + // care about the order, we need to use a vector to collect the unique + // references in the order of referencing. + llvm::SmallVector RefsAsDecls; + llvm::DenseSet UniqueRefsAsDecls; + + for (const auto *Ref : Subexpr.collectReferences(ASTCont)) { + const auto *D{Ref->getDecl()}; + const auto *VD{unpackDeclForParameter(D)}; + // Only collect the ValueDecl-s. + if (VD == nullptr) + continue; + const auto *DeclInfo{CapturedInfo.getDeclInfoFor(D)}; + if (DeclInfo == nullptr or DeclInfo->DeclaredIn != ZoneRelative::Before) + continue; + auto [It, IsNew]{UniqueRefsAsDecls.insert(VD)}; + if (IsNew) + RefsAsDecls.emplace_back(VD); + } + + std::vector Params; + std::transform(std::begin(RefsAsDecls), std::end(RefsAsDecls), + std::back_inserter(Params), [](const ValueDecl *VD) { + QualType TypeInfo{getParameterTypeInfo(VD)}; + // FIXME: Need better qualifier checks: check mutated status + // for Decl(e.g. was it assigned, passed as nonconst + // argument, etc) + // FIXME: check if parameter will be a non l-value reference. + // FIXME: We don't want to always pass variables of types + // like int, pointers, etc by reference. + bool IsPassedByRef = true; + return NewFunction::Parameter{std::string(VD->getName()), + TypeInfo, IsPassedByRef, 0}; + }); + return Params; +} + +// Adds parameters to ExtractedFunc. +MaybeParameters createParams( + const std::optional &MaybeSubexpr, + const CapturedZoneInfo &CapturedInfo, ASTContext &ASTCont) { + if (MaybeSubexpr) + return createParamsForSubexpr(CapturedInfo, *MaybeSubexpr, ASTCont); + return createParamsForNoSubexpr(CapturedInfo); } // Clangd uses open ranges while ExtractionSemicolonPolicy (in Clang Tooling) @@ -723,29 +1016,47 @@ return SemicolonPolicy; } +// Returns true if the selected code is an expression, false otherwise. +bool isExpression(const ExtractionZone &ExtZone) { + const auto &Node{*ExtZone.Parent}; + return Node.Children.size() == 1 and + ExtZone.getLastRootStmt()->ASTNode.get() != nullptr; +} + // Generate return type for ExtractedFunc. Return false if unable to do so. -bool generateReturnProperties(NewFunction &ExtractedFunc, - const FunctionDecl &EnclosingFunc, - const CapturedZoneInfo &CapturedInfo) { +std::optional +generateReturnProperties(const ExtractionZone &ExtZone, + const CapturedZoneInfo &CapturedInfo) { // If the selected code always returns, we preserve those return statements. // The return type should be the same as the enclosing function. // (Others are possible if there are conversions, but this seems clearest). + const auto &EnclosingFunc{*ExtZone.EnclosingFunction}; if (CapturedInfo.HasReturnStmt) { // If the return is conditional, neither replacing the code with // `extracted()` nor `return extracted()` is correct. if (!CapturedInfo.AlwaysReturns) - return false; + return std::nullopt; QualType Ret = EnclosingFunc.getReturnType(); - // Once we support members, it'd be nice to support e.g. extracting a method - // of Foo that returns T. But it's not clear when that's safe. + // Once we support members, it'd be nice to support e.g. extracting a + // method of Foo that returns T. But it's not clear when that's safe. if (Ret->isDependentType()) - return false; - ExtractedFunc.ReturnType = Ret; - return true; + return std::nullopt; + return Ret; + } + // If the selected code is an expression, then take the return type of it. + if (const auto &Node{*ExtZone.Parent}; Node.Children.size() == 1) { + if (const Expr * Expression{ExtZone.getLastRootStmt()->ASTNode.get()}; + Expression) { + if (const auto *Call{llvm::dyn_cast_or_null(Expression)}; + Call) { + const auto &ASTCont{ExtZone.EnclosingFunction->getParentASTContext()}; + return Call->getCallReturnType(ASTCont); + } + return Expression->getType(); + } } // FIXME: Generate new return statement if needed. - ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy; - return true; + return EnclosingFunc.getParentASTContext().VoidTy; } void captureMethodInfo(NewFunction &ExtractedFunc, @@ -791,14 +1102,25 @@ ExtractedFunc.ForwardDeclarationSyntacticDC = ExtractedFunc.SemanticDC; } - ExtractedFunc.BodyRange = ExtZone.ZoneRange; - ExtractedFunc.DefinitionPoint = ExtZone.getInsertionPoint(); + auto &ASTCont{ExtZone.EnclosingFunction->getASTContext()}; + ExtractedFunc.Expression = isExpression(ExtZone); + std::optional MaybeExtractedSubexpr; + if (ExtZone.MaybeBinarySubexpr) { + MaybeExtractedSubexpr = ExtZone.MaybeBinarySubexpr->tryExtract(); + ExtractedFunc.BodyRange = MaybeExtractedSubexpr->getRange(LangOpts); + } else { + ExtractedFunc.BodyRange = ExtZone.ZoneRange; + } + ExtractedFunc.DefinitionPoint = ExtZone.getInsertionPoint(); ExtractedFunc.CallerReturnsValue = CapturedInfo.AlwaysReturns; - if (!createParameters(ExtractedFunc, CapturedInfo) || - !generateReturnProperties(ExtractedFunc, *ExtZone.EnclosingFunction, - CapturedInfo)) + + auto MaybeRetType{generateReturnProperties(ExtZone, CapturedInfo)}; + auto MaybeParams{createParams(MaybeExtractedSubexpr, CapturedInfo, ASTCont)}; + if (not MaybeRetType || not MaybeParams) return error("Too complex to extract."); + ExtractedFunc.ReturnType = std::move(*MaybeRetType); + ExtractedFunc.Parameters = std::move(*MaybeParams); return ExtractedFunc; } @@ -913,8 +1235,8 @@ tooling::Replacements OtherEdit( createForwardDeclaration(*ExtractedFunc, SM)); - if (auto PathAndEdit = Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc), - OtherEdit)) + if (auto PathAndEdit = + Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc), OtherEdit)) MultiFileEffect->ApplyEdits.try_emplace(PathAndEdit->first, PathAndEdit->second); else diff --git a/clang-tools-extra/clangd/tool/CMakeLists.txt b/clang-tools-extra/clangd/tool/CMakeLists.txt --- a/clang-tools-extra/clangd/tool/CMakeLists.txt +++ b/clang-tools-extra/clangd/tool/CMakeLists.txt @@ -16,6 +16,7 @@ clang_target_link_libraries(clangd PRIVATE clangAST + clangASTMatchers clangBasic clangFormat clangFrontend diff --git a/clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp b/clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp --- a/clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp +++ b/clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp @@ -24,8 +24,6 @@ // Root statements should have common parent. EXPECT_EQ(apply("for(;;) [[1+2; 1+2;]]"), "unavailable"); - // Expressions aren't extracted. - EXPECT_EQ(apply("int x = 0; [[x++;]]"), "unavailable"); // We don't support extraction from lambdas. EXPECT_EQ(apply("auto lam = [](){ [[int x;]] }; "), "unavailable"); // Partial statements aren't extracted. @@ -190,6 +188,16 @@ }]] )cpp"; EXPECT_EQ(apply(CompoundFailInput), "unavailable"); + + std::string CompoundWithMultipleStatementsFailInput = R"cpp( + void f() [[{ + int a = 1; + int b = 2; + ++b; + b += a; + }]] + )cpp"; + EXPECT_EQ(apply(CompoundWithMultipleStatementsFailInput), "unavailable"); } TEST_F(ExtractFunctionTest, DifferentHeaderSourceTest) { @@ -571,6 +579,795 @@ EXPECT_EQ(apply(Before), After); } +TEST_F(ExtractFunctionTest, Expressions) { + std::vector> InputOutputs{ + // FULL BINARY EXPRESSIONS + // Full binary expression, basic maths + {R"cpp( +void wrapperFun() { + double a{2.0}, b{3.2}, c{31.55}; + double v{[[b * b - 4 * a * c]]}; +} + )cpp", + R"cpp( +double extracted(double &a, double &b, double &c) { +return b * b - 4 * a * c; +} +void wrapperFun() { + double a{2.0}, b{3.2}, c{31.55}; + double v{extracted(a, b, c)}; +} + )cpp"}, + // Full binary expression composed of '+' operator overloads ops + { + R"cpp( +struct S { + S operator+(const S&) { + return *this; + } +}; +void wrapperFun() { + S S1, S2, S3; + auto R{[[S1 + S2 + S3]]}; +} + )cpp", + R"cpp( +struct S { + S operator+(const S&) { + return *this; + } +}; +S extracted(S &S1, S &S2, S &S3) { +return S1 + S2 + S3; +} +void wrapperFun() { + S S1, S2, S3; + auto R{extracted(S1, S2, S3)}; +} + )cpp"}, + // Boolean predicate as expression + { + R"cpp( +void wrapperFun() { + int a{1}; + auto R{[[a > 1]]}; +} + )cpp", + R"cpp( +bool extracted(int &a) { +return a > 1; +} +void wrapperFun() { + int a{1}; + auto R{extracted(a)}; +} + )cpp"}, + // Expression: captures no global variable + {R"cpp( +static int a{2}; +void wrapperFun() { + int b{3}, c{31}, d{311}; + auto v{[[a + b + c + d]]}; +} + )cpp", + R"cpp( +static int a{2}; +int extracted(int &b, int &c, int &d) { +return a + b + c + d; +} +void wrapperFun() { + int b{3}, c{31}, d{311}; + auto v{extracted(b, c, d)}; +} + )cpp"}, + // Full expr: infers return type of call returning by ref + { + R"cpp( +struct S { + S& operator+(const S&) { + return *this; + } +}; +void wrapperFun() { + S S1, S2, S3; + auto R{[[S1 + S2 + S3]]}; +} + )cpp", + R"cpp( +struct S { + S& operator+(const S&) { + return *this; + } +}; +S & extracted(S &S1, S &S2, S &S3) { +return S1 + S2 + S3; +} +void wrapperFun() { + S S1, S2, S3; + auto R{extracted(S1, S2, S3)}; +} + )cpp"}, + // Full expr: infers return type of call returning by const-ref + { + R"cpp( +struct S { + const S& operator+(const S&) const { + return *this; + } +}; +void wrapperFun() { + S S1, S2, S3; + auto R{[[S1 + S2 + S3]]}; +} + )cpp", + R"cpp( +struct S { + const S& operator+(const S&) const { + return *this; + } +}; +const S & extracted(S &S1, S &S2, S &S3) { +return S1 + S2 + S3; +} +void wrapperFun() { + S S1, S2, S3; + auto R{extracted(S1, S2, S3)}; +} + )cpp"}, + // Captures deeply nested arguments + { + R"cpp( +int fw(int a) { return a; }; +int add(int a, int b) { return a + b; } +void wrapper() { + int a{0}, b{1}, c{2}, d{3}, e{4}, f{5}; + int r{[[fw(fw(fw(a))) + fw(fw(add(b, c))) + fw(fw(fw(add(d, e)))) + fw(fw(f))]]}; +} + )cpp", + R"cpp( +int fw(int a) { return a; }; +int add(int a, int b) { return a + b; } +int extracted(int &a, int &b, int &c, int &d, int &e, int &f) { +return fw(fw(fw(a))) + fw(fw(add(b, c))) + fw(fw(fw(add(d, e)))) + fw(fw(f)); +} +void wrapper() { + int a{0}, b{1}, c{2}, d{3}, e{4}, f{5}; + int r{extracted(a, b, c, d, e, f)}; +} + )cpp"}, + // SUBEXPRESSIONS + // Left-aligned subexpression + {R"cpp( +void wrapperFun() { + int a{2}, b{3}, c{31}, d{13}; + auto v{[[a + b]] + c + d}; +} + )cpp", + R"cpp( +int extracted(int &a, int &b) { +return a + b; +} +void wrapperFun() { + int a{2}, b{3}, c{31}, d{13}; + auto v{extracted(a, b) + c + d}; +} + )cpp"}, + {R"cpp( +void wrapperFun() { + int a{2}, b{3}, c{31}, d{13}; + auto v{[[a + b + c]] + d}; +} + )cpp", + R"cpp( +int extracted(int &a, int &b, int &c) { +return a + b + c; +} +void wrapperFun() { + int a{2}, b{3}, c{31}, d{13}; + auto v{extracted(a, b, c) + d}; +} + )cpp"}, + // Subexpression from the middle + {R"cpp( +void wrapperFun() { + int a{2}, b{3}, c{31}, d{15}, e{300}; + auto v{a + [[b + c + d]] + e}; +} + )cpp", + R"cpp( +int extracted(int &b, int &c, int &d) { +return b + c + d; +} +void wrapperFun() { + int a{2}, b{3}, c{31}, d{15}, e{300}; + auto v{a + extracted(b, c, d) + e}; +} + )cpp"}, + // Right-aligned subexpression + {R"cpp( +void wrapperFun() { + int a{2}, b{3}, c{31}, d{15}, e{300}; + auto v{a + b + [[c + d + e]]}; +} + )cpp", + R"cpp( +int extracted(int &c, int &d, int &e) { +return c + d + e; +} +void wrapperFun() { + int a{2}, b{3}, c{31}, d{15}, e{300}; + auto v{a + b + extracted(c, d, e)}; +} + )cpp"}, + // Larger subexpression from the middle + {R"cpp( +void wrapperFun() { + int a{2}, b{3}, c{31}, d{311}; + auto v{a + [[a + b + c + d]] + c}; +} + )cpp", + R"cpp( +int extracted(int &a, int &b, int &c, int &d) { +return a + b + c + d; +} +void wrapperFun() { + int a{2}, b{3}, c{31}, d{311}; + auto v{a + extracted(a, b, c, d) + c}; +} + )cpp"}, + // Subexpression with duplicated references + {R"cpp( +void wrapperFun() { + int a{2}, b{3}, c{31}, d{311}; + auto v{a + b + [[c + c + c + d + d]] + c}; +} + )cpp", + R"cpp( +int extracted(int &c, int &d) { +return c + c + c + d + d; +} +void wrapperFun() { + int a{2}, b{3}, c{31}, d{311}; + auto v{a + b + extracted(c, d) + c}; +} + )cpp"}, + // Subexpression: captures no global variable + {R"cpp( +static int a{2}; +void wrapperFun() { + int b{3}, c{31}, d{311}; + auto v{[[a + b + c]] + d}; +} + )cpp", + R"cpp( +static int a{2}; +int extracted(int &b, int &c) { +return a + b + c; +} +void wrapperFun() { + int b{3}, c{31}, d{311}; + auto v{extracted(b, c) + d}; +} + )cpp"}, + // Subexpression: infers return type of call returning by ref, LHS + { + R"cpp( +struct LargeStruct { + char LargeMember[1024]; + LargeStruct& get() { + return *this; + } + LargeStruct operator+(const LargeStruct&) { + return *this; + } +}; +void wrapperFun() { + LargeStruct LS1, LS2; + auto LS3{[[LS1.get()]] + LS2}; +} + )cpp", + R"cpp( +struct LargeStruct { + char LargeMember[1024]; + LargeStruct& get() { + return *this; + } + LargeStruct operator+(const LargeStruct&) { + return *this; + } +}; +LargeStruct & extracted(LargeStruct &LS1) { +return LS1.get(); +} +void wrapperFun() { + LargeStruct LS1, LS2; + auto LS3{extracted(LS1) + LS2}; +} + )cpp"}, + // Subexpression: infers return type of call returning by ref, most-RHS + { + R"cpp( +struct LargeStruct { + char LargeMember[1024]; + LargeStruct& get() { + return *this; + } + LargeStruct operator+(const LargeStruct&) { + return *this; + } +}; +void wrapperFun() { + LargeStruct LS1, LS2, LS3; + auto LS4{LS1 + LS2 + [[LS3.get()]]}; +} + )cpp", + R"cpp( +struct LargeStruct { + char LargeMember[1024]; + LargeStruct& get() { + return *this; + } + LargeStruct operator+(const LargeStruct&) { + return *this; + } +}; +LargeStruct & extracted(LargeStruct &LS3) { +return LS3.get(); +} +void wrapperFun() { + LargeStruct LS1, LS2, LS3; + auto LS4{LS1 + LS2 + extracted(LS3)}; +} + )cpp"}, + // Subexpression: infers return type of call returning by ref, middle RHS + { + R"cpp( +struct LargeStruct { + char LargeMember[1024]; + LargeStruct& get() { + return *this; + } + LargeStruct getCopy() { + return *this; + } + LargeStruct operator+(const LargeStruct&) { + return *this; + } +}; +void wrapperFun() { + LargeStruct LS1, LS2, LS3; + auto LS4{LS1.getCopy() + [[LS2.get()]] + LS3}; +} + )cpp", + R"cpp( +struct LargeStruct { + char LargeMember[1024]; + LargeStruct& get() { + return *this; + } + LargeStruct getCopy() { + return *this; + } + LargeStruct operator+(const LargeStruct&) { + return *this; + } +}; +LargeStruct & extracted(LargeStruct &LS2) { +return LS2.get(); +} +void wrapperFun() { + LargeStruct LS1, LS2, LS3; + auto LS4{LS1.getCopy() + extracted(LS2) + LS3}; +} + )cpp"}, + // Subexpr: infers return type of call returning by const-ref + { + R"cpp( +struct LargeStruct { + char LargeMember[1024]; + const LargeStruct& get() { + return *this; + } + LargeStruct operator+(const LargeStruct&) { + return *this; + } +}; +void wrapperFun() { + LargeStruct LS1, LS2; + auto LS3{LS1 + [[LS2.get()]]}; +} + )cpp", + R"cpp( +struct LargeStruct { + char LargeMember[1024]; + const LargeStruct& get() { + return *this; + } + LargeStruct operator+(const LargeStruct&) { + return *this; + } +}; +const LargeStruct & extracted(LargeStruct &LS2) { +return LS2.get(); +} +void wrapperFun() { + LargeStruct LS1, LS2; + auto LS3{LS1 + extracted(LS2)}; +} + )cpp"}, + // Subexpression on operator overload, left-aligned + { + R"cpp( +struct LargeStruct { + char LargeMember[1024]; + const LargeStruct& get() { + return *this; + } + LargeStruct& operator+(const LargeStruct&) { + return *this; + } +}; +void wrapperFun() { + LargeStruct LS1, LS2, LS3, LS4; + auto& LS5{[[LS1 + LS2.get()]] + LS3.get() + LS4}; +} + )cpp", + R"cpp( +struct LargeStruct { + char LargeMember[1024]; + const LargeStruct& get() { + return *this; + } + LargeStruct& operator+(const LargeStruct&) { + return *this; + } +}; +LargeStruct & extracted(LargeStruct &LS1, LargeStruct &LS2) { +return LS1 + LS2.get(); +} +void wrapperFun() { + LargeStruct LS1, LS2, LS3, LS4; + auto& LS5{extracted(LS1, LS2) + LS3.get() + LS4}; +} + )cpp"}, + { + R"cpp( +struct LargeStruct { + char LargeMember[1024]; + const LargeStruct& get() { + return *this; + } + LargeStruct& operator+(const LargeStruct&) { + return *this; + } +}; +void wrapperFun() { + LargeStruct LS1, LS2, LS3, LS4; + auto& LS5{[[LS1 + LS2.get() + LS3.get()]] + LS4}; +} + )cpp", + R"cpp( +struct LargeStruct { + char LargeMember[1024]; + const LargeStruct& get() { + return *this; + } + LargeStruct& operator+(const LargeStruct&) { + return *this; + } +}; +LargeStruct & extracted(LargeStruct &LS1, LargeStruct &LS2, LargeStruct &LS3) { +return LS1 + LS2.get() + LS3.get(); +} +void wrapperFun() { + LargeStruct LS1, LS2, LS3, LS4; + auto& LS5{extracted(LS1, LS2, LS3) + LS4}; +} + )cpp"}, + // Subexpression on operator overload, middle-aligned + { + R"cpp( +struct LargeStruct { + char LargeMember[1024]; + const LargeStruct& get() { + return *this; + } + LargeStruct& operator+(const LargeStruct&) { + return *this; + } +}; +void wrapperFun() { + LargeStruct LS1, LS2, LS3, LS4, LS5; + auto& R{LS1 + [[LS2.get() + LS3 + LS4.get()]] + LS5}; +} + )cpp", + R"cpp( +struct LargeStruct { + char LargeMember[1024]; + const LargeStruct& get() { + return *this; + } + LargeStruct& operator+(const LargeStruct&) { + return *this; + } +}; +LargeStruct & extracted(LargeStruct &LS2, LargeStruct &LS3, LargeStruct &LS4) { +return LS2.get() + LS3 + LS4.get(); +} +void wrapperFun() { + LargeStruct LS1, LS2, LS3, LS4, LS5; + auto& R{LS1 + extracted(LS2, LS3, LS4) + LS5}; +} + )cpp"}, + // Subexpression on operator overload, right-aligned + { + R"cpp( +struct LargeStruct { + char LargeMember[1024]; + const LargeStruct& get() { + return *this; + } + LargeStruct& operator+(const LargeStruct&) { + return *this; + } +}; +void wrapperFun() { + LargeStruct LS1, LS2, LS3, LS4, LS5; + auto& R{LS1 + LS2.get() + [[LS3 + LS4.get() + LS5]]}; +})cpp", + R"cpp( +struct LargeStruct { + char LargeMember[1024]; + const LargeStruct& get() { + return *this; + } + LargeStruct& operator+(const LargeStruct&) { + return *this; + } +}; +LargeStruct & extracted(LargeStruct &LS3, LargeStruct &LS4, LargeStruct &LS5) { +return LS3 + LS4.get() + LS5; +} +void wrapperFun() { + LargeStruct LS1, LS2, LS3, LS4, LS5; + auto& R{LS1 + LS2.get() + extracted(LS3, LS4, LS5)}; +})cpp"}, + // Boolean predicate as subexpression + { + R"cpp( +void wrapperFun() { + int a{1}, b{2}; + auto R{a > 1 ? [[b <= 0]] : false}; +} + )cpp", + R"cpp( +bool extracted(int &b) { +return b <= 0; +} +void wrapperFun() { + int a{1}, b{2}; + auto R{a > 1 ? extracted(b) : false}; +} + )cpp"}, + // Collects deeply nested arguments, left-aligned + { + R"cpp( +int fw(int a) { return a; }; +int add(int a, int b) { return a + b; } +void wrapper() { + int a{0}, b{1}, c{2}, d{3}, e{4}, f{5}; + int r{[[fw(fw(fw(a))) + fw(fw(add(b, c))) + fw(fw(fw(add(d, e))))]] + fw(fw(f))}; +} + )cpp", + R"cpp( +int fw(int a) { return a; }; +int add(int a, int b) { return a + b; } +int extracted(int &a, int &b, int &c, int &d, int &e) { +return fw(fw(fw(a))) + fw(fw(add(b, c))) + fw(fw(fw(add(d, e)))); +} +void wrapper() { + int a{0}, b{1}, c{2}, d{3}, e{4}, f{5}; + int r{extracted(a, b, c, d, e) + fw(fw(f))}; +} + )cpp"}, + // Collects deeply nested arguments, middle-aligned + { + R"cpp( +int fw(int a) { return a; }; +int add(int a, int b) { return a + b; } +void wrapper() { + int a{0}, b{1}, c{2}, d{3}, e{4}, f{5}; + int r{fw(fw(fw(a))) + [[fw(fw(add(b, c))) + fw(fw(fw(add(d, e))))]] + fw(fw(f))}; +} + )cpp", + R"cpp( +int fw(int a) { return a; }; +int add(int a, int b) { return a + b; } +int extracted(int &b, int &c, int &d, int &e) { +return fw(fw(add(b, c))) + fw(fw(fw(add(d, e)))); +} +void wrapper() { + int a{0}, b{1}, c{2}, d{3}, e{4}, f{5}; + int r{fw(fw(fw(a))) + extracted(b, c, d, e) + fw(fw(f))}; +} + )cpp"}, + // Collects deeply nested arguments, right-aligned + { + R"cpp( +int fw(int a) { return a; }; +int add(int a, int b) { return a + b; } +void wrapper() { + int a{0}, b{1}, c{2}, d{3}, e{4}, f{5}; + int r{fw(fw(fw(a))) + [[fw(fw(add(b, c))) + fw(fw(fw(add(d, e)))) + fw(fw(f))]]}; +} + )cpp", + R"cpp( +int fw(int a) { return a; }; +int add(int a, int b) { return a + b; } +int extracted(int &b, int &c, int &d, int &e, int &f) { +return fw(fw(add(b, c))) + fw(fw(fw(add(d, e)))) + fw(fw(f)); +} +void wrapper() { + int a{0}, b{1}, c{2}, d{3}, e{4}, f{5}; + int r{fw(fw(fw(a))) + extracted(b, c, d, e, f)}; +} + )cpp"}, + // FIXME: Support macros: In this case the most-LHS is not omitted! + {R"cpp( +#define ECHO(X) X +void f() { + int x = 1 + [[ECHO(2 + 3) + 4]] + 5; +})cpp", + R"cpp( +#define ECHO(X) X +int extracted() { +return 1 + ECHO(2 + 3) + 4; +} +void f() { + int x = extracted() + 5; +})cpp"}, + }; + + for (const auto &[Input, Output] : InputOutputs) { + EXPECT_EQ(Output, apply(Input)) << Input; + } +} + +TEST_F(ExtractFunctionTest, ExpressionsInMethodsSingleFile) { + // TODO: unavailable + // TODO: available + + std::vector> InputOutputs{ + // Expression: Does not capture members as parameters + // FIXME: If selected area does mutate members, make extracted() const + {R"cpp( +struct S { +void f() const { + int a{1}, b{2}; + auto r{[[a + b + mem1 + mem2]]}; +} +int mem1{0}, mem2{0}; +}; +)cpp", + R"cpp( +struct S { +int extracted(int &a, int &b) const { +return a + b + mem1 + mem2; +} +void f() const { + int a{1}, b{2}; + auto r{extracted(a, b)}; +} +int mem1{0}, mem2{0}; +}; +)cpp"}, + // Subexpression: Does not capture members as parameters + {R"cpp( +struct S { +void f() const { + int a{1}, b{2}; + auto r{a + [[mem1 + mem2 + b + mem1]] + mem2}; +} +int mem1{0}, mem2{0}; +}; +)cpp", + R"cpp( +struct S { +int extracted(int &b) const { +return mem1 + mem2 + b + mem1; +} +void f() const { + int a{1}, b{2}; + auto r{a + extracted(b) + mem2}; +} +int mem1{0}, mem2{0}; +}; +)cpp"}, + }; + + for (const auto &[Input, Output] : InputOutputs) { + EXPECT_EQ(Output, apply(Input)) << Input; + } +} + +TEST_F(ExtractFunctionTest, ExpressionInMethodMultiFile) { + Header = R"cpp( + class SomeClass { + void f(); + int mem1{0}, mem2{0}; + }; + )cpp"; + + std::string OutOfLineSource = R"cpp( + void SomeClass::f() { + int a{1}, b{2}; + int x = [[a + mem1 + b + mem2]]; + } + )cpp"; + + std::string OutOfLineSourceOutputCheck = R"cpp( + int SomeClass::extracted(int &a, int &b) { +return a + mem1 + b + mem2; +} +void SomeClass::f() { + int a{1}, b{2}; + int x = extracted(a, b); + } + )cpp"; + + std::string HeaderOutputCheck = R"cpp( + class SomeClass { + int extracted(int &a, int &b); +void f(); + int mem1{0}, mem2{0}; + }; + )cpp"; + + llvm::StringMap EditedFiles; + + EXPECT_EQ(apply(OutOfLineSource, &EditedFiles), OutOfLineSourceOutputCheck); + EXPECT_EQ(EditedFiles.begin()->second, HeaderOutputCheck); +} + +TEST_F(ExtractFunctionTest, SubexpressionInMethodMultiFile) { + Header = R"cpp( + class SomeClass { + void f(); + int mem1{0}, mem2{0}; + }; + )cpp"; + + std::string OutOfLineSource = R"cpp( + void SomeClass::f() { + int a{1}, b{2}; + int x = a + [[mem1 + b + mem2]] + mem1; + } + )cpp"; + + std::string OutOfLineSourceOutputCheck = R"cpp( + int SomeClass::extracted(int &b) { +return mem1 + b + mem2; +} +void SomeClass::f() { + int a{1}, b{2}; + int x = a + extracted(b) + mem1; + } + )cpp"; + + std::string HeaderOutputCheck = R"cpp( + class SomeClass { + int extracted(int &b); +void f(); + int mem1{0}, mem2{0}; + }; + )cpp"; + + llvm::StringMap EditedFiles; + + EXPECT_EQ(apply(OutOfLineSource, &EditedFiles), OutOfLineSourceOutputCheck); + EXPECT_EQ(EditedFiles.begin()->second, HeaderOutputCheck); +} + } // namespace } // namespace clangd } // namespace clang