Index: lib/Tooling/Refactoring/CMakeLists.txt =================================================================== --- lib/Tooling/Refactoring/CMakeLists.txt +++ lib/Tooling/Refactoring/CMakeLists.txt @@ -4,6 +4,7 @@ ASTSelection.cpp ASTSelectionRequirements.cpp AtomicChange.cpp + Extract/CaptureVariables.cpp Extract/Extract.cpp Extract/SourceExtraction.cpp RefactoringActions.cpp Index: lib/Tooling/Refactoring/Extract/CaptureVariables.h =================================================================== --- /dev/null +++ lib/Tooling/Refactoring/Extract/CaptureVariables.h @@ -0,0 +1,75 @@ +//===--- CaptureVariables.cpp - Clang refactoring library -----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_LIB_TOOLING_REFACTORING_EXTRACT_CAPTURE_VARIABLES_H +#define LLVM_CLANG_LIB_TOOLING_REFACTORING_EXTRACT_CAPTURE_VARIABLES_H + +#include "clang/Basic/LLVM.h" +#include + +namespace clang { + +struct PrintingPolicy; +class VarDecl; + +namespace tooling { + +class CodeRangeASTSelection; + +/// Represents a variable or a declaration that can be represented using a +/// a variable that was captured in the extracted code and that should be +/// passed to the extracted function as an argument. +class CapturedExtractedEntity { + enum EntityKind { + CapturedVarDecl, + // FIXME: Field/This/Self. + }; + +public: + explicit CapturedExtractedEntity(const VarDecl *VD) + : Kind(CapturedVarDecl), VD(VD) {} + + /// Print the parameter declaration for the captured entity. + /// + /// The declaration includes the type and the name of the parameter, and + /// doesn't include the trailing comma. This is a example of the output: + /// + /// int *capture + void printParamDecl(llvm::raw_ostream &OS, const PrintingPolicy &PP) const; + + /// Print the expression that should be used as an argument value when + /// calling the extracted function. + /// + /// The expression includes the name of the captured entity whose address + /// could be taken if needed. The trailing comma is not included. This is a + /// example of the output: + /// + /// &capture + void printFunctionCallArg(llvm::raw_ostream &OS, + const PrintingPolicy &PP) const; + + StringRef getName() const; + +private: + EntityKind Kind; + union { + const VarDecl *VD; + }; + // FIXME: Track things like type & qualifiers. +}; + +/// Scans the extracted AST to determine which variables have to be captured +/// and passed to the extracted function. +std::vector +findCapturedExtractedEntities(const CodeRangeASTSelection &Code); + +} // end namespace tooling +} // end namespace clang + +#endif // LLVM_CLANG_LIB_TOOLING_REFACTORING_EXTRACT_CAPTURE_VARIABLES_H Index: lib/Tooling/Refactoring/Extract/CaptureVariables.cpp =================================================================== --- /dev/null +++ lib/Tooling/Refactoring/Extract/CaptureVariables.cpp @@ -0,0 +1,126 @@ +//===--- CapturedVariables.cpp - Clang refactoring library ----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "CaptureVariables.h" +#include "clang/AST/Decl.h" +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/Tooling/Refactoring/ASTSelection.h" +#include "llvm/ADT/DenseMap.h" + +using namespace clang; + +namespace { + +// FIXME: Determine 'const' qualifier. +// FIXME: Track variables defined in the extracted code. +/// Scans the extracted AST to determine which variables have to be captured +/// and passed to the extracted function. +class ExtractedVariableCaptureVisitor + : public RecursiveASTVisitor { +public: + struct ExtractedEntityInfo { + /// True if this entity is used in extracted code. + bool IsUsed = false; + /// True if this entity is defined in the extracted code. + bool IsDefined = false; + }; + + bool VisitDeclRefExpr(const DeclRefExpr *E) { + const VarDecl *VD = dyn_cast(E->getDecl()); + if (!VD) + return true; + // FIXME: Capture 'self'. + if (!VD->isLocalVarDeclOrParm()) + return true; + captureVariable(VD); + return true; + } + + bool VisitVarDecl(const VarDecl *VD) { + captureVariable(VD).IsDefined = true; + // FIXME: Track more information about variables defined in extracted code + // to support "use after defined in extracted" situation reasonably well. + return true; + } + + const llvm::DenseMap & + getVisitedVariables() const { + return Variables; + } + +private: + ExtractedEntityInfo &captureVariable(const VarDecl *VD) { + ExtractedEntityInfo &Result = Variables[VD]; + Result.IsUsed = true; + return Result; + } + + llvm::DenseMap Variables; + // TODO: Track fields/this/self. +}; + +} // end anonymous namespace + +namespace clang { +namespace tooling { + +void CapturedExtractedEntity::printParamDecl(llvm::raw_ostream &OS, + const PrintingPolicy &PP) const { + switch (Kind) { + case CapturedVarDecl: + VD->getType().print(OS, PP, /*PlaceHolder=*/VD->getName()); + break; + } +} + +void CapturedExtractedEntity::printFunctionCallArg( + llvm::raw_ostream &OS, const PrintingPolicy &PP) const { + // FIXME: Take address if needed. + switch (Kind) { + case CapturedVarDecl: + OS << VD->getName(); + break; + } +} + +StringRef CapturedExtractedEntity::getName() const { + switch (Kind) { + case CapturedVarDecl: + return VD->getName(); + } + llvm_unreachable("invalid kind!"); +} + +/// Scans the extracted AST to determine which variables have to be captured +/// and passed to the extracted function. +std::vector +findCapturedExtractedEntities(const CodeRangeASTSelection &Code) { + ExtractedVariableCaptureVisitor Visitor; + for (size_t I = 0, E = Code.size(); I != E; ++I) + Visitor.TraverseStmt(const_cast(Code[I])); + std::vector Entities; + for (const auto &I : Visitor.getVisitedVariables()) { + if (!I.getSecond().IsDefined) + Entities.push_back(CapturedExtractedEntity(I.getFirst())); + // FIXME: Handle variables used after definition in extracted code. + } + // Sort the entities by name. + std::sort( + Entities.begin(), Entities.end(), + [](const CapturedExtractedEntity &X, const CapturedExtractedEntity &Y) { + return X.getName() < Y.getName(); + }); + // FIXME: Capture any field if necessary (method -> function extraction). + // FIXME: Capture 'this' / 'self' if necessary. + // FIXME: Compute the actual parameter types. + return Entities; +} + +} // end namespace tooling +} // end namespace clang Index: lib/Tooling/Refactoring/Extract/Extract.cpp =================================================================== --- lib/Tooling/Refactoring/Extract/Extract.cpp +++ lib/Tooling/Refactoring/Extract/Extract.cpp @@ -14,6 +14,7 @@ //===----------------------------------------------------------------------===// #include "clang/Tooling/Refactoring/Extract/Extract.h" +#include "CaptureVariables.h" #include "SourceExtraction.h" #include "clang/AST/ASTContext.h" #include "clang/AST/DeclCXX.h" @@ -111,7 +112,8 @@ const LangOptions &LangOpts = AST.getLangOpts(); Rewriter ExtractedCodeRewriter(SM, LangOpts); - // FIXME: Capture used variables. + std::vector Captures = + findCapturedExtractedEntities(Code); // Compute the return type. QualType ReturnType = AST.VoidTy; @@ -125,14 +127,6 @@ // FIXME: Rewrite the extracted code performing any required adjustments. - // FIXME: Capture any field if necessary (method -> function extraction). - - // FIXME: Sort captured variables by name. - - // FIXME: Capture 'this' / 'self' if necessary. - - // FIXME: Compute the actual parameter types. - // Compute the location of the extracted declaration. SourceLocation ExtractedDeclLocation = computeFunctionExtractionLocation(ParentDecl); @@ -157,7 +151,11 @@ OS << "static "; ReturnType.print(OS, PP, DeclName); OS << '('; - // FIXME: Arguments. + for (const auto &Capture : llvm::enumerate(Captures)) { + if (Capture.index() > 0) + OS << ", "; + Capture.value().printParamDecl(OS, PP); + } OS << ')'; // Function body. @@ -179,7 +177,11 @@ llvm::raw_string_ostream OS(ReplacedCode); OS << DeclName << '('; - // FIXME: Forward arguments. + for (const auto &Capture : llvm::enumerate(Captures)) { + if (Capture.index() > 0) + OS << ", "; + Capture.value().printFunctionCallArg(OS, PP); + } OS << ')'; if (Semicolons.isNeededInOriginalFunction()) OS << ';'; Index: test/Refactor/Extract/CaptureSimpleVariables.cpp =================================================================== --- /dev/null +++ test/Refactor/Extract/CaptureSimpleVariables.cpp @@ -0,0 +1,40 @@ +// RUN: clang-refactor extract -selection=test:%s %s -- 2>&1 | grep -v CHECK | FileCheck %s + +void captureStaticVars() { + static int x; + /*range astaticvar=->+0:39*/int y = x; + x += 1; +} +// CHECK: 1 'astaticvar' results: +// CHECK: static void extracted(int x) { +// CHECK-NEXT: int y = x;{{$}} +// CHECK-NEXT: }{{[[:space:]].*}} +// CHECK-NEXT: void captureStaticVars() { +// CHECK-NEXT: static int x; +// CHECK-NEXT: extracted(x);{{$}} + +typedef struct { + int width, height; +} Rectangle; + +void basicTypes(int i, float f, char c, const int *ip, float *fp, const Rectangle *structPointer) { + /*range basictypes=->+0:73*/basicTypes(i, f, c, ip, fp, structPointer); +} +// CHECK: 1 'basictypes' results: +// CHECK: static void extracted(char c, float f, float *fp, int i, const int *ip, const Rectangle *structPointer) { +// CHECK-NEXT: basicTypes(i, f, c, ip, fp, structPointer);{{$}} +// CHECK-NEXT: }{{[[:space:]].*}} +// CHECK-NEXT: void basicTypes(int i, float f, char c, const int *ip, float *fp, const Rectangle *structPointer) { +// CHECK-NEXT: extracted(c, f, fp, i, ip, structPointer);{{$}} + +int noGlobalsPlease = 0; + +void cantCaptureGlobals() { + /*range cantCaptureGlobals=->+0:62*/int y = noGlobalsPlease; +} +// CHECK: 1 'cantCaptureGlobals' results: +// CHECK: static void extracted() { +// CHECK-NEXT: int y = noGlobalsPlease;{{$}} +// CHECK-NEXT: }{{[[:space:]].*}} +// CHECK-NEXT: void cantCaptureGlobals() { +// CHECK-NEXT: extracted();{{$}} Index: test/Refactor/Extract/ExtractionSemicolonPolicy.cpp =================================================================== --- test/Refactor/Extract/ExtractionSemicolonPolicy.cpp +++ test/Refactor/Extract/ExtractionSemicolonPolicy.cpp @@ -6,11 +6,11 @@ /*range adeclstmt=->+0:59*/int area = r.width * r.height; } // CHECK: 1 'adeclstmt' results: -// CHECK: static void extracted() { +// CHECK: static void extracted(const Rectangle &r) { // CHECK-NEXT: int area = r.width * r.height;{{$}} // CHECK-NEXT: }{{[[:space:]].*}} // CHECK-NEXT: void extractStatement(const Rectangle &r) { -// CHECK-NEXT: /*range adeclstmt=->+0:59*/extracted();{{$}} +// CHECK-NEXT: /*range adeclstmt=->+0:59*/extracted(r);{{$}} // CHECK-NEXT: } void extractStatementNoSemiIf(const Rectangle &r) { @@ -19,13 +19,13 @@ } } // CHECK: 1 'bextractif' results: -// CHECK: static void extracted() { +// CHECK: static void extracted(const Rectangle &r) { // CHECK-NEXT: if (r.width) { // CHECK-NEXT: int x = r.height; // CHECK-NEXT: }{{$}} // CHECK-NEXT: }{{[[:space:]].*}} // CHECK-NEXT: void extractStatementNoSemiIf(const Rectangle &r) { -// CHECK-NEXT: /*range bextractif=->+2:4*/extracted();{{$}} +// CHECK-NEXT: /*range bextractif=->+2:4*/extracted(r);{{$}} // CHECK-NEXT: } void extractStatementDontExtraneousSemi(const Rectangle &r) { @@ -34,13 +34,13 @@ } ; } //^ This semicolon shouldn't be extracted. // CHECK: 1 'cextractif' results: -// CHECK: static void extracted() { +// CHECK: static void extracted(const Rectangle &r) { // CHECK-NEXT: if (r.width) { // CHECK-NEXT: int x = r.height; // CHECK-NEXT: }{{$}} // CHECK-NEXT: }{{[[:space:]].*}} // CHECK-NEXT: void extractStatementDontExtraneousSemi(const Rectangle &r) { -// CHECK-NEXT: extracted(); ;{{$}} +// CHECK-NEXT: extracted(r); ;{{$}} // CHECK-NEXT: } void extractStatementNotSemiSwitch() { @@ -102,12 +102,12 @@ } } // CHECK: 1 'gextract' results: -// CHECK: static void extracted() { +// CHECK: static void extracted(XS xs) { // CHECK-NEXT: for (int i : xs) { // CHECK-NEXT: }{{$}} // CHECK-NEXT: }{{[[:space:]].*}} // CHECK-NEXT: void extractStatementNotSemiRangedFor(XS xs) { -// CHECK-NEXT: extracted();{{$}} +// CHECK-NEXT: extracted(xs);{{$}} // CHECK-NEXT: } void extractStatementNotSemiRangedTryCatch() { Index: test/Refactor/Extract/ExtractionSemicolonPolicy.m =================================================================== --- test/Refactor/Extract/ExtractionSemicolonPolicy.m +++ test/Refactor/Extract/ExtractionSemicolonPolicy.m @@ -10,7 +10,7 @@ } } // CHECK: 1 'astmt' results: -// CHECK: static void extracted() { +// CHECK: static void extracted(NSArray *array) { // CHECK-NEXT: for (id i in array) { // CHECK-NEXT: int x = 0; // CHECK-NEXT: }{{$}} @@ -23,7 +23,7 @@ } } // CHECK: 1 'bstmt' results: -// CHECK: static void extracted() { +// CHECK: static void extracted(id lock) { // CHECK-NEXT: @synchronized(lock) { // CHECK-NEXT: int x = 0; // CHECK-NEXT: }{{$}}