diff --git a/clang-tools-extra/clangd/index/Relation.h b/clang-tools-extra/clangd/index/Relation.h --- a/clang-tools-extra/clangd/index/Relation.h +++ b/clang-tools-extra/clangd/index/Relation.h @@ -21,6 +21,7 @@ enum class RelationKind : uint8_t { BaseOf, + OverridenBy, }; /// Represents a relation between two symbols. @@ -41,6 +42,8 @@ std::tie(Other.Subject, Other.Predicate, Other.Object); } }; +llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const RelationKind R); +llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Relation &R); class RelationSlab { public: diff --git a/clang-tools-extra/clangd/index/Relation.cpp b/clang-tools-extra/clangd/index/Relation.cpp --- a/clang-tools-extra/clangd/index/Relation.cpp +++ b/clang-tools-extra/clangd/index/Relation.cpp @@ -13,6 +13,20 @@ namespace clang { namespace clangd { +llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const RelationKind R) { + switch (R) { + case RelationKind ::BaseOf: + return OS << "BaseOf"; + case RelationKind ::OverridenBy: + return OS << "OverridenBy"; + } + llvm_unreachable("Unhandled RelationKind enum."); +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Relation &R) { + return OS << R.Subject << " " << R.Predicate << " " << R.Object; +} + llvm::iterator_range RelationSlab::lookup(const SymbolID &Subject, RelationKind Predicate) const { auto IterPair = std::equal_range(Relations.begin(), Relations.end(), diff --git a/clang-tools-extra/clangd/index/SymbolCollector.cpp b/clang-tools-extra/clangd/index/SymbolCollector.cpp --- a/clang-tools-extra/clangd/index/SymbolCollector.cpp +++ b/clang-tools-extra/clangd/index/SymbolCollector.cpp @@ -15,6 +15,7 @@ #include "SourceCode.h" #include "SymbolLocation.h" #include "URI.h" +#include "index/Relation.h" #include "index/SymbolID.h" #include "support/Logger.h" #include "clang/AST/Decl.h" @@ -187,9 +188,12 @@ return Result; } -bool shouldIndexRelation(const index::SymbolRelation &R) { - // We currently only index BaseOf relations, for type hierarchy subtypes. - return R.Roles & static_cast(index::SymbolRole::RelationBaseOf); +llvm::Optional indexableRelation(const index::SymbolRelation &R) { + if (R.Roles & static_cast(index::SymbolRole::RelationBaseOf)) + return RelationKind::BaseOf; + if (R.Roles & static_cast(index::SymbolRole::RelationOverrideOf)) + return RelationKind::OverridenBy; + return None; } } // namespace @@ -486,14 +490,10 @@ void SymbolCollector::processRelations( const NamedDecl &ND, const SymbolID &ID, ArrayRef Relations) { - // Store subtype relations. - if (!dyn_cast(&ND)) - return; - for (const auto &R : Relations) { - if (!shouldIndexRelation(R)) + auto RKind = indexableRelation(R); + if (!RKind) continue; - const Decl *Object = R.RelatedSymbol; auto ObjectID = getSymbolID(Object); @@ -509,7 +509,10 @@ // in the index and find nothing, but that's a situation they // probably need to handle for other reasons anyways. // We currently do (B) because it's simpler. - this->Relations.insert(Relation{ID, RelationKind::BaseOf, ObjectID}); + if (*RKind == RelationKind::BaseOf) + this->Relations.insert({ID, *RKind, ObjectID}); + else if (*RKind == RelationKind::OverridenBy) + this->Relations.insert({ObjectID, *RKind, ID}); } } diff --git a/clang-tools-extra/clangd/unittests/SymbolCollectorTests.cpp b/clang-tools-extra/clangd/unittests/SymbolCollectorTests.cpp --- a/clang-tools-extra/clangd/unittests/SymbolCollectorTests.cpp +++ b/clang-tools-extra/clangd/unittests/SymbolCollectorTests.cpp @@ -97,6 +97,9 @@ const Range &Range = ::testing::get<1>(arg); return rangesMatch(Pos.Location, Range); } +MATCHER_P2(OverridenBy, Subject, Object, "") { + return arg == Relation{Subject.ID, RelationKind::OverridenBy, Object.ID}; +} ::testing::Matcher &> HaveRanges(const std::vector Ranges) { return ::testing::UnorderedPointwise(RefRange(), Ranges); @@ -873,12 +876,12 @@ llvm::StringRef Main; llvm::StringRef TargetSymbolName; } TestCases[] = { - { - R"cpp( + { + R"cpp( struct Foo; #define MACRO Foo )cpp", - R"cpp( + R"cpp( struct $spelled[[Foo]] { $spelled[[Foo]](); ~$spelled[[Foo]](); @@ -886,24 +889,24 @@ $spelled[[Foo]] Variable1; $implicit[[MACRO]] Variable2; )cpp", - "Foo", - }, - { - R"cpp( + "Foo", + }, + { + R"cpp( class Foo { public: Foo() = default; }; )cpp", - R"cpp( + R"cpp( void f() { Foo $implicit[[f]]; f = $spelled[[Foo]]();} )cpp", - "Foo::Foo" /// constructor. - }, + "Foo::Foo" /// constructor. + }, }; CollectorOpts.RefFilter = RefKind::All; CollectorOpts.RefsInHeaders = false; - for (const auto& T : TestCases) { + for (const auto &T : TestCases) { Annotations Header(T.Header); Annotations Main(T.Main); // Reset the file system. @@ -1031,7 +1034,7 @@ HaveRanges(Header.ranges("macro"))))); } -TEST_F(SymbolCollectorTest, Relations) { +TEST_F(SymbolCollectorTest, BaseOfRelations) { std::string Header = R"( class Base {}; class Derived : public Base {}; @@ -1043,6 +1046,77 @@ Contains(Relation{Base.ID, RelationKind::BaseOf, Derived.ID})); } +TEST_F(SymbolCollectorTest, OverrideRelationsSimpleInheritance) { + std::string Header = R"cpp( + class A { + virtual void foo(); + }; + class B : public A { + void foo() override; // A::foo + virtual void bar(); + }; + class C : public B { + void bar() override; // B::bar + }; + class D: public C { + void foo() override; // B::foo + void bar() override; // C::bar + }; + )cpp"; + runSymbolCollector(Header, /*Main=*/""); + const Symbol &AFoo = findSymbol(Symbols, "A::foo"); + const Symbol &BFoo = findSymbol(Symbols, "B::foo"); + const Symbol &DFoo = findSymbol(Symbols, "D::foo"); + + const Symbol &BBar = findSymbol(Symbols, "B::bar"); + const Symbol &CBar = findSymbol(Symbols, "C::bar"); + const Symbol &DBar = findSymbol(Symbols, "D::bar"); + + std::vector Result; + for (const Relation &R : Relations) + if (R.Predicate == RelationKind::OverridenBy) + Result.push_back(R); + EXPECT_THAT(Result, UnorderedElementsAre( + OverridenBy(AFoo, BFoo), OverridenBy(BBar, CBar), + OverridenBy(BFoo, DFoo), OverridenBy(CBar, DBar))); +} + +TEST_F(SymbolCollectorTest, OverrideRelationsMultipleInheritance) { + std::string Header = R"cpp( + class A { + virtual void foo(); + }; + class B { + virtual void bar(); + }; + class C : public B { + void bar() override; // B::bar + virtual void baz(); + } + class D : public A, C { + void foo() override; // A::foo + void bar() override; // C::bar + void baz() override; // C::baz + }; + )cpp"; + runSymbolCollector(Header, /*Main=*/""); + const Symbol &AFoo = findSymbol(Symbols, "A::foo"); + const Symbol &BBar = findSymbol(Symbols, "B::bar"); + const Symbol &CBar = findSymbol(Symbols, "C::bar"); + const Symbol &CBaz = findSymbol(Symbols, "C::baz"); + const Symbol &DFoo = findSymbol(Symbols, "D::foo"); + const Symbol &DBar = findSymbol(Symbols, "D::bar"); + const Symbol &DBaz = findSymbol(Symbols, "D::baz"); + + std::vector Result; + for (const Relation &R : Relations) + if (R.Predicate == RelationKind::OverridenBy) + Result.push_back(R); + EXPECT_THAT(Result, UnorderedElementsAre( + OverridenBy(BBar, CBar), OverridenBy(AFoo, DFoo), + OverridenBy(CBar, DBar), OverridenBy(CBaz, DBaz))); +} + TEST_F(SymbolCollectorTest, CountReferences) { const std::string Header = R"( class W;