diff --git a/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp b/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp --- a/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp +++ b/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp @@ -13,6 +13,7 @@ #include "refactor/Tweak.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" +#include "clang/AST/ExprCXX.h" #include "clang/AST/OperationKinds.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/Stmt.h" @@ -37,18 +38,16 @@ const ASTContext &Ctx); const clang::Expr *getExpr() const { return Expr; } const SelectionTree::Node *getExprNode() const { return ExprNode; } - bool isExtractable() const { return Extractable; } // Generate Replacement for replacing selected expression with given VarName tooling::Replacement replaceWithVar(llvm::StringRef VarName) const; // Generate Replacement for declaring the selected Expr as a new variable tooling::Replacement insertDeclaration(llvm::StringRef VarName) const; + const clang::Stmt *InsertionPoint = nullptr; private: - bool Extractable = false; const clang::Expr *Expr; const SelectionTree::Node *ExprNode; // Stmt before which we will extract - const clang::Stmt *InsertionPoint = nullptr; const SourceManager &SM; const ASTContext &Ctx; // Decls referenced in the Expr @@ -77,29 +76,13 @@ return Visitor.ReferencedDecls; } -// An expr is not extractable if it's null or an expression of type void -// FIXME: Ignore assignment (a = 1) Expr since it is extracted as dummy = a = -static bool isExtractableExpr(const clang::Expr *Expr) { - if (Expr) { - const Type *ExprType = Expr->getType().getTypePtrOrNull(); - // FIXME: check if we need to cover any other types - if (ExprType) - return !ExprType->isVoidType(); - } - return false; -} - ExtractionContext::ExtractionContext(const SelectionTree::Node *Node, const SourceManager &SM, const ASTContext &Ctx) : ExprNode(Node), SM(SM), Ctx(Ctx) { Expr = Node->ASTNode.get(); - if (isExtractableExpr(Expr)) { - ReferencedDecls = computeReferencedDecls(Expr); - InsertionPoint = computeInsertionPoint(); - if (InsertionPoint) - Extractable = true; - } + ReferencedDecls = computeReferencedDecls(Expr); + InsertionPoint = computeInsertionPoint(); } // checks whether extracting before InsertionPoint will take a @@ -121,9 +104,9 @@ // the current Stmt. We ALWAYS insert before a Stmt whose parent is a // CompoundStmt // - // FIXME: Extraction from switch and case statements // FIXME: Doens't work for FoldExpr +// FIXME: Ensure extraction from loops doesn't change semantics const clang::Stmt *ExtractionContext::computeInsertionPoint() const { // returns true if we can extract before InsertionPoint auto CanExtractOutside = @@ -209,6 +192,9 @@ return "Extract subexpression to variable"; } Intent intent() const override { return Refactor; } + // Compute the extraction context for the Selection + bool computeExtractionContext(const SelectionTree::Node *N, + const SourceManager &SM, const ASTContext &Ctx); private: // the expression to extract @@ -219,10 +205,9 @@ const ASTContext &Ctx = Inputs.AST.getASTContext(); const SourceManager &SM = Inputs.AST.getSourceManager(); const SelectionTree::Node *N = Inputs.ASTSelection.commonAncestor(); - if (!N) + if (!computeExtractionContext(N, SM, Ctx)) return false; - Target = llvm::make_unique(N, SM, Ctx); - return Target->isExtractable(); + return Target->InsertionPoint; } Expected ExtractVariable::apply(const Selection &Inputs) { @@ -238,6 +223,75 @@ return Effect::applyEdit(Result); } +// Find the CallExpr whose callee is an ancestor of the DeclRef +const SelectionTree::Node *getCallExpr(const SelectionTree::Node *DeclRef) { + // we maintain a stack of all exprs encountered while traversing the + // selectiontree because the callee of the callexpr can be an ancestor of the + // DeclRef. e.g. Callee can be an ImplicitCastExpr. + std::vector ExprStack; + for (auto *CurNode = DeclRef; CurNode; CurNode = CurNode->Parent) { + const Expr *CurExpr = CurNode->ASTNode.get(); + if (const CallExpr *CallPar = CurNode->ASTNode.get()) { + // check whether the callee of the callexpr is present in Expr stack. + if (std::find(ExprStack.begin(), ExprStack.end(), CallPar->getCallee()) != + ExprStack.end()) + return CurNode; + return nullptr; + } + ExprStack.push_back(CurExpr); + } + return nullptr; +} + +// check if Expr can be assigned to a variable i.e. is non-void type +bool canBeAssigned(const SelectionTree::Node *ExprNode) { + const clang::Expr *Expr = ExprNode->ASTNode.get(); + if (const Type *ExprType = Expr->getType().getTypePtrOrNull()) + // FIXME: check if we need to cover any other types + return !ExprType->isVoidType(); + return true; +} + +// Find the node that will form our ExtractionContext. +// We don't want to trigger for assignment expressions and variable/field +// DeclRefs. For function/member function, we want to extract the entire +// function call. +bool ExtractVariable::computeExtractionContext(const SelectionTree::Node *N, + const SourceManager &SM, + const ASTContext &Ctx) { + if(!N) + return false; + const clang::Expr *SelectedExpr = N->ASTNode.get(); + const SelectionTree::Node *TargetNode = N; + if (!SelectedExpr) + return false; + // Extracting Exprs like a = 1 gives dummy = a = 1 which isn't useful. + if (const BinaryOperator *BinOpExpr = + dyn_cast_or_null(SelectedExpr)) { + if (BinOpExpr->getOpcode() == BinaryOperatorKind::BO_Assign) + return false; + } + // For function and member function DeclRefs, we look for a parent that is a + // CallExpr + if (const DeclRefExpr *DeclRef = + dyn_cast_or_null(SelectedExpr)) { + // Extracting just a variable isn't that useful. + if (!isa(DeclRef->getDecl())) + return false; + TargetNode = getCallExpr(N); + } + if (const MemberExpr *Member = dyn_cast_or_null(SelectedExpr)) { + // Extracting just a field member isn't that useful. + if (!isa(Member->getMemberDecl())) + return false; + TargetNode = getCallExpr(N); + } + if (!TargetNode || !canBeAssigned(TargetNode)) + return false; + Target = llvm::make_unique(TargetNode, SM, Ctx); + return true; +} + } // namespace } // namespace clangd } // namespace clang diff --git a/clang-tools-extra/clangd/unittests/TweakTests.cpp b/clang-tools-extra/clangd/unittests/TweakTests.cpp --- a/clang-tools-extra/clangd/unittests/TweakTests.cpp +++ b/clang-tools-extra/clangd/unittests/TweakTests.cpp @@ -291,18 +291,23 @@ const char *Input = "struct ^X { int x; int y; }"; EXPECT_THAT(getMessage(ID, Input), ::testing::HasSubstr("0 | int x")); } + TEST(TweakTest, ExtractVariable) { llvm::StringLiteral ID = "ExtractVariable"; checkAvailable(ID, R"cpp( - int xyz() { + int xyz(int a = 1) { + struct T { + int bar(int a = 1); + int z; + } t; // return statement - return ^1; + return [[[[t.bar]](t.z)]]; } void f() { int a = 5 + [[4 ^* ^xyz^()]]; // multivariable initialization if(1) - int x = ^1, y = ^a + 1, a = ^1, z = a + 1; + int x = ^1, y = [[a + 1]], a = ^1, z = a + 1; // if without else if(^1) {} // if with else @@ -315,13 +320,13 @@ a = ^4; else a = ^5; - // for loop + // for loop for(a = ^1; a > ^3^+^4; a++) a = ^2; - // while + // while while(a < ^1) - ^a++; - // do while + [[a++]]; + // do while do a = ^1; while(a < ^3); @@ -337,28 +342,34 @@ )cpp"); checkNotAvailable(ID, R"cpp( int xyz(int a = ^1) { - return 1; - class T { - T(int a = ^1) {}; - int xyz = ^1; - }; + struct T { + int bar(int a = [[1]]); + int z = [[1]]; + } t; + return t.bar([[t.z]]); } + void v() { return; } // function default argument void f(int b = ^1) { // void expressions auto i = new int, j = new int; de^lete i^, del^ete j; + [[v]](); // if if(1) int x = 1, y = a + 1, a = 1, z = ^a + 1; if(int a = 1) if(^a == 4) a = ^a ^+ 1; - // for loop + // for loop for(int a = 1, b = 2, c = 3; ^a > ^b ^+ ^c; ^a++) a = ^a ^+ 1; - // lambda + // lambda auto lamb = [&^a, &^b](int r = ^1) {return 1;} + // assigment + [[a ^= 5]]; + // Variable DeclRefExpr + a = [[b]]; } )cpp"); // vector of pairs of input and output strings @@ -412,6 +423,24 @@ R"cpp(void f(int a) { auto dummy = 1; label: [ [gsl::suppress("type")] ] for (;;) a = dummy; })cpp"}, + // MemberExpr + {R"cpp(class T { + T f() { + return [[T().f().f]](); + } + };)cpp", + R"cpp(class T { + T f() { + auto dummy = T().f().f(); return dummy; + } + };)cpp"}, + // Function DeclRefExpr + {R"cpp(int f() { + return [[f]](); + })cpp", + R"cpp(int f() { + auto dummy = f(); return dummy; + })cpp"}, // FIXME: Doesn't work because bug in selection tree /*{R"cpp(#define PLUS(x) x++ void f(int a) { @@ -421,8 +450,8 @@ void f(int a) { auto dummy = a; PLUS(dummy); })cpp"},*/ - // FIXME: Doesn't work correctly for \[\[clang::uninitialized\]\] int b - // = 1; since the attr is inside the DeclStmt and the bounds of + // FIXME: Wrong result for \[\[clang::uninitialized\]\] int b = 1; + // since the attr is inside the DeclStmt and the bounds of // DeclStmt don't cover the attribute }; for (const auto &IO : InputOutputs) {