diff --git a/clang/include/clang/Tooling/Refactoring/RangeSelector.h b/clang/include/clang/Tooling/Refactoring/RangeSelector.h --- a/clang/include/clang/Tooling/Refactoring/RangeSelector.h +++ b/clang/include/clang/Tooling/Refactoring/RangeSelector.h @@ -37,6 +37,15 @@ /// Convenience version of \c range where end-points are bound nodes. RangeSelector range(std::string BeginID, std::string EndID); +/// Selects the (empty) range [B,B) when \p Selector selects the range [B,E). +RangeSelector before(RangeSelector Selector); + +/// Selects the the point immediately following \p Selector. That is, the +/// (empty) range [E,E), when \p Selector selects either +/// * the CharRange [B,E) or +/// * the TokenRange [B,E'] where the token at E' spans the range [E,E'). +RangeSelector after(RangeSelector Selector); + /// Selects a node, including trailing semicolon (for non-expression /// statements). \p ID is the node's binding in the match result. RangeSelector node(std::string ID); diff --git a/clang/lib/Tooling/Refactoring/RangeSelector.cpp b/clang/lib/Tooling/Refactoring/RangeSelector.cpp --- a/clang/lib/Tooling/Refactoring/RangeSelector.cpp +++ b/clang/lib/Tooling/Refactoring/RangeSelector.cpp @@ -104,6 +104,28 @@ return findPreviousTokenKind(EndLoc, SM, LangOpts, tok::TokenKind::l_paren); } +RangeSelector tooling::before(RangeSelector Selector) { + return [Selector](const MatchResult &Result) -> Expected { + Expected SelectedRange = Selector(Result); + if (!SelectedRange) + return SelectedRange.takeError(); + return CharSourceRange::getCharRange(SelectedRange->getBegin()); + }; +} + +RangeSelector tooling::after(RangeSelector Selector) { + return [Selector](const MatchResult &Result) -> Expected { + Expected SelectedRange = Selector(Result); + if (!SelectedRange) + return SelectedRange.takeError(); + if (SelectedRange->isCharRange()) + return CharSourceRange::getCharRange(SelectedRange->getEnd()); + return CharSourceRange::getCharRange(Lexer::getLocForEndOfToken( + SelectedRange->getEnd(), 0, Result.Context->getSourceManager(), + Result.Context->getLangOpts())); + }; +} + RangeSelector tooling::node(std::string ID) { return [ID](const MatchResult &Result) -> Expected { Expected Node = getNode(Result.Nodes, ID); diff --git a/clang/unittests/Tooling/RangeSelectorTest.cpp b/clang/unittests/Tooling/RangeSelectorTest.cpp --- a/clang/unittests/Tooling/RangeSelectorTest.cpp +++ b/clang/unittests/Tooling/RangeSelectorTest.cpp @@ -21,13 +21,15 @@ using namespace ast_matchers; namespace { -using ::testing::AllOf; -using ::testing::HasSubstr; -using MatchResult = MatchFinder::MatchResult; using ::llvm::Expected; using ::llvm::Failed; using ::llvm::HasValue; using ::llvm::StringError; +using ::testing::AllOf; +using ::testing::HasSubstr; +using ::testing::Property; + +using MatchResult = MatchFinder::MatchResult; struct TestMatch { // The AST unit from which `result` is built. We bundle it because it backs @@ -117,6 +119,53 @@ Failed(withUnboundNodeMessage())); } +MATCHER_P(EqualsCharSourceRange, Range, "") { + return Range.getAsRange() == arg.getAsRange() && + Range.isTokenRange() == arg.isTokenRange(); +} + +TEST(RangeSelectorTest, BeforeOp) { + StringRef Code = R"cc( + int f(int x, int y, int z) { return 3; } + int g() { return f(/* comment */ 3, 7 /* comment */, 9); } + )cc"; + StringRef Call = "call"; + TestMatch Match = matchCode(Code, callExpr().bind(Call)); + const auto* E = Match.Result.Nodes.getNodeAs(Call); + assert(E != nullptr); + auto ExprBegin = E->getSourceRange().getBegin(); + EXPECT_THAT_EXPECTED( + before(node(Call))(Match.Result), + HasValue(EqualsCharSourceRange( + CharSourceRange::getCharRange(ExprBegin, ExprBegin)))); +} + +TEST(RangeSelectorTest, AfterOp) { + StringRef Code = R"cc( + int f(int x, int y, int z) { return 3; } + int g() { return f(/* comment */ 3, 7 /* comment */, 9); } + )cc"; + StringRef Call = "call"; + TestMatch Match = matchCode(Code, callExpr().bind(Call)); + const auto* E = Match.Result.Nodes.getNodeAs(Call); + assert(E != nullptr); + const SourceRange Range = E->getSourceRange(); + // The end token, a right paren, is one character wide, so advance by one, + // bringing us to the semicolon. + const SourceLocation SemiLoc = Range.getEnd().getLocWithOffset(1); + const auto ExpectedAfter = CharSourceRange::getCharRange(SemiLoc, SemiLoc); + + // Test with a char range. + auto CharRange = CharSourceRange::getCharRange(Range.getBegin(), SemiLoc); + EXPECT_THAT_EXPECTED(after(charRange(CharRange))(Match.Result), + HasValue(EqualsCharSourceRange(ExpectedAfter))); + + // Test with a token range. + auto TokenRange = CharSourceRange::getTokenRange(Range); + EXPECT_THAT_EXPECTED(after(charRange(TokenRange))(Match.Result), + HasValue(EqualsCharSourceRange(ExpectedAfter))); +} + TEST(RangeSelectorTest, RangeOp) { StringRef Code = R"cc( int f(int x, int y, int z) { return 3; }