diff --git a/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.h b/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.h --- a/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.h +++ b/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.h @@ -38,10 +38,15 @@ Marshaller() = delete; Marshaller(llvm::StringRef RemoteIndexRoot, llvm::StringRef LocalIndexRoot); - clangd::FuzzyFindRequest fromProtobuf(const FuzzyFindRequest *Request); llvm::Optional fromProtobuf(const Symbol &Message); llvm::Optional fromProtobuf(const Ref &Message); + llvm::Optional + fromProtobuf(const LookupRequest *Message); + llvm::Optional + fromProtobuf(const FuzzyFindRequest *Message); + llvm::Optional fromProtobuf(const RefsRequest *Message); + /// toProtobuf() functions serialize native clangd types and strip IndexRoot /// from the file paths specific to indexing machine. fromProtobuf() functions /// deserialize clangd types and translate relative paths into machine-native diff --git a/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.cpp b/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.cpp --- a/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.cpp +++ b/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.cpp @@ -17,6 +17,7 @@ #include "index/SymbolOrigin.h" #include "support/Logger.h" #include "clang/Index/IndexSymbol.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallString.h" @@ -30,6 +31,24 @@ namespace clangd { namespace remote { +namespace { + +template +llvm::Optional> getIDs(MessageT *Message) { + llvm::DenseSet Result; + for (const auto &ID : Message->ids()) { + auto SID = SymbolID::fromStr(StringRef(ID)); + if (!SID) { + elog("Invalid SymbolID {1}", SID.takeError()); + return llvm::None; + } + Result.insert(*SID); + } + return Result; +} + +} // namespace + Marshaller::Marshaller(llvm::StringRef RemoteIndexRoot, llvm::StringRef LocalIndexRoot) : Strings(Arena) { @@ -49,27 +68,54 @@ assert(!RemoteIndexRoot.empty() || !LocalIndexRoot.empty()); } -clangd::FuzzyFindRequest -Marshaller::fromProtobuf(const FuzzyFindRequest *Request) { +llvm::Optional +Marshaller::fromProtobuf(const LookupRequest *Message) { + clangd::LookupRequest Req; + auto IDs = getIDs(Message); + if (!IDs) { + elog("Cannot parse LookupRequest from protobuf: invalid Symbol IDs"); + return llvm::None; + } + Req.IDs = std::move(*IDs); + return Req; +} + +llvm::Optional +Marshaller::fromProtobuf(const FuzzyFindRequest *Message) { assert(RemoteIndexRoot); clangd::FuzzyFindRequest Result; - Result.Query = Request->query(); - for (const auto &Scope : Request->scopes()) + Result.Query = Message->query(); + for (const auto &Scope : Message->scopes()) Result.Scopes.push_back(Scope); - Result.AnyScope = Request->any_scope(); - if (Request->limit()) - Result.Limit = Request->limit(); - Result.RestrictForCodeCompletion = Request->restricted_for_code_completion(); - for (const auto &Path : Request->proximity_paths()) { + Result.AnyScope = Message->any_scope(); + if (Message->limit()) + Result.Limit = Message->limit(); + Result.RestrictForCodeCompletion = Message->restricted_for_code_completion(); + for (const auto &Path : Message->proximity_paths()) { llvm::SmallString<256> LocalPath = llvm::StringRef(*RemoteIndexRoot); llvm::sys::path::append(LocalPath, Path); Result.ProximityPaths.push_back(std::string(LocalPath)); } - for (const auto &Type : Request->preferred_types()) + for (const auto &Type : Message->preferred_types()) Result.ProximityPaths.push_back(Type); return Result; } +llvm::Optional +Marshaller::fromProtobuf(const RefsRequest *Message) { + clangd::RefsRequest Req; + auto IDs = getIDs(Message); + if (!IDs) { + elog("Cannot parse RefsRequest from protobuf: invalid Symbol IDs"); + return llvm::None; + } + Req.IDs = std::move(*IDs); + Req.Filter = static_cast(Message->filter()); + if (Message->limit()) + Req.Limit = Message->limit(); + return Req; +} + llvm::Optional Marshaller::fromProtobuf(const Symbol &Message) { if (!Message.has_info() || !Message.has_canonical_declaration()) { elog("Cannot convert Symbol from protobuf (missing info, definition or " @@ -157,8 +203,7 @@ RPCRequest.set_restricted_for_code_completion(From.RestrictForCodeCompletion); for (const auto &Path : From.ProximityPaths) { llvm::SmallString<256> RelativePath = llvm::StringRef(Path); - if (llvm::sys::path::replace_path_prefix(RelativePath, *LocalIndexRoot, - "")) + if (llvm::sys::path::replace_path_prefix(RelativePath, *LocalIndexRoot, "")) RPCRequest.add_proximity_paths(llvm::sys::path::convert_to_slash( RelativePath, llvm::sys::path::Style::posix)); } diff --git a/clang-tools-extra/clangd/index/remote/server/Server.cpp b/clang-tools-extra/clangd/index/remote/server/Server.cpp --- a/clang-tools-extra/clangd/index/remote/server/Server.cpp +++ b/clang-tools-extra/clangd/index/remote/server/Server.cpp @@ -59,14 +59,10 @@ grpc::Status Lookup(grpc::ServerContext *Context, const LookupRequest *Request, grpc::ServerWriter *Reply) override { - clangd::LookupRequest Req; - for (const auto &ID : Request->ids()) { - auto SID = SymbolID::fromStr(StringRef(ID)); - if (!SID) - return grpc::Status::CANCELLED; - Req.IDs.insert(*SID); - } - Index->lookup(Req, [&](const clangd::Symbol &Sym) { + const auto Req = ProtobufMarshaller->fromProtobuf(Request); + if (!Req) + return grpc::Status::CANCELLED; + Index->lookup(*Req, [&](const clangd::Symbol &Sym) { auto SerializedSymbol = ProtobufMarshaller->toProtobuf(Sym); if (!SerializedSymbol) return; @@ -84,7 +80,9 @@ const FuzzyFindRequest *Request, grpc::ServerWriter *Reply) override { const auto Req = ProtobufMarshaller->fromProtobuf(Request); - bool HasMore = Index->fuzzyFind(Req, [&](const clangd::Symbol &Sym) { + if (!Req) + return grpc::Status::CANCELLED; + bool HasMore = Index->fuzzyFind(*Req, [&](const clangd::Symbol &Sym) { auto SerializedSymbol = ProtobufMarshaller->toProtobuf(Sym); if (!SerializedSymbol) return; @@ -100,14 +98,10 @@ grpc::Status Refs(grpc::ServerContext *Context, const RefsRequest *Request, grpc::ServerWriter *Reply) override { - clangd::RefsRequest Req; - for (const auto &ID : Request->ids()) { - auto SID = SymbolID::fromStr(StringRef(ID)); - if (!SID) - return grpc::Status::CANCELLED; - Req.IDs.insert(*SID); - } - bool HasMore = Index->refs(Req, [&](const clangd::Ref &Reference) { + const auto Req = ProtobufMarshaller->fromProtobuf(Request); + if (!Req) + return grpc::Status::CANCELLED; + bool HasMore = Index->refs(*Req, [&](const clangd::Ref &Reference) { auto SerializedRef = ProtobufMarshaller->toProtobuf(Reference); if (!SerializedRef) return; diff --git a/clang-tools-extra/clangd/unittests/remote/MarshallingTests.cpp b/clang-tools-extra/clangd/unittests/remote/MarshallingTests.cpp --- a/clang-tools-extra/clangd/unittests/remote/MarshallingTests.cpp +++ b/clang-tools-extra/clangd/unittests/remote/MarshallingTests.cpp @@ -289,7 +289,8 @@ auto Serialized = ProtobufMarshaller.toProtobuf(Request); EXPECT_EQ(Serialized.proximity_paths_size(), 2); auto Deserialized = ProtobufMarshaller.fromProtobuf(&Serialized); - EXPECT_THAT(Deserialized.ProximityPaths, + ASSERT_TRUE(Deserialized); + EXPECT_THAT(Deserialized->ProximityPaths, testing::ElementsAre(testPath("remote/Header.h"), testPath("remote/subdir/OtherHeader.h"))); }