Index: include/clang/Tooling/Refactoring/RefactoringResultConsumer.h =================================================================== --- include/clang/Tooling/Refactoring/RefactoringResultConsumer.h +++ include/clang/Tooling/Refactoring/RefactoringResultConsumer.h @@ -12,6 +12,7 @@ #include "clang/Basic/LLVM.h" #include "clang/Tooling/Refactoring/AtomicChange.h" +#include "clang/Tooling/Refactoring/Rename/SymbolOccurrences.h" #include "llvm/Support/Error.h" namespace clang { @@ -37,6 +38,10 @@ /// Handles the source replacements that are produced by a refactoring action. virtual void handle(AtomicChanges SourceReplacements) = 0; + + /// Handles the symbol occurrences that are found by an interactive + /// refactoring action. + virtual void handle(SymbolOccurrences Occurrences) = 0; }; namespace traits { Index: unittests/Tooling/RefactoringActionRulesTest.cpp =================================================================== --- unittests/Tooling/RefactoringActionRulesTest.cpp +++ unittests/Tooling/RefactoringActionRulesTest.cpp @@ -11,6 +11,7 @@ #include "RewriterTestContext.h" #include "clang/Tooling/Refactoring.h" #include "clang/Tooling/Refactoring/RefactoringActionRules.h" +#include "clang/Tooling/Refactoring/Rename/SymbolName.h" #include "clang/Tooling/Tooling.h" #include "llvm/Support/Errc.h" #include "gtest/gtest.h" @@ -32,10 +33,19 @@ std::string DefaultCode = std::string(100, 'a'); }; +class DefaultRefactoringResultConsumer : public RefactoringResultConsumer { +public: + void handleInitiationFailure() override {} + void handleInitiationError(llvm::Error) override {} + void handleInvocationError(llvm::Error) override {} + void handle(AtomicChanges) override {} + void handle(SymbolOccurrences) override {} +}; + Expected> createReplacements(const std::unique_ptr &Rule, RefactoringRuleContext &Context) { - class Consumer final : public RefactoringResultConsumer { + class Consumer final : public DefaultRefactoringResultConsumer { void handleInitiationFailure() { Result = Expected>(None); } @@ -181,4 +191,48 @@ EXPECT_EQ(Message, "bad selection"); } +Optional findOccurrences(RefactoringActionRule &Rule, + RefactoringRuleContext &Context) { + class Consumer final : public DefaultRefactoringResultConsumer { + void handle(SymbolOccurrences Occurrences) override { + Result = std::move(Occurrences); + } + + public: + Optional Result; + }; + + Consumer C; + Rule.invoke(C, Context); + return std::move(C.Result); +} + +TEST_F(RefactoringActionRulesTest, ReturnSymbolOccurrences) { + auto Rule = createRefactoringRule( + [](selection::SourceSelectionRange Selection) + -> Expected { + SymbolOccurrences Occurrences; + Occurrences.push_back(SymbolOccurrence( + SymbolName("test"), SymbolOccurrence::MatchingSymbol, + Selection.getRange().getBegin())); + return Occurrences; + }, + requiredSelection( + selection::identity())); + + RefactoringRuleContext RefContext(Context.Sources); + SourceLocation Cursor = + Context.Sources.getLocForStartOfFile(Context.Sources.getMainFileID()); + RefContext.setSelectionRange({Cursor, Cursor}); + Optional Result = findOccurrences(*Rule, RefContext); + + ASSERT_FALSE(!Result); + SymbolOccurrences Occurrences = std::move(*Result); + EXPECT_EQ(Occurrences.size(), 1u); + EXPECT_EQ(Occurrences[0].getKind(), SymbolOccurrence::MatchingSymbol); + EXPECT_EQ(Occurrences[0].getNameRanges().size(), 1u); + EXPECT_EQ(Occurrences[0].getNameRanges()[0], + SourceRange(Cursor, Cursor.getLocWithOffset(strlen("test")))); +} + } // end anonymous namespace