diff --git a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt --- a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt +++ b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt @@ -14,6 +14,7 @@ add_clang_library(clangDaemonTweaks OBJECT RawStringLiteral.cpp SwapIfBranches.cpp + ExtractVariable.cpp LINK_LIBS clangAST diff --git a/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp b/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp new file mode 100644 --- /dev/null +++ b/clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp @@ -0,0 +1,193 @@ +//===--- ExtractVariable.cpp ------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "ClangdUnit.h" +#include "Logger.h" +#include "Protocol.h" +#include "Selection.h" +#include "SourceCode.h" +#include "refactor/Tweak.h" +#include "clang/AST/ASTContext.h" +#include "clang/AST/Expr.h" +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/AST/Stmt.h" +#include "clang/Basic/LangOptions.h" +#include "clang/Basic/SourceLocation.h" +#include "clang/Basic/SourceManager.h" +#include "clang/Tooling/Core/Replacement.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Error.h" + +namespace clang { +namespace clangd { +namespace { + +// +/// Extracts an expression to the variable dummy +/// Before: +/// int x = 5 + 4 * 3; +/// ^^^^^ +/// After: +/// auto dummy = 5 + 4; +/// int x = dummy * 3; +class ExtractVariable : public Tweak { +public: + const char *id() const override final; + bool prepare(const Selection &Inputs) override; + Expected apply(const Selection &Inputs) override; + std::string title() const override; + +private: + // the expression to extract + const clang::Expr *Exp = nullptr; + // the statement before which variable will be extracted to + const clang::Stmt *Stm = nullptr; +}; +REGISTER_TWEAK(ExtractVariable) + +// RAV subclass to find all DeclRefs in a given Stmt +class FindDeclRefsVisitor + : public clang::RecursiveASTVisitor { +public: + FindDeclRefsVisitor(const SourceManager &M) : M(&M){}; + + // checks whether any variable in a given expr is declared in the DeclStmt + bool isDeclaredIn(const DeclStmt *Decl, const Expr *Exp) { + this->Decl = Decl; + // if complete traversal is unsuccessful, return true + return !TraverseStmt(const_cast(dyn_cast(Exp))); + } + // Visit each DeclRefExpr and check whether it was declared in Decl + bool VisitDeclRefExpr(DeclRefExpr *DeclRef) { // NOLINT + // Beginning location of the ValueDecl of the DeclRef + auto ValueDeclLoc = DeclRef->getDecl()->getBeginLoc(); + // return false if this ValueDecl location is within the DeclStmt + return !M->isPointWithin(ValueDeclLoc, Decl->getBeginLoc(), + Decl->getEndLoc()); + } + +private: + const DeclStmt *Decl; + const SourceManager *M; +}; + +// Here T is a Stmt subclass and getBody is a member function that returns its +// "body" Stmt. +// Checks whether N is inside the body of Stm. +template +static bool isInBodyOf(const SelectionTree::Node *N, const Stmt *Stm, + const SourceManager &M) { + auto CastStm = dyn_cast_or_null(Stm); + if (!CastStm) + return false; // return false if can't be casted + auto BodyRng = (CastStm->*getBody)()->getSourceRange(); + auto NodeRng = N->ASTNode.getSourceRange(); + return M.isPointWithin(NodeRng.getBegin(), BodyRng.getBegin(), + BodyRng.getEnd()); +} +// Returns true if we will need braces after extraction +static bool needBraces(const SelectionTree::Node *N, const Stmt *Stm, + const SourceManager &M) { + return isInBodyOf(N, Stm, M) || + isInBodyOf(N, Stm, M) || + isInBodyOf(N, Stm, M) || + // for an if statement we have two bodies - Then and Else + isInBodyOf(N, Stm, M) || + isInBodyOf(N, Stm, M) || + isInBodyOf(N, Stm, M) || + isInBodyOf(N, Stm, M); +} + +// check whether to allow extraction from for(...) +static bool checkFor(const ForStmt *F, const Expr *Exp, + const SourceManager &M) { + const Stmt *Init = F->getInit(); + if (!Init || !isa(Init)) + return true; // FIXME + FindDeclRefsVisitor Visitor(M); + // if any variable in Exp is declared in the DeclStmt, return false + return !Visitor.isDeclaredIn(dyn_cast(Init), Exp); +} +// +// Return the Stmt before which we need to insert the extraction. +static const clang::Stmt *getStmt(const SelectionTree::Node *N, + const SourceManager &M) { + auto CurN = N; + while (CurN->Parent) { + auto ParStmt = CurN->Parent->ASTNode.get(); + if (!ParStmt) { + } else if (isa(ParStmt)) { + return CurN->ASTNode.get(); + } else if (needBraces(N, ParStmt, M)) { + // We will need more functionality here later on + llvm::errs() << "Need braces.\n"; + break; + } else if (isa(ParStmt) && + !isInBodyOf(N, ParStmt, M)) { + // give up if it's a Case Statement constant + break; + } else if (isa(ParStmt) && !checkFor(dyn_cast(ParStmt), + N->ASTNode.get(), M)) { + // Check whether the expression references any variable in the for + // initializer and if so, we can't extract + llvm::errs() << "Can't extract from for\n"; + break; + } + CurN = CurN->Parent; + } + return nullptr; +} + +// FIXME(Bug in selection tree): doesn't work for int a = 6, b = 8; +// FIXME: case constant refactoring +// FIXME: if, while, else, for, case, default statements without curly braces +// and their combinations +// FIXME: Ignore assignment (a = 1) Expr since it is extracted as dummy = a = 1 +bool ExtractVariable::prepare(const Selection &Inputs) { + const auto &M = Inputs.AST.getSourceManager(); + const SelectionTree::Node *N = Inputs.ASTSelection.commonAncestor(); + if (!N) + return false; + Exp = N->ASTNode.get(); + if (!Exp) + return false; + return (Stm = getStmt(N, M)); +} + +Expected +ExtractVariable::apply(const Selection &Inputs) { + auto &Ctx = Inputs.AST.getASTContext(); + const SourceManager &SrcMgr = Inputs.AST.getSourceManager(); + auto StmRng = + toHalfOpenFileRange(SrcMgr, Ctx.getLangOpts(), Stm->getSourceRange()); + auto ExpRng = + toHalfOpenFileRange(SrcMgr, Ctx.getLangOpts(), Exp->getSourceRange()); + auto ExpCode = toSourceCode(SrcMgr, *ExpRng); + std::string VarName = "dummy"; + std::string ExtractedVarDecl = + std::string("auto ") + VarName + " = " + ExpCode.str() + "; "; + tooling::Replacements Result; + // insert new variable declaration + if (auto Err = Result.add(tooling::Replacement( + Ctx.getSourceManager(), StmRng->getBegin(), 0, ExtractedVarDecl))) + return std::move(Err); + // replace expression with variable name + if (auto Err = Result.add(tooling::Replacement( + Ctx.getSourceManager(), ExpRng->getBegin(), ExpCode.size(), "dummy"))) + return std::move(Err); + return Result; +} + +std::string ExtractVariable::title() const { + return "Extract to dummy variable"; +} + +} // namespace +} // namespace clangd +} // namespace clang diff --git a/clang-tools-extra/clangd/unittests/TweakTests.cpp b/clang-tools-extra/clangd/unittests/TweakTests.cpp --- a/clang-tools-extra/clangd/unittests/TweakTests.cpp +++ b/clang-tools-extra/clangd/unittests/TweakTests.cpp @@ -216,6 +216,207 @@ checkTransform(ID, Input, Output); } +TEST(TweakTest, ExtractVariable) { + llvm::StringLiteral ID = "ExtractVariable"; + + checkAvailable(ID, R"cpp( + int xyz() { + return 1; + } + void f() { + int a = 5 + [[4 * [[^xyz()]]]]; + // FIXME: add test case for multiple variable initialization once + // SelectionTree commonAncestor bug is fixed + switch(a) { + case 1: { + a = ^1; + break; + } + default: { + a = ^3; + } + } + if(a < ^3) + if(a == 4) + a = 5; + else + a = 6; + else if (a < 4) { + a = ^4; + } + else { + a = ^5; + } + // for loop testing + for(a = ^1; a > ^3+^4; a++) + a = 2; + // while testing + while(a < ^1) { + ^a++; + } + // do while testing + do + a = 1; + while(a < ^3); + } + )cpp"); + checkNotAvailable(ID, R"cpp( + void f(int b = ^1) { + int a = 5 + 4 * 3; + // switch testing + switch(a) { + case 1: + a = ^1; + break; + default: + a = ^3; + } + // if testing + if(a < 3) + if(a == ^4) + a = ^5; + else + a = ^6; + else if (a < ^4) { + a = 4; + } + else { + a = 5; + } + // for loop testing + for(int a = 1, b = 2, c = 3; ^a > ^b ^+ ^c; ^a++) + a = ^2; + // while testing + while(a < 1) { + a++; + } + // do while testing + do + a = ^1; + while(a < 3); + // testing in cases where braces are required + if (true) + do + a = 1; + while(a < ^1); + } + )cpp"); + // vector of pairs of input and output strings + const std::vector> + InputOutputs = { + // extraction from variable declaration/assignment + {R"cpp(void varDecl() { + int a = 5 * (4 + (3 [[- 1)]]); + })cpp", + R"cpp(void varDecl() { + auto dummy = (3 - 1); int a = 5 * (4 + dummy); + })cpp"}, + // extraction from for loop init/cond/incr + {R"cpp(void forLoop() { + for(int a = 1; a < ^3; a++) { + a = 5 + 4 * 3; + } + })cpp", + R"cpp(void forLoop() { + auto dummy = 3; for(int a = 1; a < dummy; a++) { + a = 5 + 4 * 3; + } + })cpp"}, + // extraction inside for loop body + {R"cpp(void forBody() { + for(int a = 1; a < 3; a++) { + a = 5 + [[4 * 3]]; + } + })cpp", + R"cpp(void forBody() { + for(int a = 1; a < 3; a++) { + auto dummy = 4 * 3; a = 5 + dummy; + } + })cpp"}, + // extraction inside while loop condition + {R"cpp(void whileLoop(int a) { + while(a < 5 + [[4 * 3]]) + a += 1; + })cpp", + R"cpp(void whileLoop(int a) { + auto dummy = 4 * 3; while(a < 5 + dummy) + a += 1; + })cpp"}, + // extraction inside while body condition + {R"cpp(void whileBody(int a) { + while(a < 1) { + a += ^7 * 3; + } + })cpp", + R"cpp(void whileBody(int a) { + while(a < 1) { + auto dummy = 7; a += dummy * 3; + } + })cpp"}, + // extraction inside do-while loop condition + {R"cpp(void doWhileLoop(int a) { + do + a += 3; + while(a < ^1); + })cpp", + R"cpp(void doWhileLoop(int a) { + auto dummy = 1; do + a += 3; + while(a < dummy); + })cpp"}, + // extraction inside do-while body + {R"cpp(void doWhileBody(int a) { + do { + a += ^3; + } + while(a < 1); + })cpp", + R"cpp(void doWhileBody(int a) { + do { + auto dummy = 3; a += dummy; + } + while(a < 1); + })cpp"}, + // extraction inside switch condition + {R"cpp(void switchLoop(int a) { + switch(a = 1 + [[3 * 5]]) { + default: + break; + } + })cpp", + R"cpp(void switchLoop(int a) { + auto dummy = 3 * 5; switch(a = 1 + dummy) { + default: + break; + } + })cpp"}, + // extraction inside case body + {R"cpp(void caseBody(int a) { + switch(1) { + case 1: { + a = ^1; + break; + } + default: + break; + } + })cpp", + R"cpp(void caseBody(int a) { + switch(1) { + case 1: { + auto dummy = 1; a = dummy; + break; + } + default: + break; + } + })cpp"}, + }; + for (const auto &IO : InputOutputs) { + checkTransform(ID, IO.first, IO.second); + } +} + } // namespace } // namespace clangd } // namespace clang