diff --git a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt --- a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt +++ b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt @@ -16,6 +16,7 @@ DumpAST.cpp ExpandAutoType.cpp ExpandMacro.cpp + ExtractFunction.cpp ExtractVariable.cpp RawStringLiteral.cpp SwapIfBranches.cpp diff --git a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp new file mode 100644 --- /dev/null +++ b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp @@ -0,0 +1,285 @@ +//===--- ExtractFunction.cpp ------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "ClangdUnit.h" +#include "Logger.h" +#include "Selection.h" +#include "SourceCode.h" +#include "refactor/Tweak.h" +#include "clang/AST/ASTContext.h" +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/AST/Stmt.h" +#include "clang/Basic/LangOptions.h" +#include "clang/Basic/SourceLocation.h" +#include "clang/Basic/SourceManager.h" +#include "clang/Lex/Lexer.h" +#include "clang/Tooling/Core/Replacement.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Error.h" + +namespace clang { +namespace clangd { +namespace { + +using Node = SelectionTree::Node; +using std::shared_ptr; + +class ExtractionTarget { +public: + ExtractionTarget(const Node *N, const SourceManager &SM, + const ASTContext &Ctx); + + // Check if the location is part of the TargetContext and before Target. + bool isInPreTarget(SourceLocation Loc) const { + return !isOutsideTargetContext(Loc) && Loc < TargetRng.getBegin(); + } + + // Checks whether the Location is within the Target. + bool isInTarget(SourceLocation Loc) const { + return SM.isPointWithin(Loc, TargetRng.getBegin(), TargetRng.getEnd()); + } + + // Check if the location is part of the TargetContext and after Target. + bool isInPostTarget(SourceLocation Loc) const { + return !isOutsideTargetContext(Loc) && TargetRng.getEnd() < Loc; + } + // Check if the location is outside the TargetContext. + bool isOutsideTargetContext(SourceLocation Loc) const { + return !SM.isPointWithin(Loc, TargetContextRng.getBegin(), + TargetContextRng.getEnd()); + } + bool isExtractable() { return TargetRng.isValid() && TargetContext; } + + tooling::Replacement replaceWithFuncCall(llvm::StringRef FuncName) const; + tooling::Replacement createFunctionDefinition(llvm::StringRef FuncName) const; + // The function inside which our target resides. + const FunctionDecl *TargetContext; + +private: + const Node *CommonAnc; + const SourceManager &SM; + const ASTContext &Ctx; + // The range of the code being extracted. + SourceRange TargetRng; + SourceRange TargetContextRng; + + SourceRange computeTargetRange() const; +}; + +const FunctionDecl *computeTargetContext(const Node *CommonAnc) { + // Walk up the SelectionTree until we find a function Decl + for (const Node *CurNode = CommonAnc; CurNode; CurNode = CurNode->Parent) { + if (const FunctionDecl *FD = CurNode->ASTNode.get()) { + // FIXME: Support extraction from methods. + if (CurNode->ASTNode.get()) + return nullptr; + return FD; + } + } + return nullptr; +} + +ExtractionTarget::ExtractionTarget(const Node *CommonAnc, + const SourceManager &SM, + const ASTContext &Ctx) + : CommonAnc(CommonAnc), SM(SM), Ctx(Ctx) { + TargetRng = computeTargetRange(); + if ((TargetContext = computeTargetContext(CommonAnc))) + TargetContextRng = + TargetContext->getSourceRange(); // TODO: Use toHalfOpenFileRange? +} + +// TODO: Describe how this works +SourceRange ExtractionTarget::computeTargetRange() const { + const LangOptions &LangOpts = Ctx.getLangOpts(); + if (!CommonAnc) + return SourceRange(); + if (CommonAnc->Selected == SelectionTree::Selection::Complete) + return *toHalfOpenFileRange(SM, LangOpts, + CommonAnc->ASTNode.getSourceRange()); + // FIXME: Bail out when it's partial selected. Wait for selectiontree + // semicolon fix. + if (CommonAnc->Selected == SelectionTree::Selection::Partial && + !CommonAnc->ASTNode.get()) + return SourceRange(); + SourceRange SR; + for (const Node *Child : CommonAnc->Children) { + // We don't want to extract a partially selected child. + if (Child->Selected == SelectionTree::Selection::Partial || + (Child->Selected == SelectionTree::Selection::Unselected && + !Child->Children.empty())) + return SourceRange(); + // Completely selected child + if (SR.isInvalid()) + SR = *toHalfOpenFileRange(SM, LangOpts, Child->ASTNode.getSourceRange()); + else + SR.setEnd( + toHalfOpenFileRange(SM, LangOpts, Child->ASTNode.getSourceRange()) + ->getEnd()); + } + return SR; +} + +// TODO: Add support for function arguments as well as assigning function return +// value. +tooling::Replacement +ExtractionTarget::replaceWithFuncCall(llvm::StringRef FuncName) const { + std::string FuncCall = FuncName.str() + "()"; + return tooling::Replacement(SM, CharSourceRange(TargetRng, false), FuncCall, + Ctx.getLangOpts()); +} + +// TODO: add support for adding function parameters and other function return +// types besides void. +tooling::Replacement +ExtractionTarget::createFunctionDefinition(llvm::StringRef FuncName) const { + std::string NewFunction = "void " + FuncName.str() + "() {" + + toSourceCode(SM, TargetRng).str() + "}"; + return tooling::Replacement(SM, TargetContextRng.getBegin(), 0, NewFunction); +} + +// We use the ASTVisitor instead of using the selection tree since we need to +// find references in the Post-Target as well. +// FIXME: Check which statements we don't allow to extract. +class TargetAnalyzer : public clang::RecursiveASTVisitor { +public: + enum LocType { PRETARGET, TARGET, POSTTARGET, OUTSIDECONTEXT }; + struct Reference { + Decl *TheDecl = nullptr; + LocType DeclaredIn; + bool IsReferencedInPostTarget = false; + bool IsAssigned = false; + bool MaybeModifiedOutside = false; + Reference(){}; + Reference(Decl *TheDecl, LocType DeclaredIn) + : TheDecl(TheDecl), DeclaredIn(DeclaredIn){}; + }; + // True if target has a reference to it's extraction context + bool HasRecursiveCallToContext = false; + shared_ptr Target; + + TargetAnalyzer(shared_ptr Target) : Target(Target) { + TraverseDecl(const_cast(Target->TargetContext)); + } + // Get the location type for the given Decl. + LocType getLocType(SourceLocation Loc); + // Return reference for a Decl, adding it if not already present. + Reference &getReferenceFor(VarDecl *D); + // Adds/updates the reference corresponding to the Decl of the DeclRefExpr. + void updateReferenceFor(DeclRefExpr *DRE); + // Add/update the Reference for a DeclRefExpr + bool VisitDeclRefExpr(DeclRefExpr *DRE); // NOLINT + // check whether the DRE refers to the TargetContext. + void checkIfRecursiveCall(DeclRefExpr *DRE); + bool isTooComplexToExtract(); + +private: + llvm::DenseMap References; +}; + +TargetAnalyzer::Reference &TargetAnalyzer::getReferenceFor(VarDecl *D) { + assert(D && "D shouldn't be null!"); + // Don't add Decls that are outside the extraction context or in PostTarget. + if (References.find(D) == References.end()) + References.insert({D, Reference(D, getLocType(D->getBeginLoc()))}); + return References[D]; +} + +void TargetAnalyzer::checkIfRecursiveCall(DeclRefExpr *DRE) { + if (Target->isInTarget(DRE->getBeginLoc()) && + DRE->getDecl() == Target->TargetContext) + HasRecursiveCallToContext = true; +} + +void TargetAnalyzer::updateReferenceFor(DeclRefExpr *DRE) { + VarDecl *D = dyn_cast_or_null(DRE->getDecl()); + if (!D) + return; + LocType DRELocType = getLocType(DRE->getBeginLoc()); + LocType DeclLocType = getLocType(D->getBeginLoc()); + // Ensure we only add Decl where the DeclRef is in the next block. + // Note that DRELocType is never OUTSIDECONTEXT. + if (DeclLocType + 1 != DRELocType) + return; + Reference &Ref = getReferenceFor(D); + if (DRELocType == POSTTARGET) + Ref.IsReferencedInPostTarget = true; +} + +bool TargetAnalyzer::isTooComplexToExtract() { + // If there is any reference or target has a recursive call, we don't extract + // for now. + return !References.empty() || HasRecursiveCallToContext; +} +bool TargetAnalyzer::VisitDeclRefExpr(DeclRefExpr *DRE) { // NOLINT + checkIfRecursiveCall(DRE); + updateReferenceFor(DRE); + return true; +} + +TargetAnalyzer::LocType TargetAnalyzer::getLocType(SourceLocation Loc) { + // FIXME: isInPreTarget, isInTarget and isInPostTarget all check if the Loc is + // outsideContext. + if (Target->isInPreTarget(Loc)) + return PRETARGET; + if (Target->isInTarget(Loc)) + return TARGET; + if (Target->isInPostTarget(Loc)) + return POSTTARGET; + return OUTSIDECONTEXT; +} + +/// Extracts statements to a new function and replaces the statements with a +/// call to the new function. +class ExtractFunction : public Tweak { +public: + const char *id() const override final; + + bool prepare(const Selection &Inputs) override; + Expected apply(const Selection &Inputs) override; + std::string title() const override { return "Extract to function"; } + Intent intent() const override { return Refactor; } + shared_ptr Target; + +private: +}; + +REGISTER_TWEAK(ExtractFunction) + +bool ExtractFunction::prepare(const Selection &Inputs) { + const Node *CommonAnc = Inputs.ASTSelection.commonAncestor(); + const SourceManager &SM = Inputs.AST.getSourceManager(); + const ASTContext &Ctx = Inputs.AST.getASTContext(); + Target = std::make_shared(CommonAnc, SM, Ctx); + return Target->isExtractable(); +} + +Expected ExtractFunction::apply(const Selection &Inputs) { + TargetAnalyzer Analyzer(Target); + if(Analyzer.isTooComplexToExtract()) + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "Too complex to extract."); + tooling::Replacements Result; + // FIXME: get variable name from user or suggest based on type + std::string FuncName = "dummyFunc"; + // insert new variable declaration + if (auto Err = Result.add(Target->createFunctionDefinition(FuncName))) + return std::move(Err); + // replace expression with variable name + if (auto Err = Result.add(Target->replaceWithFuncCall(FuncName))) + return std::move(Err); + return Effect::applyEdit(Result); +} + +} // namespace +} // namespace clangd +} // namespace clang