diff --git a/clang-tools-extra/clangd/index/remote/Client.h b/clang-tools-extra/clangd/index/remote/Client.h --- a/clang-tools-extra/clangd/index/remote/Client.h +++ b/clang-tools-extra/clangd/index/remote/Client.h @@ -22,10 +22,13 @@ /// described by the remote index. Paths returned by the index will be treated /// as relative to this directory. /// -/// This method attempts to resolve the address and establish the connection. +/// This method attempts to resolve the address and establish the connection by +/// performing the handshake (sending the data and trying to receive it back). /// -/// \returns nullptr if the address is not resolved during the function call or -/// if the project was compiled without Remote Index support. +/// \returns nullptr if one of the following conditions holds: +/// * Address is not resolved during the function call. +/// * Handshake was not successful or received data is corrupted. +/// * Project was compiled without Remote Index support. std::unique_ptr getClient(llvm::StringRef Address, llvm::StringRef IndexRoot); diff --git a/clang-tools-extra/clangd/index/remote/Client.cpp b/clang-tools-extra/clangd/index/remote/Client.cpp --- a/clang-tools-extra/clangd/index/remote/Client.cpp +++ b/clang-tools-extra/clangd/index/remote/Client.cpp @@ -14,11 +14,14 @@ #include "marshalling/Marshalling.h" #include "support/Logger.h" #include "support/Trace.h" +#include "clang/Basic/LLVM.h" #include "clang/Basic/Version.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" #include +#include +#include namespace clang { namespace clangd { @@ -82,6 +85,35 @@ assert(!ProjectRoot.empty()); } + // Returns true if received checksum matches given one. + bool handshake(uint32_t Checksum) const { + trace::Span Tracer(HandshakeRequest::descriptor()->name()); + HandshakeRequest RPCRequest; + RPCRequest.set_checksum(Checksum); + SPAN_ATTACH(Tracer, "Request", RPCRequest.DebugString()); + grpc::ClientContext Context; + Context.AddMetadata("version", clang::getClangToolFullVersion("clangd")); + const std::chrono::system_clock::time_point StartTime = + std::chrono::system_clock::now(); + const auto Deadline = StartTime + DeadlineWaitingTime; + Context.set_deadline(Deadline); + vlog("Sending RPC Request {0}: {1}", HandshakeRequest::descriptor()->name(), + RPCRequest.DebugString()); + HandshakeReply Reply; + const grpc::Status Status = + Stub.get()->Handshake(&Context, RPCRequest, &Reply); + assert(Checksum == Reply.checksum() && "Received checksum should match."); + const auto Millis = std::chrono::duration_cast( + std::chrono::system_clock::now() - StartTime) + .count(); + const std::string ChecksumMatches = + Checksum == Reply.checksum() ? "OK" : "CHECKSUM MISMATCH"; + vlog("RPC {0} => {1}: {2}ms", HandshakeRequest::descriptor()->name(), + ChecksumMatches, Reply.checksum(), Millis); + SPAN_ATTACH(Tracer, "Status", Status.ok()); + return Status.ok() && Checksum == Reply.checksum(); + } + void lookup(const clangd::LookupRequest &Request, llvm::function_ref Callback) const override { @@ -130,8 +162,15 @@ const auto Channel = grpc::CreateChannel(Address.str(), grpc::InsecureChannelCredentials()); Channel->GetState(true); - return std::unique_ptr( - new IndexClient(Channel, ProjectRoot)); + auto *Idx = new IndexClient(Channel, ProjectRoot); + const uint32_t Checksum = 42; + const bool OK = Idx->handshake(Checksum); + if (!OK) { + elog("Handshake did not finish correctly. Breaking the connection to " + "remote index."); + return nullptr; + } + return std::unique_ptr(Idx); } } // namespace remote diff --git a/clang-tools-extra/clangd/index/remote/Index.proto b/clang-tools-extra/clangd/index/remote/Index.proto --- a/clang-tools-extra/clangd/index/remote/Index.proto +++ b/clang-tools-extra/clangd/index/remote/Index.proto @@ -10,6 +10,14 @@ package clang.clangd.remote; +message HandshakeRequest { + required uint32 checksum = 1; +} + +message HandshakeReply { + required uint32 checksum = 1; +} + // Common final result for streaming requests. message FinalResult { optional bool has_more = 1; } diff --git a/clang-tools-extra/clangd/index/remote/Service.proto b/clang-tools-extra/clangd/index/remote/Service.proto --- a/clang-tools-extra/clangd/index/remote/Service.proto +++ b/clang-tools-extra/clangd/index/remote/Service.proto @@ -15,6 +15,8 @@ // Semantics of SymbolIndex match clangd::SymbolIndex with all required // structures corresponding to their clangd::* counterparts. service SymbolIndex { + rpc Handshake(HandshakeRequest) returns (HandshakeReply) {} + rpc Lookup(LookupRequest) returns (stream LookupReply) {} rpc FuzzyFind(FuzzyFindRequest) returns (stream FuzzyFindReply) {} diff --git a/clang-tools-extra/clangd/index/remote/server/Server.cpp b/clang-tools-extra/clangd/index/remote/server/Server.cpp --- a/clang-tools-extra/clangd/index/remote/server/Server.cpp +++ b/clang-tools-extra/clangd/index/remote/server/Server.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -96,6 +97,19 @@ private: using stopwatch = std::chrono::steady_clock; + grpc::Status Handshake(grpc::ServerContext *Context, + const remote::HandshakeRequest *Request, + HandshakeReply *Reply) override { + WithContextValue WithRequestContext(CurrentRequest, Context); + logRequest(*Request); + trace::Span Tracer("LookupRequest"); + Reply->set_checksum(Request->checksum()); + logResponse(*Reply); + log("[public] request {0} => OK", HandshakeRequest::descriptor()->name()); + SPAN_ATTACH(Tracer, "Checksum", Reply->checksum()); + return grpc::Status::OK; + } + grpc::Status Lookup(grpc::ServerContext *Context, const LookupRequest *Request, grpc::ServerWriter *Reply) override {