diff --git a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt --- a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt +++ b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt @@ -14,6 +14,7 @@ add_clang_library(clangDaemonTweaks OBJECT RawStringLiteral.cpp SwapIfBranches.cpp + ExtractVariable.cpp LINK_LIBS clangAST diff --git a/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp b/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp new file mode 100644 --- /dev/null +++ b/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp @@ -0,0 +1,223 @@ +//===--- ExtractVariable.cpp ------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "ClangdUnit.h" +#include "Logger.h" +#include "Protocol.h" +#include "Selection.h" +#include "SourceCode.h" +#include "refactor/Tweak.h" +#include "clang/AST/ASTContext.h" +#include "clang/AST/Expr.h" +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/AST/Stmt.h" +#include "clang/Basic/LangOptions.h" +#include "clang/Basic/SourceLocation.h" +#include "clang/Basic/SourceManager.h" +#include "clang/Tooling/Core/Replacement.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Error.h" + +namespace clang { +namespace clangd { +namespace { + +// +/// Extracts an expression to the variable dummy +/// Before: +/// int x = 5 + 4 * 3; +/// ^^^^^ +/// After: +/// auto dummy = 5 + 4; +/// int x = dummy * 3; +class ExtractVariable : public Tweak { +public: + const char *id() const override final; + bool prepare(const Selection &Inputs) override; + Expected apply(const Selection &Inputs) override; + std::string title() const override; + +private: + // the expression to extract + const clang::Expr *Exp = nullptr; + // the statement before which variable will be extracted to + const clang::Stmt *Stm = nullptr; +}; +REGISTER_TWEAK(ExtractVariable) + +// checks whether any variable in a given expr is declared in the DeclStmt +static bool isDeclaredIn(const DeclStmt *Decl, const Expr *Exp, + const SourceManager &M) { + + // RAV subclass to find all DeclRefs in a given Stmt + class FindDeclRefsVisitor + : public clang::RecursiveASTVisitor { + public: + std::vector DeclRefExprs; + bool VisitDeclRefExpr(DeclRefExpr *DeclRef) { // NOLINT + DeclRefExprs.push_back(DeclRef); + return true; + } + }; + + FindDeclRefsVisitor Visitor; + Visitor.TraverseStmt(const_cast(dyn_cast(Exp))); + for (const DeclRefExpr *DeclRef : Visitor.DeclRefExprs) { + // Beginning location of the ValueDecl of the DeclRef + auto ValueDeclLoc = DeclRef->getDecl()->getBeginLoc(); + // return false if this ValueDecl location is within the DeclStmt + if (M.isPointWithin(ValueDeclLoc, Decl->getBeginLoc(), Decl->getEndLoc())) + return true; + } + return false; +} +// check whether the expr denoted by N is a part of the Body Stmt +bool isNodeInside(const Stmt *Body, const SelectionTree::Node *N, + const SourceManager &M) { + if(!Body) return false; + auto BodyRng = Body->getSourceRange(); + SourceRange NodeRng = N->ASTNode.getSourceRange(); + return M.isPointWithin(NodeRng.getBegin(), BodyRng.getBegin(), + BodyRng.getEnd()); +} +// Returns true if we will need braces after extraction +static bool needBraces(const SelectionTree::Node *N, const Stmt *Stm, + const SourceManager &M) { + + if (const ForStmt *Stmt = llvm::dyn_cast_or_null(Stm)) + return isNodeInside(Stmt->getBody(), N, M); + if (const WhileStmt *Stmt = llvm::dyn_cast_or_null(Stm)) + return isNodeInside(Stmt->getBody(), N, M); + if (const DoStmt *Stmt = llvm::dyn_cast_or_null(Stm)) + return isNodeInside(Stmt->getBody(), N, M); + if (const CaseStmt *Stmt = llvm::dyn_cast_or_null(Stm)) + return isNodeInside(Stmt->getSubStmt(), N, M); + if (const DefaultStmt *Stmt = llvm::dyn_cast_or_null(Stm)) + return isNodeInside(Stmt->getSubStmt(), N, M); + if (const IfStmt *Stmt = llvm::dyn_cast_or_null(Stm)) + return isNodeInside(Stmt->getThen(), N, M) || + isNodeInside(Stmt->getElse(), N, M); + return false; +} + +// check whether to allow extraction from for(...) +static bool checkFor(const ForStmt *F, const Expr *Exp, + const SourceManager &M) { + const Stmt *Init = F->getInit(); + if (!Init || !isa(Init)) + return true; // FIXME + // if any variable in Exp is declared in the DeclStmt, return false + return !isDeclaredIn(dyn_cast(Init), Exp, M); +} + +// Return the Stmt before which we need to insert the extraction. +// To find the Stmt, we go up the AST Tree and if the Parent of the current Stmt +// is a CompoundStmt, we can extract inside this CompoundStmt just above the +// current Stmt. We ALWAYS insert above a Stmt whose parent is a CompoundStmt +// +// Otherwise if we encounter an if/while/do-while/for/case Stmt, we must check +// whether we are in their "body" or the "condition" part. If we are in the +// "body", that means we need to insert braces i.e. create a CompoundStmt. For +// now, we don't allow extraction if we need to insert braces. Else if we are in +// the "condition" part, we can extract above the if/while/do-while/for/case +// Stmt. Remember that we only insert above a stmt if its Parent Stmt is a +// CompoundStmt, thus we continue looping to find a CompoundStmt. +// +// We have a special case for the Case Stmt constant for which we currently +// don't offer extraction. +// +// Another special case is if we are in the condition/Increment part of a for +// statement, we must ensure that the variables in the expression being +// extracted are not declared in the init of the for statement. +static const clang::Stmt *getStmt(const SelectionTree::Node *N, + const SourceManager &M) { + auto CurN = N; + while (CurN->Parent) { + auto ParStmt = CurN->Parent->ASTNode.get(); + if (!ParStmt) { + } else if (isa(ParStmt)) { + return CurN->ASTNode.get(); + } else if (needBraces(N, ParStmt, M)) { + // We will need more functionality here later on + break; + } else if (isa(ParStmt) && !isNodeInside(ParStmt, N, M)) { + // give up if it's a Case Statement constant + break; + } else if (isa(ParStmt) && !checkFor(dyn_cast(ParStmt), + N->ASTNode.get(), M)) { + // Check whether the expression references any variable in the for + // initializer and if so, we can't extract + break; + } + CurN = CurN->Parent; + } + return nullptr; +} +// returns the replacement for substituting Exp with VarName +tooling::Replacement replaceExpression(std::string VarName, const Expr *Exp, + const ASTContext &Ctx) { + auto &M = Ctx.getSourceManager(); + auto ExpRng = + toHalfOpenFileRange(M, Ctx.getLangOpts(), Exp->getSourceRange()); + auto ExpCode = toSourceCode(M, *ExpRng); + return tooling::Replacement(M, ExpRng->getBegin(), ExpCode.size(), VarName); + // insert new variable declaration + // replace expression with variable name +} + +// returns the Replacement for declaring a new variable storing the extracted +// expression +tooling::Replacement insertExtractedVar(std::string VarName, const Stmt *Stm, + const Expr *Exp, + const ASTContext &Ctx) { + auto &M = Ctx.getSourceManager(); + auto ExpRng = + toHalfOpenFileRange(M, Ctx.getLangOpts(), Exp->getSourceRange()); + auto ExpCode = toSourceCode(M, *ExpRng); + auto StmRng = + toHalfOpenFileRange(M, Ctx.getLangOpts(), Stm->getSourceRange()); + std::string ExtractedVarDecl = + std::string("auto ") + VarName + " = " + ExpCode.str() + "; "; + return tooling::Replacement(M, StmRng->getBegin(), 0, ExtractedVarDecl); +} + +// FIXME(Bug in selection tree): doesn't work for int a = 6, b = 8; +// FIXME: case constant refactoring +// FIXME: if, while, else, for, case, default statements without curly braces +// and their combinations +// FIXME: Ignore assignment (a = 1) Expr since it is extracted as dummy = a = 1 +bool ExtractVariable::prepare(const Selection &Inputs) { + const auto &M = Inputs.AST.getSourceManager(); + const SelectionTree::Node *N = Inputs.ASTSelection.commonAncestor(); + if (!N) + return false; + Exp = N->ASTNode.get(); + if (!Exp) + return false; + return (Stm = getStmt(N, M)); +} + +Expected +ExtractVariable::apply(const Selection &Inputs) { + auto &Ctx = Inputs.AST.getASTContext(); + tooling::Replacements Result; + if (auto Err = Result.add(insertExtractedVar("dummy", Stm, Exp, Ctx))) + return std::move(Err); + if (auto Err = Result.add(replaceExpression("dummy", Exp, Ctx))) + return std::move(Err); + return Result; +} + +std::string ExtractVariable::title() const { + return "Extract to dummy variable"; +} + +} // namespace +} // namespace clangd +} // namespace clang diff --git a/clang-tools-extra/clangd/unittests/TweakTests.cpp b/clang-tools-extra/clangd/unittests/TweakTests.cpp --- a/clang-tools-extra/clangd/unittests/TweakTests.cpp +++ b/clang-tools-extra/clangd/unittests/TweakTests.cpp @@ -216,6 +216,208 @@ checkTransform(ID, Input, Output); } +TEST(TweakTest, ExtractVariable) { + llvm::StringLiteral ID = "ExtractVariable"; + checkAvailable(ID, R"cpp( + int xyz() { + return 1; + } + void f() { + int a = 5 + [[4 * [[^xyz()]]]]; + // FIXME: add test case for multiple variable initialization once + // SelectionTree commonAncestor bug is fixed + switch(a) { + case 1: { + a = ^1; + break; + } + default: { + a = ^3; + } + } + if(^1) {} + if(a < ^3) + if(a == 4) + a = 5; + else + a = 6; + else if (a < 4) { + a = ^4; + } + else { + a = ^5; + } + // for loop testing + for(a = ^1; a > ^3+^4; a++) + a = 2; + // while testing + while(a < ^1) { + ^a++; + } + // do while testing + do + a = 1; + while(a < ^3); + } + )cpp"); + checkNotAvailable(ID, R"cpp( + void f(int b = ^1) { + int a = 5 + 4 * 3; + // switch testing + switch(a) { + case 1: + a = ^1; + break; + default: + a = ^3; + } + // if testing + if(a < 3) + if(a == ^4) + a = ^5; + else + a = ^6; + else if (a < ^4) { + a = 4; + } + else { + a = 5; + } + // for loop testing + for(int a = 1, b = 2, c = 3; ^a > ^b ^+ ^c; ^a++) + a = ^2; + // while testing + while(a < 1) { + a++; + } + // do while testing + do + a = ^1; + while(a < 3); + // testing in cases where braces are required + if (true) + do + a = 1; + while(a < ^1); + } + )cpp"); + // vector of pairs of input and output strings + const std::vector> + InputOutputs = { + // extraction from variable declaration/assignment + {R"cpp(void varDecl() { + int a = 5 * (4 + (3 [[- 1)]]); + })cpp", + R"cpp(void varDecl() { + auto dummy = (3 - 1); int a = 5 * (4 + dummy); + })cpp"}, + // extraction from for loop init/cond/incr + {R"cpp(void forLoop() { + for(int a = 1; a < ^3; a++) { + a = 5 + 4 * 3; + } + })cpp", + R"cpp(void forLoop() { + auto dummy = 3; for(int a = 1; a < dummy; a++) { + a = 5 + 4 * 3; + } + })cpp"}, + // extraction inside for loop body + {R"cpp(void forBody() { + for(int a = 1; a < 3; a++) { + a = 5 + [[4 * 3]]; + } + })cpp", + R"cpp(void forBody() { + for(int a = 1; a < 3; a++) { + auto dummy = 4 * 3; a = 5 + dummy; + } + })cpp"}, + // extraction inside while loop condition + {R"cpp(void whileLoop(int a) { + while(a < 5 + [[4 * 3]]) + a += 1; + })cpp", + R"cpp(void whileLoop(int a) { + auto dummy = 4 * 3; while(a < 5 + dummy) + a += 1; + })cpp"}, + // extraction inside while body condition + {R"cpp(void whileBody(int a) { + while(a < 1) { + a += ^7 * 3; + } + })cpp", + R"cpp(void whileBody(int a) { + while(a < 1) { + auto dummy = 7; a += dummy * 3; + } + })cpp"}, + // extraction inside do-while loop condition + {R"cpp(void doWhileLoop(int a) { + do + a += 3; + while(a < ^1); + })cpp", + R"cpp(void doWhileLoop(int a) { + auto dummy = 1; do + a += 3; + while(a < dummy); + })cpp"}, + // extraction inside do-while body + {R"cpp(void doWhileBody(int a) { + do { + a += ^3; + } + while(a < 1); + })cpp", + R"cpp(void doWhileBody(int a) { + do { + auto dummy = 3; a += dummy; + } + while(a < 1); + })cpp"}, + // extraction inside switch condition + {R"cpp(void switchLoop(int a) { + switch(a = 1 + [[3 * 5]]) { + default: + break; + } + })cpp", + R"cpp(void switchLoop(int a) { + auto dummy = 3 * 5; switch(a = 1 + dummy) { + default: + break; + } + })cpp"}, + // extraction inside case body + {R"cpp(void caseBody(int a) { + switch(1) { + case 1: { + a = ^1; + break; + } + default: + break; + } + })cpp", + R"cpp(void caseBody(int a) { + switch(1) { + case 1: { + auto dummy = 1; a = dummy; + break; + } + default: + break; + } + })cpp"}, + }; + for (const auto &IO : InputOutputs) { + checkTransform(ID, IO.first, IO.second); + } + +} + } // namespace } // namespace clangd } // namespace clang