diff --git a/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp b/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp --- a/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp +++ b/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp @@ -13,6 +13,7 @@ #include "refactor/Tweak.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" +#include "clang/AST/ExprCXX.h" #include "clang/AST/OperationKinds.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/Stmt.h" @@ -77,29 +78,15 @@ return Visitor.ReferencedDecls; } -// An expr is not extractable if it's null or an expression of type void -// FIXME: Ignore assignment (a = 1) Expr since it is extracted as dummy = a = -static bool isExtractableExpr(const clang::Expr *Expr) { - if (Expr) { - const Type *ExprType = Expr->getType().getTypePtrOrNull(); - // FIXME: check if we need to cover any other types - if (ExprType) - return !ExprType->isVoidType(); - } - return false; -} - ExtractionContext::ExtractionContext(const SelectionTree::Node *Node, const SourceManager &SM, const ASTContext &Ctx) : ExprNode(Node), SM(SM), Ctx(Ctx) { Expr = Node->ASTNode.get(); - if (isExtractableExpr(Expr)) { - ReferencedDecls = computeReferencedDecls(Expr); - InsertionPoint = computeInsertionPoint(); - if (InsertionPoint) - Extractable = true; - } + ReferencedDecls = computeReferencedDecls(Expr); + InsertionPoint = computeInsertionPoint(); + if (InsertionPoint) + Extractable = true; } // checks whether extracting before InsertionPoint will take a @@ -121,9 +108,9 @@ // the current Stmt. We ALWAYS insert before a Stmt whose parent is a // CompoundStmt // - -// FIXME: Extraction from switch and case statements +// FIXME: Extraction from label, switch and case statements // FIXME: Doens't work for FoldExpr +// FIXME: Ensure extraction from loops doesn't change semantics. const clang::Stmt *ExtractionContext::computeInsertionPoint() const { // returns true if we can extract before InsertionPoint auto CanExtractOutside = @@ -141,8 +128,7 @@ return isa(Stmt) || isa(Stmt) || isa(Stmt) || isa(Stmt) || isa(Stmt) || isa(Stmt) || isa(Stmt) || - isa(Stmt) || isa(Stmt) || - isa(Stmt); + isa(Stmt) || isa(Stmt); } if (InsertionPoint->ASTNode.get()) return true; @@ -209,6 +195,9 @@ return "Extract subexpression to variable"; } Intent intent() const override { return Refactor; } + // Compute the extraction context for the Selection + bool computeExtractionContext(const SelectionTree::Node *N, + const SourceManager &SM, const ASTContext &Ctx); private: // the expression to extract @@ -216,14 +205,13 @@ }; REGISTER_TWEAK(ExtractVariable) bool ExtractVariable::prepare(const Selection &Inputs) { + // we don't trigger on empty selections for now + if (Inputs.SelectionBegin == Inputs.SelectionEnd) + return false; const ASTContext &Ctx = Inputs.AST.getASTContext(); const SourceManager &SM = Inputs.AST.getSourceManager(); const SelectionTree::Node *N = Inputs.ASTSelection.commonAncestor(); - // we don't trigger on empty selections for now - if (!N || Inputs.SelectionBegin == Inputs.SelectionEnd) - return false; - Target = llvm::make_unique(N, SM, Ctx); - return Target->isExtractable(); + return computeExtractionContext(N, SM, Ctx); } Expected ExtractVariable::apply(const Selection &Inputs) { @@ -239,6 +227,75 @@ return Effect::applyEdit(Result); } +// Find the CallExpr whose callee is an ancestor of the DeclRef +const SelectionTree::Node *getCallExpr(const SelectionTree::Node *DeclRef) { + // we maintain a stack of all exprs encountered while traversing the + // selectiontree because the callee of the callexpr can be an ancestor of the + // DeclRef. e.g. Callee can be an ImplicitCastExpr. + std::vector ExprStack; + for (auto *CurNode = DeclRef; CurNode; CurNode = CurNode->Parent) { + const Expr *CurExpr = CurNode->ASTNode.get(); + if (const CallExpr *CallPar = CurNode->ASTNode.get()) { + // check whether the callee of the callexpr is present in Expr stack. + if (std::find(ExprStack.begin(), ExprStack.end(), CallPar->getCallee()) != + ExprStack.end()) + return CurNode; + return nullptr; + } + ExprStack.push_back(CurExpr); + } + return nullptr; +} + +// check if Expr can be assigned to a variable i.e. is non-void type +bool canBeAssigned(const SelectionTree::Node *ExprNode) { + const clang::Expr *Expr = ExprNode->ASTNode.get(); + if (const Type *ExprType = Expr->getType().getTypePtrOrNull()) + // FIXME: check if we need to cover any other types + return !ExprType->isVoidType(); + return true; +} + +// Find the node that will form our ExtractionContext. +// We don't want to trigger for assignment expressions and variable/field +// DeclRefs. For function/member function, we want to extract the entire +// function call. +bool ExtractVariable::computeExtractionContext(const SelectionTree::Node *N, + const SourceManager &SM, + const ASTContext &Ctx) { + if (!N) + return false; + const clang::Expr *SelectedExpr = N->ASTNode.get(); + const SelectionTree::Node *TargetNode = N; + if (!SelectedExpr) + return false; + // Extracting Exprs like a = 1 gives dummy = a = 1 which isn't useful. + if (const BinaryOperator *BinOpExpr = + dyn_cast_or_null(SelectedExpr)) { + if (BinOpExpr->getOpcode() == BinaryOperatorKind::BO_Assign) + return false; + } + // For function and member function DeclRefs, we look for a parent that is a + // CallExpr + if (const DeclRefExpr *DeclRef = + dyn_cast_or_null(SelectedExpr)) { + // Extracting just a variable isn't that useful. + if (!isa(DeclRef->getDecl())) + return false; + TargetNode = getCallExpr(N); + } + if (const MemberExpr *Member = dyn_cast_or_null(SelectedExpr)) { + // Extracting just a field member isn't that useful. + if (!isa(Member->getMemberDecl())) + return false; + TargetNode = getCallExpr(N); + } + if (!TargetNode || !canBeAssigned(TargetNode)) + return false; + Target = llvm::make_unique(TargetNode, SM, Ctx); + return Target->isExtractable(); +} + } // namespace } // namespace clangd } // namespace clang diff --git a/clang-tools-extra/clangd/unittests/TweakTests.cpp b/clang-tools-extra/clangd/unittests/TweakTests.cpp --- a/clang-tools-extra/clangd/unittests/TweakTests.cpp +++ b/clang-tools-extra/clangd/unittests/TweakTests.cpp @@ -291,18 +291,23 @@ const char *Input = "struct ^X { int x; int y; }"; EXPECT_THAT(getMessage(ID, Input), ::testing::HasSubstr("0 | int x")); } + TEST(TweakTest, ExtractVariable) { llvm::StringLiteral ID = "ExtractVariable"; checkAvailable(ID, R"cpp( - int xyz() { + int xyz(int a = 1) { + struct T { + int bar(int a = 1); + int z; + } t; // return statement - return [[1]]; + return [[[[t.b[[a]]r]](t.z)]]; } void f() { int a = [[5 +]] [[4 * [[[[xyz]]()]]]]; // multivariable initialization if(1) - int x = [[1]], y = [[a]] + 1, a = [[1]], z = a + 1; + int x = [[1]], y = [[a + 1]], a = [[1]], z = a + 1; // if without else if([[1]]) a = [[1]]; @@ -316,13 +321,13 @@ a = [[4]]; else a = [[5]]; - // for loop + // for loop for(a = [[1]]; a > [[[[3]] + [[4]]]]; a++) a = [[2]]; - // while + // while while(a < [[1]]) - [[a]]++; - // do while + [[a++]]; + // do while do a = [[1]]; while(a < [[3]]); @@ -332,18 +337,19 @@ checkNotAvailable(ID, R"cpp( template struct Test { - Test(const T &v) :val(^) {} + Test(const T &v) :val[[(^]]) {} T val; }; )cpp"); checkNotAvailable(ID, R"cpp( int xyz(int a = [[1]]) { - return 1; - class T { - T(int a = [[1]]) {}; - int xyz = [[1]]; - }; + struct T { + int bar(int a = [[1]]); + int z = [[1]]; + } t; + return [[t]].bar([[[[t]].z]]); } + void v() { return; } // function default argument void f(int b = [[1]]) { // empty selection @@ -351,17 +357,26 @@ // void expressions auto i = new int, j = new int; [[[[delete i]], delete j]]; + [[v]](); // if if(1) int x = 1, y = a + 1, a = 1, z = [[a + 1]]; if(int a = 1) - if([[a]] == 4) + if([[a + 1]] == 4) a = [[[[a]] +]] 1; - // for loop - for(int a = 1, b = 2, c = 3; [[a]] > [[b + c]]; [[a]]++) + // for loop + for(int a = 1, b = 2, c = 3; a > [[b + c]]; [[a++]]) a = [[a + 1]]; - // lambda + // lambda auto lamb = [&[[a]], &[[b]]](int r = [[1]]) {return 1;} + // assigment + [[a = 5]]; + // Variable DeclRefExpr + a = [[b]]; + // label statement + goto label; + label: + a = [[1]]; } )cpp"); // vector of pairs of input and output strings @@ -397,6 +412,15 @@ break; } })cpp"},*/ + // Macros + {R"cpp(#define PLUS(x) x++ + void f(int a) { + PLUS([[1+a]]); + })cpp", + R"cpp(#define PLUS(x) x++ + void f(int a) { + auto dummy = PLUS(1+a); dummy; + })cpp"}, // ensure InsertionPoint isn't inside a macro {R"cpp(#define LOOP(x) while (1) {a = x;} void f(int a) { @@ -425,25 +449,35 @@ auto dummy = 3; if(1) LOOP(5 + dummy) })cpp"}, - // label and attribute testing + // attribute testing {R"cpp(void f(int a) { - label: [ [gsl::suppress("type")] ] for (;;) a = [[1]]; + [ [gsl::suppress("type")] ] for (;;) a = [[1]]; })cpp", R"cpp(void f(int a) { - auto dummy = 1; label: [ [gsl::suppress("type")] ] for (;;) a = dummy; + auto dummy = 1; [ [gsl::suppress("type")] ] for (;;) a = dummy; })cpp"}, - // macro testing - {R"cpp(#define PLUS(x) x++ - void f(int a) { - PLUS([[a]]); + // MemberExpr + {R"cpp(class T { + T f() { + return [[T().f()]].f(); + } + };)cpp", + R"cpp(class T { + T f() { + auto dummy = T().f(); return dummy.f(); + } + };)cpp"}, + // Function DeclRefExpr + {R"cpp(int f() { + return [[f]](); })cpp", - R"cpp(#define PLUS(x) x++ - void f(int a) { - auto dummy = a; PLUS(dummy); + R"cpp(int f() { + auto dummy = f(); return dummy; })cpp"}, - // FIXME: Doesn't work correctly for \[\[clang::uninitialized\]\] int - // b = [[1]]; since the attr is inside the DeclStmt and the bounds of - // DeclStmt don't cover the attribute + + // FIXME: Wrong result for \[\[clang::uninitialized\]\] int b = [[1]]; + // since the attr is inside the DeclStmt and the bounds of + // DeclStmt don't cover the attribute. }; for (const auto &IO : InputOutputs) { checkTransform(ID, IO.first, IO.second);