diff --git a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt --- a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt +++ b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt @@ -15,6 +15,7 @@ DumpAST.cpp RawStringLiteral.cpp SwapIfBranches.cpp + ExtractVariable.cpp LINK_LIBS clangAST diff --git a/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp b/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp new file mode 100644 --- /dev/null +++ b/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp @@ -0,0 +1,275 @@ +//===--- ExtractVariable.cpp ------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "ClangdUnit.h" +#include "Logger.h" +#include "Protocol.h" +#include "Selection.h" +#include "SourceCode.h" +#include "refactor/Tweak.h" +#include "clang/AST/ASTContext.h" +#include "clang/AST/Expr.h" +#include "clang/AST/OperationKinds.h" +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/AST/Stmt.h" +#include "clang/Basic/LangOptions.h" +#include "clang/Basic/SourceLocation.h" +#include "clang/Basic/SourceManager.h" +#include "clang/Tooling/Core/Replacement.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Error.h" + +namespace clang { +namespace clangd { +namespace { + +// class to store information about the Expr that is being extracted +class Extract { +public: + Extract(const clang::Expr *Expr, const SelectionTree::Node *Node); + // checks if any variable in our Expr references the declarations in Scope + bool referencesLocalDecls(const DeclStmt *Scope, + const SourceManager &SM) const; + const clang::Expr *getExpr() const; + const SelectionTree::Node *getNode() const; + bool needBraces(const Stmt *InsertionPoint, const SourceManager &SM) const; + bool isSubExprOf(const clang::Stmt *Body, const SourceManager &SM) const; + bool extractionAllowed(const Stmt *InsertionPoint, + const SourceManager &SM) const; +private: + const clang::Expr *Expr; + const SelectionTree::Node *Node; + std::vector ReferencedDecls; + std::vector getReferencedDecls(); + bool checkForStmt(const ForStmt *F, const SourceManager &SM) const; + bool checkDeclStmt(const DeclStmt *D, const SourceManager &SM) const; +}; +Extract::Extract(const clang::Expr *Expr, const SelectionTree::Node *Node) + : Expr(Expr), Node(Node) { + ReferencedDecls = getReferencedDecls(); +} +const clang::Expr *Extract::getExpr() const { return Expr; } +const SelectionTree::Node *Extract::getNode() const { return Node; } +bool Extract::referencesLocalDecls(const DeclStmt *Scope, + const SourceManager &SM) const { + for (const clang::Decl *ReferencedDecl : ReferencedDecls) { + // Beginning location of the ValueDecl of the DeclRef + SourceLocation DeclLoc = ReferencedDecl->getBeginLoc(); + // return true if this ValueDecl location is within the DeclStmt + if (SM.isPointWithin(DeclLoc, Scope->getBeginLoc(), Scope->getEndLoc())) + return true; + } + return false; +} +std::vector Extract::getReferencedDecls() { + // RAV subclass to find all DeclRefs in a given Stmt + class FindDeclRefsVisitor + : public clang::RecursiveASTVisitor { + public: + std::vector ReferencedDecls; + bool VisitDeclRefExpr(DeclRefExpr *DeclRef) { // NOLINT + ReferencedDecls.push_back(DeclRef->getDecl()); + return true; + } + }; + FindDeclRefsVisitor Visitor; + Visitor.TraverseStmt(const_cast(dyn_cast(Expr))); + return Visitor.ReferencedDecls; +} +// check whether the Expr is a part of the Body Stmt +bool Extract::isSubExprOf(const clang::Stmt *Body, + const SourceManager &SM) const { + SourceRange BodyRng = Body->getSourceRange(); + SourceRange TargetRng = Node->ASTNode.getSourceRange(); + return SM.isPointWithin(TargetRng.getBegin(), BodyRng.getBegin(), + BodyRng.getEnd()); +} +// Returns true if we will need braces after extraction +bool Extract::needBraces(const Stmt *InsertionPoint, + const SourceManager &SM) const { + llvm::SmallVector Bodies; + if (const ForStmt *Stmt = llvm::dyn_cast_or_null(InsertionPoint)) + Bodies.push_back(Stmt->getBody()); + else if (const WhileStmt *Stmt = + llvm::dyn_cast_or_null(InsertionPoint)) + Bodies.push_back(Stmt->getBody()); + else if (const DoStmt *Stmt = llvm::dyn_cast_or_null(InsertionPoint)) + Bodies.push_back(Stmt->getBody()); + else if (const CaseStmt *Stmt = + llvm::dyn_cast_or_null(InsertionPoint)) + Bodies.push_back(Stmt->getSubStmt()); + else if (const DefaultStmt *Stmt = + llvm::dyn_cast_or_null(InsertionPoint)) + Bodies.push_back(Stmt->getSubStmt()); + else if (const IfStmt *Stmt = llvm::dyn_cast_or_null(InsertionPoint)) + Bodies.insert(Bodies.end(), {Stmt->getThen(), Stmt->getElse()}); + for (const clang::Stmt *Body : Bodies) + if (Body && isSubExprOf(Body, SM)) + return true; + return false; +} +// check whether to allow extraction from for(...) +// if we are in the condition/Increment part of a for statement, we must ensure +// that the variables in the expression being extracted are not declared in the +// init of the for statement. +bool Extract::checkForStmt(const ForStmt *F, const SourceManager &SM) const { + const Stmt *Init = F->getInit(); + if (!Init || !isa(Init)) + return true; // FIXME + // if any variable in Expr is declared in the DeclStmt, return false + return !referencesLocalDecls(dyn_cast(Init), SM); +} +// check whether to allow extraction from a Declaration +bool Extract::checkDeclStmt(const DeclStmt *D, const SourceManager &SM) const { + return !referencesLocalDecls(D, SM); +} +// checks whether extracting before InsertionPoint will take a +// variable out of scope +bool Extract::extractionAllowed(const Stmt *InsertionPoint, + const SourceManager &SM) const { + if (isa(InsertionPoint) && + !checkForStmt(dyn_cast(InsertionPoint), SM)) + return false; + if (isa(InsertionPoint) && + !checkDeclStmt(dyn_cast(InsertionPoint), SM)) + return false; + return true; +} +/// Extracts an expression to the variable dummy +/// Before: +/// int x = 5 + 4 * 3; +/// ^^^^^ +/// After: +/// auto dummy = 5 + 4; +/// int x = dummy * 3; +class ExtractVariable : public Tweak { +public: + const char *id() const override final; + bool prepare(const Selection &Inputs) override; + Expected apply(const Selection &Inputs) override; + std::string title() const override; + Intent intent() const override { return Refactor; } + std::vector ReferencedDecls; + +private: + // the expression to extract + const Extract *Target = nullptr; + // the statement before which variable will be extracted to + const clang::Stmt *InsertionPoint = nullptr; +}; +REGISTER_TWEAK(ExtractVariable) + +// Return the Stmt before which we need to insert the extraction. +// To find the Stmt, we go up the AST Tree and if the Parent of the current Stmt +// is a CompoundStmt, we can extract inside this CompoundStmt just before the +// current Stmt. We ALWAYS insert before a Stmt whose parent is a CompoundStmt +// +// Otherwise if we encounter an if/while/do-while/for/case Stmt, we must check +// whether we are in their "body" or the "condition" part. If we are in the +// "body", that means we need to insert braces i.e. create a CompoundStmt. For +// now, we don't allow extraction if we need to insert braces. Else if we are in +// the "condition" part, we can extract before the if/while/do-while/for/case +// Stmt. Remember that we only insert before a stmt if its Parent Stmt is a +// CompoundStmt, thus we continue looping to find a CompoundStmt. +// +// We have a special case for the Case Stmt constant for which we currently +// don't offer extraction. +// +// we also check if doing the extraction will take a variable out of scope +static const clang::Stmt *getInsertionPoint(const Extract *Target, + const SourceManager &SM) { + + for (const SelectionTree::Node *CurNode = Target->getNode(); CurNode->Parent; + CurNode = CurNode->Parent) { + const clang::Stmt *Ancestor = CurNode->Parent->ASTNode.get(); + if (Ancestor) { + if (isa(Ancestor)) { + return CurNode->ASTNode.get(); + } + // We will need more functionality here later on + if (Target->needBraces(Ancestor, SM)) + break; + // give up if it's a Case Statement constant + if (isa(Ancestor) && !Target->isSubExprOf(Ancestor, SM)) + break; + // give up if extraction will take a variable out of scope + if (!Target->extractionAllowed(Ancestor, SM)) + break; + } + } + return nullptr; +} +// returns the replacement for substituting Expr with VarName +tooling::Replacement replaceExpression(std::string VarName, + const Extract *Target, + const ASTContext &Ctx) { + const SourceManager &SM = Ctx.getSourceManager(); + const llvm::Optional ExpRng = toHalfOpenFileRange( + SM, Ctx.getLangOpts(), Target->getExpr()->getSourceRange()); + unsigned ExpLength = + SM.getFileOffset(ExpRng->getEnd()) - SM.getFileOffset(ExpRng->getBegin()); + return tooling::Replacement(SM, ExpRng->getBegin(), ExpLength, VarName); + // insert new variable declaration + // replace expression with variable name +} +// returns the Replacement for declaring a new variable storing the extracted +// expression +tooling::Replacement insertExtractedVar(std::string VarName, + const Stmt *InsertionPoint, + const Extract *Target, + const ASTContext &Ctx) { + const SourceManager &SM = Ctx.getSourceManager(); + const llvm::Optional ExpRng = toHalfOpenFileRange( + SM, Ctx.getLangOpts(), Target->getExpr()->getSourceRange()); + llvm::StringRef ExpCode = toSourceCode(SM, *ExpRng); + const SourceLocation InsertionLoc = + toHalfOpenFileRange(SM, Ctx.getLangOpts(), + InsertionPoint->getSourceRange()) + ->getBegin(); + // FIXME: Replace auto with explicit type and add &/&& as necessary + std::string ExtractedVarDecl = + std::string("auto ") + VarName + " = " + ExpCode.str() + "; "; + return tooling::Replacement(SM, InsertionLoc, 0, ExtractedVarDecl); +} + +// FIXME: case constant refactoring +// FIXME: if, while, else, for, case, default statements without curly braces +// and their combinations +// FIXME: Ignore assignment (a = 1) Expr since it is extracted as dummy = a = 1 +bool ExtractVariable::prepare(const Selection &Inputs) { + const SourceManager &SM = Inputs.AST.getSourceManager(); + const SelectionTree::Node *N = Inputs.ASTSelection.commonAncestor(); + if (!N) + return false; + const clang::Expr *Expr = N->ASTNode.get(); + if (!Expr) + return false; + Target = new Extract(Expr, N); + InsertionPoint = getInsertionPoint(Target, SM); + return InsertionPoint != nullptr; +} + +Expected ExtractVariable::apply(const Selection &Inputs) { + ASTContext &Ctx = Inputs.AST.getASTContext(); + tooling::Replacements Result; + if (auto Err = + Result.add(insertExtractedVar("dummy", InsertionPoint, Target, Ctx))) + return std::move(Err); + if (auto Err = Result.add(replaceExpression("dummy", Target, Ctx))) + return std::move(Err); + return Effect::applyEdit(Result); +} +std::string ExtractVariable::title() const { + return "Extract subexpression to variable"; +} + +} // namespace +} // namespace clangd +} // namespace clang diff --git a/clang-tools-extra/clangd/unittests/TweakTests.cpp b/clang-tools-extra/clangd/unittests/TweakTests.cpp --- a/clang-tools-extra/clangd/unittests/TweakTests.cpp +++ b/clang-tools-extra/clangd/unittests/TweakTests.cpp @@ -239,13 +239,13 @@ checkNotAvailable(ID, "/*c^omment*/ int foo() return 2 ^ + 2; }"); const char *Input = "int x = 2 ^+ 2;"; - auto result = getMessage(ID, Input); - EXPECT_THAT(result, ::testing::HasSubstr("BinaryOperator")); - EXPECT_THAT(result, ::testing::HasSubstr("'+'")); - EXPECT_THAT(result, ::testing::HasSubstr("|-IntegerLiteral")); - EXPECT_THAT(result, + auto Result = getMessage(ID, Input); + EXPECT_THAT(Result, ::testing::HasSubstr("BinaryOperator")); + EXPECT_THAT(Result, ::testing::HasSubstr("'+'")); + EXPECT_THAT(Result, ::testing::HasSubstr("|-IntegerLiteral")); + EXPECT_THAT(Result, ::testing::HasSubstr(" 'int' 2\n`-IntegerLiteral")); - EXPECT_THAT(result, ::testing::HasSubstr(" 'int' 2")); + EXPECT_THAT(Result, ::testing::HasSubstr(" 'int' 2")); } TEST(TweakTest, ShowSelectionTree) { @@ -277,6 +277,209 @@ const char *Input = "struct ^X { int x; int y; }"; EXPECT_THAT(getMessage(ID, Input), ::testing::HasSubstr("0 | int x")); } +TEST(TweakTest, ExtractVariable) { + llvm::StringLiteral ID = "ExtractVariable"; + checkAvailable(ID, R"cpp( + int xyz() { + return 1; + } + void f() { + int a = 5 + [[4 * [[^xyz()]]]]; + int x = ^1, y = x + 1, z = ^1; + switch(a) { + case 1: { + a = ^1; + break; + } + default: { + a = ^3; + } + } + // if testing + if(^1) {} + if(a < ^3) + if(a == 4) + a = 5; + else + a = 6; + else if (a < 4) { + a = ^4; + } + else { + a = ^5; + } + // for loop testing + for(a = ^1; a > ^3+^4; a++) + a = 2; + // while testing + while(a < ^1) { + ^a++; + } + // do while testing + do + a = 1; + while(a < ^3); + } + )cpp"); + checkNotAvailable(ID, R"cpp( + void f(int b = ^1) { + int a = 5 + 4 * 3; + // check whether extraction breaks scope + int x = 1, y = ^x + 1; + // switch testing + switch(a) { + case 1: + a = ^1; + break; + default: + a = ^3; + } + // if testing + if(a < 3) + if(a == ^4) + a = ^5; + else + a = ^6; + else if (a < ^4) { + a = 4; + } + else { + a = 5; + } + // for loop testing + for(int a = 1, b = 2, c = 3; ^a > ^b ^+ ^c; ^a++) + a = ^2; + // while testing + while(a < 1) { + a++; + } + // do while testing + do + a = ^1; + while(a < 3); + // testing in cases where braces are required + if (true) + do + a = 1; + while(a < ^1); + } + )cpp"); + // vector of pairs of input and output strings + const std::vector> + InputOutputs = { + // extraction from variable declaration/assignment + {R"cpp(void varDecl() { + int a = 5 * (4 + (3 [[- 1)]]); + })cpp", + R"cpp(void varDecl() { + auto dummy = (3 - 1); int a = 5 * (4 + dummy); + })cpp"}, + // extraction from for loop init/cond/incr + {R"cpp(void forLoop() { + for(int a = 1; a < ^3; a++) { + a = 5 + 4 * 3; + } + })cpp", + R"cpp(void forLoop() { + auto dummy = 3; for(int a = 1; a < dummy; a++) { + a = 5 + 4 * 3; + } + })cpp"}, + // extraction inside for loop body + {R"cpp(void forBody() { + for(int a = 1; a < 3; a++) { + a = 5 + [[4 * 3]]; + } + })cpp", + R"cpp(void forBody() { + for(int a = 1; a < 3; a++) { + auto dummy = 4 * 3; a = 5 + dummy; + } + })cpp"}, + // extraction inside while loop condition + {R"cpp(void whileLoop(int a) { + while(a < 5 + [[4 * 3]]) + a += 1; + })cpp", + R"cpp(void whileLoop(int a) { + auto dummy = 4 * 3; while(a < 5 + dummy) + a += 1; + })cpp"}, + // extraction inside while body condition + {R"cpp(void whileBody(int a) { + while(a < 1) { + a += ^7 * 3; + } + })cpp", + R"cpp(void whileBody(int a) { + while(a < 1) { + auto dummy = 7; a += dummy * 3; + } + })cpp"}, + // extraction inside do-while loop condition + {R"cpp(void doWhileLoop(int a) { + do + a += 3; + while(a < ^1); + })cpp", + R"cpp(void doWhileLoop(int a) { + auto dummy = 1; do + a += 3; + while(a < dummy); + })cpp"}, + // extraction inside do-while body + {R"cpp(void doWhileBody(int a) { + do { + a += ^3; + } + while(a < 1); + })cpp", + R"cpp(void doWhileBody(int a) { + do { + auto dummy = 3; a += dummy; + } + while(a < 1); + })cpp"}, + // extraction inside switch condition + {R"cpp(void switchLoop(int a) { + switch(a = 1 + [[3 * 5]]) { + default: + break; + } + })cpp", + R"cpp(void switchLoop(int a) { + auto dummy = 3 * 5; switch(a = 1 + dummy) { + default: + break; + } + })cpp"}, + // extraction inside case body + {R"cpp(void caseBody(int a) { + switch(1) { + case 1: { + a = ^1; + break; + } + default: + break; + } + })cpp", + R"cpp(void caseBody(int a) { + switch(1) { + case 1: { + auto dummy = 1; a = dummy; + break; + } + default: + break; + } + })cpp"}, + }; + for (const auto &IO : InputOutputs) { + checkTransform(ID, IO.first, IO.second); + } + +} } // namespace } // namespace clangd