Index: include/clang/Tooling/Refactoring/RefactoringResultConsumer.h =================================================================== --- include/clang/Tooling/Refactoring/RefactoringResultConsumer.h +++ include/clang/Tooling/Refactoring/RefactoringResultConsumer.h @@ -12,6 +12,7 @@ #include "clang/Basic/LLVM.h" #include "clang/Tooling/Refactoring/AtomicChange.h" +#include "clang/Tooling/Refactoring/Rename/SymbolOccurrences.h" #include "llvm/Support/Error.h" namespace clang { @@ -34,6 +35,10 @@ defaultResultHandler(); } + /// Handles the symbol occurrences that are found by an interactive + /// refactoring action. + virtual void handle(SymbolOccurrences Occurrences) { defaultResultHandler(); } + private: void defaultResultHandler() { handleError(llvm::make_error( Index: unittests/Tooling/RefactoringActionRulesTest.cpp =================================================================== --- unittests/Tooling/RefactoringActionRulesTest.cpp +++ unittests/Tooling/RefactoringActionRulesTest.cpp @@ -11,6 +11,7 @@ #include "RewriterTestContext.h" #include "clang/Tooling/Refactoring.h" #include "clang/Tooling/Refactoring/RefactoringActionRules.h" +#include "clang/Tooling/Refactoring/Rename/SymbolName.h" #include "clang/Tooling/Tooling.h" #include "llvm/Support/Errc.h" #include "gtest/gtest.h" @@ -175,4 +176,49 @@ EXPECT_EQ(Message, "bad selection"); } +Optional findOccurrences(RefactoringActionRule &Rule, + RefactoringRuleContext &Context) { + class Consumer final : public RefactoringResultConsumer { + void handleError(llvm::Error) override {} + void handle(SymbolOccurrences Occurrences) override { + Result = std::move(Occurrences); + } + + public: + Optional Result; + }; + + Consumer C; + Rule.invoke(C, Context); + return std::move(C.Result); +} + +TEST_F(RefactoringActionRulesTest, ReturnSymbolOccurrences) { + auto Rule = createRefactoringRule( + [](selection::SourceSelectionRange Selection) + -> Expected { + SymbolOccurrences Occurrences; + Occurrences.push_back(SymbolOccurrence( + SymbolName("test"), SymbolOccurrence::MatchingSymbol, + Selection.getRange().getBegin())); + return Occurrences; + }, + requiredSelection( + selection::identity())); + + RefactoringRuleContext RefContext(Context.Sources); + SourceLocation Cursor = + Context.Sources.getLocForStartOfFile(Context.Sources.getMainFileID()); + RefContext.setSelectionRange({Cursor, Cursor}); + Optional Result = findOccurrences(*Rule, RefContext); + + ASSERT_FALSE(!Result); + SymbolOccurrences Occurrences = std::move(*Result); + EXPECT_EQ(Occurrences.size(), 1u); + EXPECT_EQ(Occurrences[0].getKind(), SymbolOccurrence::MatchingSymbol); + EXPECT_EQ(Occurrences[0].getNameRanges().size(), 1u); + EXPECT_EQ(Occurrences[0].getNameRanges()[0], + SourceRange(Cursor, Cursor.getLocWithOffset(strlen("test")))); +} + } // end anonymous namespace