diff --git a/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.h b/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.h --- a/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.h +++ b/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.h @@ -38,10 +38,15 @@ Marshaller() = delete; Marshaller(llvm::StringRef RemoteIndexRoot, llvm::StringRef LocalIndexRoot); - clangd::FuzzyFindRequest fromProtobuf(const FuzzyFindRequest *Request); llvm::Optional fromProtobuf(const Symbol &Message); llvm::Optional fromProtobuf(const Ref &Message); + llvm::Expected + fromProtobuf(const LookupRequest *Message); + llvm::Expected + fromProtobuf(const FuzzyFindRequest *Message); + llvm::Expected fromProtobuf(const RefsRequest *Message); + /// toProtobuf() functions serialize native clangd types and strip IndexRoot /// from the file paths specific to indexing machine. fromProtobuf() functions /// deserialize clangd types and translate relative paths into machine-native diff --git a/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.cpp b/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.cpp --- a/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.cpp +++ b/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.cpp @@ -17,6 +17,7 @@ #include "index/SymbolOrigin.h" #include "support/Logger.h" #include "clang/Index/IndexSymbol.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallString.h" @@ -30,6 +31,22 @@ namespace clangd { namespace remote { +namespace { + +template +llvm::Expected> getIDs(MessageT *Message) { + llvm::DenseSet Result; + for (const auto &ID : Message->ids()) { + auto SID = SymbolID::fromStr(StringRef(ID)); + if (!SID) + return SID.takeError(); + Result.insert(*SID); + } + return Result; +} + +} // namespace + Marshaller::Marshaller(llvm::StringRef RemoteIndexRoot, llvm::StringRef LocalIndexRoot) : Strings(Arena) { @@ -49,27 +66,50 @@ assert(!RemoteIndexRoot.empty() || !LocalIndexRoot.empty()); } -clangd::FuzzyFindRequest -Marshaller::fromProtobuf(const FuzzyFindRequest *Request) { +llvm::Expected +Marshaller::fromProtobuf(const LookupRequest *Message) { + clangd::LookupRequest Req; + auto IDs = getIDs(Message); + if (!IDs) + return IDs.takeError(); + Req.IDs = std::move(*IDs); + return Req; +} + +llvm::Expected +Marshaller::fromProtobuf(const FuzzyFindRequest *Message) { assert(RemoteIndexRoot); clangd::FuzzyFindRequest Result; - Result.Query = Request->query(); - for (const auto &Scope : Request->scopes()) + Result.Query = Message->query(); + for (const auto &Scope : Message->scopes()) Result.Scopes.push_back(Scope); - Result.AnyScope = Request->any_scope(); - if (Request->limit()) - Result.Limit = Request->limit(); - Result.RestrictForCodeCompletion = Request->restricted_for_code_completion(); - for (const auto &Path : Request->proximity_paths()) { + Result.AnyScope = Message->any_scope(); + if (Message->limit()) + Result.Limit = Message->limit(); + Result.RestrictForCodeCompletion = Message->restricted_for_code_completion(); + for (const auto &Path : Message->proximity_paths()) { llvm::SmallString<256> LocalPath = llvm::StringRef(*RemoteIndexRoot); llvm::sys::path::append(LocalPath, Path); Result.ProximityPaths.push_back(std::string(LocalPath)); } - for (const auto &Type : Request->preferred_types()) + for (const auto &Type : Message->preferred_types()) Result.ProximityPaths.push_back(Type); return Result; } +llvm::Expected +Marshaller::fromProtobuf(const RefsRequest *Message) { + clangd::RefsRequest Req; + auto IDs = getIDs(Message); + if (!IDs) + return IDs.takeError(); + Req.IDs = std::move(*IDs); + Req.Filter = static_cast(Message->filter()); + if (Message->limit()) + Req.Limit = Message->limit(); + return Req; +} + llvm::Optional Marshaller::fromProtobuf(const Symbol &Message) { if (!Message.has_info() || !Message.has_canonical_declaration()) { elog("Cannot convert Symbol from protobuf (missing info, definition or " @@ -157,8 +197,7 @@ RPCRequest.set_restricted_for_code_completion(From.RestrictForCodeCompletion); for (const auto &Path : From.ProximityPaths) { llvm::SmallString<256> RelativePath = llvm::StringRef(Path); - if (llvm::sys::path::replace_path_prefix(RelativePath, *LocalIndexRoot, - "")) + if (llvm::sys::path::replace_path_prefix(RelativePath, *LocalIndexRoot, "")) RPCRequest.add_proximity_paths(llvm::sys::path::convert_to_slash( RelativePath, llvm::sys::path::Style::posix)); } diff --git a/clang-tools-extra/clangd/index/remote/server/Server.cpp b/clang-tools-extra/clangd/index/remote/server/Server.cpp --- a/clang-tools-extra/clangd/index/remote/server/Server.cpp +++ b/clang-tools-extra/clangd/index/remote/server/Server.cpp @@ -9,6 +9,7 @@ #include "index/Index.h" #include "index/Serialization.h" #include "index/remote/marshalling/Marshalling.h" +#include "support/Logger.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Path.h" @@ -59,14 +60,12 @@ grpc::Status Lookup(grpc::ServerContext *Context, const LookupRequest *Request, grpc::ServerWriter *Reply) override { - clangd::LookupRequest Req; - for (const auto &ID : Request->ids()) { - auto SID = SymbolID::fromStr(StringRef(ID)); - if (!SID) - return grpc::Status::CANCELLED; - Req.IDs.insert(*SID); + auto Req = ProtobufMarshaller->fromProtobuf(Request); + if (!Req) { + elog("Can not parse LookupRequest from protobuf: {0}", Req.takeError()); + return grpc::Status::CANCELLED; } - Index->lookup(Req, [&](const clangd::Symbol &Sym) { + Index->lookup(*Req, [&](const clangd::Symbol &Sym) { auto SerializedSymbol = ProtobufMarshaller->toProtobuf(Sym); if (!SerializedSymbol) return; @@ -83,8 +82,13 @@ grpc::Status FuzzyFind(grpc::ServerContext *Context, const FuzzyFindRequest *Request, grpc::ServerWriter *Reply) override { - const auto Req = ProtobufMarshaller->fromProtobuf(Request); - bool HasMore = Index->fuzzyFind(Req, [&](const clangd::Symbol &Sym) { + auto Req = ProtobufMarshaller->fromProtobuf(Request); + if (!Req) { + elog("Can not parse FuzzyFindRequest from protobuf: {0}", + Req.takeError()); + return grpc::Status::CANCELLED; + } + bool HasMore = Index->fuzzyFind(*Req, [&](const clangd::Symbol &Sym) { auto SerializedSymbol = ProtobufMarshaller->toProtobuf(Sym); if (!SerializedSymbol) return; @@ -100,14 +104,12 @@ grpc::Status Refs(grpc::ServerContext *Context, const RefsRequest *Request, grpc::ServerWriter *Reply) override { - clangd::RefsRequest Req; - for (const auto &ID : Request->ids()) { - auto SID = SymbolID::fromStr(StringRef(ID)); - if (!SID) - return grpc::Status::CANCELLED; - Req.IDs.insert(*SID); + auto Req = ProtobufMarshaller->fromProtobuf(Request); + if (!Req) { + elog("Can not parse RefsRequest from protobuf: {0}", Req.takeError()); + return grpc::Status::CANCELLED; } - bool HasMore = Index->refs(Req, [&](const clangd::Ref &Reference) { + bool HasMore = Index->refs(*Req, [&](const clangd::Ref &Reference) { auto SerializedRef = ProtobufMarshaller->toProtobuf(Reference); if (!SerializedRef) return; diff --git a/clang-tools-extra/clangd/unittests/remote/MarshallingTests.cpp b/clang-tools-extra/clangd/unittests/remote/MarshallingTests.cpp --- a/clang-tools-extra/clangd/unittests/remote/MarshallingTests.cpp +++ b/clang-tools-extra/clangd/unittests/remote/MarshallingTests.cpp @@ -18,6 +18,7 @@ #include "clang/Index/IndexSymbol.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" #include "llvm/Support/Path.h" #include "llvm/Support/StringSaver.h" #include "gmock/gmock.h" @@ -271,6 +272,30 @@ EXPECT_EQ(toYAML(Sym), toYAML(*Deserialized)); } +TEST(RemoteMarshallingTest, LookupRequestSerialization) { + clangd::LookupRequest Request; + Request.IDs.insert(llvm::cantFail(SymbolID::fromStr("0000000000000001"))); + Request.IDs.insert(llvm::cantFail(SymbolID::fromStr("0000000000000002"))); + + Marshaller ProtobufMarshaller(testPath("remote/"), testPath("local/")); + + auto Serialized = ProtobufMarshaller.toProtobuf(Request); + EXPECT_EQ(static_cast(Serialized.ids_size()), Request.IDs.size()); + auto Deserialized = ProtobufMarshaller.fromProtobuf(&Serialized); + ASSERT_TRUE(bool(Deserialized)); + EXPECT_EQ(Deserialized->IDs, Request.IDs); +} + +TEST(RemoteMarshallingTest, LookupRequestFailingSerialization) { + clangd::LookupRequest Request; + Marshaller ProtobufMarshaller(testPath("remote/"), testPath("local/")); + auto Serialized = ProtobufMarshaller.toProtobuf(Request); + Serialized.add_ids("Invalid Symbol ID"); + auto Deserialized = ProtobufMarshaller.fromProtobuf(&Serialized); + EXPECT_FALSE(Deserialized); + llvm::consumeError(Deserialized.takeError()); +} + TEST(RemoteMarshallingTest, FuzzyFindRequestSerialization) { clangd::FuzzyFindRequest Request; Request.ProximityPaths = {testPath("local/Header.h"), @@ -280,11 +305,43 @@ auto Serialized = ProtobufMarshaller.toProtobuf(Request); EXPECT_EQ(Serialized.proximity_paths_size(), 2); auto Deserialized = ProtobufMarshaller.fromProtobuf(&Serialized); - EXPECT_THAT(Deserialized.ProximityPaths, + ASSERT_TRUE(bool(Deserialized)); + EXPECT_THAT(Deserialized->ProximityPaths, testing::ElementsAre(testPath("remote/Header.h"), testPath("remote/subdir/OtherHeader.h"))); } +TEST(RemoteMarshallingTest, RefsRequestSerialization) { + clangd::RefsRequest Request; + Request.IDs.insert(llvm::cantFail(SymbolID::fromStr("0000000000000001"))); + Request.IDs.insert(llvm::cantFail(SymbolID::fromStr("0000000000000002"))); + + Request.Limit = 9000; + Request.Filter = RefKind::Spelled | RefKind::Declaration; + + Marshaller ProtobufMarshaller(testPath("remote/"), testPath("local/")); + + auto Serialized = ProtobufMarshaller.toProtobuf(Request); + EXPECT_EQ(static_cast(Serialized.ids_size()), Request.IDs.size()); + EXPECT_EQ(Serialized.limit(), Request.Limit); + auto Deserialized = ProtobufMarshaller.fromProtobuf(&Serialized); + ASSERT_TRUE(bool(Deserialized)); + EXPECT_EQ(Deserialized->IDs, Request.IDs); + ASSERT_TRUE(Deserialized->Limit); + EXPECT_EQ(*Deserialized->Limit, Request.Limit); + EXPECT_EQ(Deserialized->Filter, Request.Filter); +} + +TEST(RemoteMarshallingTest, RefsRequestFailingSerialization) { + clangd::RefsRequest Request; + Marshaller ProtobufMarshaller(testPath("remote/"), testPath("local/")); + auto Serialized = ProtobufMarshaller.toProtobuf(Request); + Serialized.add_ids("Invalid Symbol ID"); + auto Deserialized = ProtobufMarshaller.fromProtobuf(&Serialized); + EXPECT_FALSE(Deserialized); + llvm::consumeError(Deserialized.takeError()); +} + TEST(RemoteMarshallingTest, RelativePathToURITranslation) { Marshaller ProtobufMarshaller(/*RemoteIndexRoot=*/"", /*LocalIndexRoot=*/testPath("home/project/"));