diff --git a/clang/include/clang/Tooling/FixIt.h b/clang/include/clang/Tooling/FixIt.h --- a/clang/include/clang/Tooling/FixIt.h +++ b/clang/include/clang/Tooling/FixIt.h @@ -26,21 +26,23 @@ namespace fixit { namespace internal { -StringRef getText(SourceRange Range, const ASTContext &Context); +StringRef getText(CharSourceRange Range, const ASTContext &Context); /// Returns the SourceRange of a SourceRange. This identity function is /// used by the following template abstractions. -inline SourceRange getSourceRange(const SourceRange &Range) { return Range; } +inline CharSourceRange getSourceRange(const SourceRange &Range) { + return CharSourceRange::getTokenRange(Range); +} /// Returns the SourceRange of the token at Location \p Loc. -inline SourceRange getSourceRange(const SourceLocation &Loc) { - return SourceRange(Loc); +inline CharSourceRange getSourceRange(const SourceLocation &Loc) { + return CharSourceRange::getTokenRange(Loc, Loc); } /// Returns the SourceRange of an given Node. \p Node is typically a /// 'Stmt', 'Expr' or a 'Decl'. -template SourceRange getSourceRange(const T &Node) { - return Node.getSourceRange(); +template CharSourceRange getSourceRange(const T &Node) { + return CharSourceRange::getTokenRange(Node.getSourceRange()); } } // end namespace internal @@ -50,6 +52,32 @@ return internal::getText(internal::getSourceRange(Node), Context); } +// Returns the source range spanning the statement and any trailing semicolon +// that belongs with that statement. +// +// N.B. The API of this function is still evolving and might change in the +// future to include more associated text (like comments). +CharSourceRange getSourceRangeAuto(const Stmt &S, ASTContext &Context); + +CharSourceRange getSourceRangeAuto(const ast_type_traits::DynTypedNode &Node, + ASTContext &Context); +// Catch all for any nodes that aren't DynTypedNode or derived from Stmt. +template ::value)>::type> +CharSourceRange getSourceRangeAuto(const T &Node, ASTContext &Context) { + return internal::getSourceRange(Node); +} + +// Gets the source text of the node, taking into account the node's type and +// context. In contrast with \p getText(), this function selects a source range +// "automatically", extracting text that a reader might intuitively associate +// with a node. Currently, only specialized for \p clang::Stmt, where it will +// include any associated trailing semicolon. +template +StringRef getTextAuto(const T &Node, ASTContext &Context) { + return internal::getText(getSourceRangeAuto(Node, Context), Context); +} + // Returns a FixItHint to remove \p Node. // TODO: Add support for related syntactical elements (i.e. comments, ...). template FixItHint createRemoval(const T &Node) { diff --git a/clang/lib/Tooling/FixIt.cpp b/clang/lib/Tooling/FixIt.cpp --- a/clang/lib/Tooling/FixIt.cpp +++ b/clang/lib/Tooling/FixIt.cpp @@ -11,6 +11,8 @@ // //===----------------------------------------------------------------------===// #include "clang/Tooling/FixIt.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/ASTMatchers/ASTMatchers.h" #include "clang/Lex/Lexer.h" namespace clang { @@ -18,13 +20,66 @@ namespace fixit { namespace internal { -StringRef getText(SourceRange Range, const ASTContext &Context) { - return Lexer::getSourceText(CharSourceRange::getTokenRange(Range), - Context.getSourceManager(), +StringRef getText(CharSourceRange Range, const ASTContext &Context) { + return Lexer::getSourceText(Range, Context.getSourceManager(), Context.getLangOpts()); } } // end namespace internal +// For a given range, returns the lexed token immediately after the range if +// and only if it's a semicolon. +static Optional getTrailingSemi(SourceLocation EndLoc, + const ASTContext &Context) { + if (Optional Next = Lexer::findNextToken( + EndLoc, Context.getSourceManager(), Context.getLangOpts())) { + return Next->is(clang::tok::TokenKind::semi) ? Next : None; + } + return None; +} + +// Determines whether `S` is a strictly a statement. Clang's class hierchy +// implicitly categorizes all expressions as statements. This function +// distinguishs expressions that appear in a context for which they are +// first-class statements. In essence, this excludes expressions that are +// either part of another expression, the condition of a statement or a +// for-loop's init or increment. +static bool isStrictStatement(const Stmt &S, ASTContext &Context) { + using namespace ast_matchers; + + if (!isa(S)) + return true; + + auto NotCondition = unless(hasCondition(equalsNode(&S))); + auto Standalone = + stmt(hasParent(stmt(anyOf(compoundStmt(), ifStmt(NotCondition), + whileStmt(NotCondition), doStmt(NotCondition), + switchStmt(NotCondition), switchCase(), + forStmt(unless(hasLoopInit(equalsNode(&S))), + unless(hasCondition(equalsNode(&S))), + unless(hasIncrement(equalsNode(&S)))), + labelStmt())))) + .bind("stmt"); + return !match(Standalone, S, Context).empty(); +} + +CharSourceRange getSourceRangeAuto(const Stmt &S, ASTContext &Context) { + // Only exlude non-statement expressions. + if (!isa(S) || isStrictStatement(S, Context)) { + // TODO: exclude case where last token is a right brace? + if (auto Tok = getTrailingSemi(S.getEndLoc(), Context)) + return CharSourceRange::getTokenRange(S.getBeginLoc(), + Tok->getLocation()); + } + return CharSourceRange::getTokenRange(S.getSourceRange()); +} + +CharSourceRange getSourceRangeAuto(const ast_type_traits::DynTypedNode &Node, + ASTContext &Context) { + if (const auto *S = Node.get()) + return getSourceRangeAuto(*S, Context); + return CharSourceRange::getTokenRange(Node.getSourceRange()); +} + } // end namespace fixit } // end namespace tooling } // end namespace clang diff --git a/clang/unittests/Tooling/FixItTest.cpp b/clang/unittests/Tooling/FixItTest.cpp --- a/clang/unittests/Tooling/FixItTest.cpp +++ b/clang/unittests/Tooling/FixItTest.cpp @@ -13,6 +13,7 @@ using namespace clang; using tooling::fixit::getText; +using tooling::fixit::getTextAuto; using tooling::fixit::createRemoval; using tooling::fixit::createReplacement; @@ -27,6 +28,24 @@ std::function OnCall; }; +struct IfVisitor : TestVisitor { + bool VisitIfStmt(IfStmt* S) { + OnIfStmt(S, Context); + return true; + } + + std::function OnIfStmt; +}; + +struct VarDeclVisitor : TestVisitor { + bool VisitVarDecl(VarDecl* Decl) { + OnVarDecl(Decl, Context); + return true; + } + + std::function OnVarDecl; +}; + std::string LocationToString(SourceLocation Loc, ASTContext *Context) { return Loc.printToString(Context->getSourceManager()); } @@ -77,6 +96,57 @@ "void foo(int x, int y) { FOO(x,y) }"); } +TEST(FixItTest, getTextAuto) { + CallsVisitor Visitor; + + Visitor.OnCall = [](CallExpr *CE, ASTContext *Context) { + EXPECT_EQ("foo(x, y);", getTextAuto(*CE, *Context)); + + Expr *P0 = CE->getArg(0); + Expr *P1 = CE->getArg(1); + EXPECT_EQ("x", getTextAuto(*P0, *Context)); + EXPECT_EQ("y", getTextAuto(*P1, *Context)); + }; + Visitor.runOver("void foo(int x, int y) { foo(x, y); }"); + Visitor.runOver("void foo(int x, int y) { if (true) foo(x, y); }"); + Visitor.runOver("void foo(int x, int y) { switch(x) foo(x, y); }"); + Visitor.runOver("void foo(int x, int y) { switch(x) case 3: foo(x, y); }"); + Visitor.runOver("void foo(int x, int y) { switch(x) default: foo(x, y); }"); + Visitor.runOver("void foo(int x, int y) { while (true) foo(x, y); }"); + Visitor.runOver("void foo(int x, int y) { do foo(x, y); while (true); }"); + Visitor.runOver("void foo(int x, int y) { for (;;) foo(x, y); }"); + Visitor.runOver("void foo(int x, int y) { bar: foo(x, y); }"); + + Visitor.OnCall = [](CallExpr *CE, ASTContext *Context) { + EXPECT_EQ("foo()", getTextAuto(*CE, *Context)); + }; + Visitor.runOver("int foo() { return foo(); }"); + Visitor.runOver("int foo() { 3 + foo(); return 0; }"); + Visitor.runOver("bool foo() { if (foo()) true; return true; }"); + Visitor.runOver("bool foo() { switch(foo()) true; return true; }"); + Visitor.runOver("bool foo() { while (foo()) true; return true; }"); + Visitor.runOver("bool foo() { do true; while (foo()); return true; }"); + Visitor.runOver("void foo() { for (foo();;) true; }"); + Visitor.runOver("bool foo() { for (;foo();) true; return true; }"); + Visitor.runOver("void foo() { for (;; foo()) true; }"); +} + +TEST(FixItTest, getTextAutoNonStatement) { + VarDeclVisitor Visitor; + Visitor.OnVarDecl = [](VarDecl *D, ASTContext *Context) { + EXPECT_EQ(getText(*D, *Context), getTextAuto(*D, *Context)); + }; + Visitor.runOver("int foo() { int x = 3; return x; }"); +} + +TEST(FixItTest, getTextAutoNonExprStatement) { + IfVisitor Visitor; + Visitor.OnIfStmt = [](IfStmt *S, ASTContext *Context) { + EXPECT_EQ(getText(*S, *Context), getTextAuto(*S, *Context)); + }; + Visitor.runOver("int foo() { int x = 3; return x; }"); +} + TEST(FixItTest, createRemoval) { CallsVisitor Visitor;