Index: clangd/FileDistance.h =================================================================== --- clangd/FileDistance.h +++ clangd/FileDistance.h @@ -56,6 +56,7 @@ unsigned UpCost = 2; // |foo/bar.h -> foo| unsigned DownCost = 1; // |foo -> foo/bar.h| unsigned IncludeCost = 2; // |foo.cc -> included_header.h| + bool AllowDownTraversalFromRoot = true; // | / -> /a | }; struct SourceParams { @@ -70,6 +71,7 @@ class FileDistance { public: static constexpr unsigned Unreachable = std::numeric_limits::max(); + static const llvm::hash_code RootHash; FileDistance(llvm::StringMap Sources, const FileDistanceOptions &Opts = {}); Index: clangd/FileDistance.cpp =================================================================== --- clangd/FileDistance.cpp +++ clangd/FileDistance.cpp @@ -54,6 +54,7 @@ } constexpr const unsigned FileDistance::Unreachable; +const llvm::hash_code FileDistance::RootHash = hash_value(StringRef("/")); FileDistance::FileDistance(StringMap Sources, const FileDistanceOptions &Opts) @@ -99,15 +100,18 @@ for (auto Child : DownEdges.lookup(hash_value(llvm::StringRef("")))) Next.push(Child); while (!Next.empty()) { - auto ParentCost = Cache.lookup(Next.front()); - for (auto Child : DownEdges.lookup(Next.front())) { - auto &ChildCost = - Cache.try_emplace(Child, Unreachable).first->getSecond(); - if (ParentCost + Opts.DownCost < ChildCost) - ChildCost = ParentCost + Opts.DownCost; + auto Parent = Next.front(); + Next.pop(); + auto ParentCost = Cache.lookup(Parent); + for (auto Child : DownEdges.lookup(Parent)) { + if (Parent != RootHash || Opts.AllowDownTraversalFromRoot) { + auto &ChildCost = + Cache.try_emplace(Child, Unreachable).first->getSecond(); + if (ParentCost + Opts.DownCost < ChildCost) + ChildCost = ParentCost + Opts.DownCost; + } Next.push(Child); } - Next.pop(); } } @@ -119,6 +123,11 @@ for (StringRef Rest = Canonical; !Rest.empty(); Rest = parent_path(Rest, sys::path::Style::posix)) { auto Hash = hash_value(Rest); + if (Hash == RootHash && !Ancestors.empty() && + !Opts.AllowDownTraversalFromRoot) { + Cost = Unreachable; + break; + } auto It = Cache.find(Hash); if (It != Cache.end()) { Cost = It->second; Index: unittests/clangd/FileDistanceTests.cpp =================================================================== --- unittests/clangd/FileDistanceTests.cpp +++ unittests/clangd/FileDistanceTests.cpp @@ -95,6 +95,20 @@ EXPECT_EQ(D.distance("/a/b/z"), 2u); } +TEST(FileDistance, DisallowDownTraversalsFromRoot) { + FileDistanceOptions Opts; + Opts.UpCost = Opts.DownCost = 1; + Opts.AllowDownTraversalFromRoot = false; + SourceParams CostLots; + CostLots.Cost = 100; + + FileDistance D({{"/", SourceParams()}, {"/a/b/c", CostLots}}, Opts); + EXPECT_EQ(D.distance("/"), 0u); + EXPECT_EQ(D.distance("/a"), 102u); + EXPECT_EQ(D.distance("/a/b"), 101u); + EXPECT_EQ(D.distance("/x"), FileDistance::Unreachable); +} + } // namespace } // namespace clangd } // namespace clang