Index: clangd/CMakeLists.txt =================================================================== --- clangd/CMakeLists.txt +++ clangd/CMakeLists.txt @@ -9,6 +9,7 @@ add_clang_library(clangDaemon AST.cpp + Cancellation.cpp ClangdLSPServer.cpp ClangdServer.cpp ClangdUnit.cpp Index: clangd/Cancellation.h =================================================================== --- /dev/null +++ clangd/Cancellation.h @@ -0,0 +1,42 @@ +//===--- Cancellation.h -------------------------------------------*-C++-*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_TOOLS_EXTRA_CLANGD_CANCELLATION_H +#define LLVM_CLANG_TOOLS_EXTRA_CLANGD_CANCELLATION_H + +#include "Context.h" +#include "llvm/Support/Error.h" +#include <memory> + +namespace clang { +namespace clangd { + +class CancellationHandler { +public: + static bool HasCancelled(); + static Context LLVM_NODISCARD SetCurrentCancellationToken( + std::shared_ptr<std::atomic<bool>> CancellationToken); + static llvm::Error GetCancellationError(); +}; + +class TaskCancelledError : public llvm::ErrorInfo<TaskCancelledError> { +public: + static char ID; + + void log(llvm::raw_ostream &OS) const override { + OS << "Task got cancelled."; + } + std::error_code convertToErrorCode() const override { + llvm_unreachable("Tried to get error code on TaskCancelledError"); + } +}; +} // namespace clangd +} // namespace clang + +#endif Index: clangd/Cancellation.cpp =================================================================== --- /dev/null +++ clangd/Cancellation.cpp @@ -0,0 +1,41 @@ +//===--- Cancellation.cpp -----------------------------------------*-C++-*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "Cancellation.h" +#include "Logger.h" +#include <atomic> + +namespace clang { +namespace clangd { + +namespace { +static Key<std::shared_ptr<std::atomic<bool>>> CancellationTokenKey; +} // namespace + +char TaskCancelledError::ID = 0; + +bool CancellationHandler::HasCancelled() { + const auto *CancellationToken = Context::current().get(CancellationTokenKey); + if (!CancellationToken) + return false; + return **CancellationToken; +} + +Context CancellationHandler::SetCurrentCancellationToken( + std::shared_ptr<std::atomic<bool>> CancellationToken) { + Context Ctx = Context::current().clone(); + return Ctx.derive(CancellationTokenKey, CancellationToken); +} + +llvm::Error CancellationHandler::GetCancellationError() { + return llvm::make_error<TaskCancelledError>(); +} + +} // namespace clangd +} // namespace clang Index: clangd/ClangdLSPServer.h =================================================================== --- clangd/ClangdLSPServer.h +++ clangd/ClangdLSPServer.h @@ -75,6 +75,7 @@ void onRename(RenameParams &Parames) override; void onHover(TextDocumentPositionParams &Params) override; void onChangeConfiguration(DidChangeConfigurationParams &Params) override; + void onCancelRequest(CancelParams &Params) override; std::vector<Fix> getFixes(StringRef File, const clangd::Diagnostic &D); @@ -165,8 +166,16 @@ // the worker thread that may otherwise run an async callback on partially // destructed instance of ClangdLSPServer. ClangdServer Server; -}; + // Holds cancellation tokens for requests. + llvm::StringMap<std::shared_ptr<std::atomic<bool>>> CancellationTokens; + std::mutex CancellationTokensMutex; + + // Following two functions are context-aware, they create and delete tokens + // associated with only their thread. + void CleanupCancellationToken(); + Context GenerateCancellationToken(); +}; } // namespace clangd } // namespace clang Index: clangd/ClangdLSPServer.cpp =================================================================== --- clangd/ClangdLSPServer.cpp +++ clangd/ClangdLSPServer.cpp @@ -8,10 +8,12 @@ //===---------------------------------------------------------------------===// #include "ClangdLSPServer.h" +#include "Cancellation.h" #include "Diagnostics.h" #include "JSONRPCDispatcher.h" #include "SourceCode.h" #include "URI.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Errc.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Path.h" @@ -69,6 +71,11 @@ return Defaults; } +const std::string NormalizeRequestID(const json::Value &ID) { + return ID.kind() == json::Value::Number + ? utostr(static_cast<int64_t>(ID.getAsNumber().getValue())) + : std::string(ID.getAsString().getValue()); +} } // namespace void ClangdLSPServer::onInitialize(InitializeParams &Params) { @@ -335,11 +342,23 @@ } void ClangdLSPServer::onCompletion(TextDocumentPositionParams &Params) { + WithContext ContextWithCancellation(GenerateCancellationToken()); Server.codeComplete(Params.textDocument.uri.file(), Params.position, CCOpts, [this](llvm::Expected<CodeCompleteResult> List) { - if (!List) - return replyError(ErrorCode::InvalidParams, - llvm::toString(List.takeError())); + auto _ = llvm::make_scope_exit( + [this]() { CleanupCancellationToken(); }); + + if (!List) { + Error Uncaught = handleErrors( + List.takeError(), [](const TaskCancelledError &) { + replyError(ErrorCode::RequestCancelled, + "Request got cancelled"); + }); + if (Uncaught) + replyError(ErrorCode::InvalidParams, + llvm::toString(List.takeError())); + return; + } CompletionList LSPList; LSPList.isIncomplete = List->HasMore; for (const auto &R : List->Completions) @@ -586,3 +605,36 @@ return *CachingCDB; return *CDB; } + +void ClangdLSPServer::onCancelRequest(CancelParams &Params) { + std::lock_guard<std::mutex> _(CancellationTokensMutex); + const auto &it = CancellationTokens.find(Params.ID); + if (it != CancellationTokens.end()) { + *it->second = true; + CancellationTokens.erase(it); + } +} + +void ClangdLSPServer::CleanupCancellationToken() { + const json::Value *ID = GetRequestId(); + if (!ID) + return; + const std::string &NormalizedID = NormalizeRequestID(*ID); + { + std::lock_guard<std::mutex> _(CancellationTokensMutex); + CancellationTokens.erase(NormalizedID); + } +} + +Context ClangdLSPServer::GenerateCancellationToken() { + const json::Value *ID = GetRequestId(); + if (!ID) + return Context::current().clone(); + const std::string &NormalizedID = NormalizeRequestID(*ID); + const auto CancellationToken = std::make_shared<std::atomic<bool>>(); + { + std::lock_guard<std::mutex> _(CancellationTokensMutex); + CancellationTokens[NormalizedID] = CancellationToken; + } + return CancellationHandler::SetCurrentCancellationToken(CancellationToken); +} Index: clangd/ClangdServer.cpp =================================================================== --- clangd/ClangdServer.cpp +++ clangd/ClangdServer.cpp @@ -8,6 +8,7 @@ //===-------------------------------------------------------------------===// #include "ClangdServer.h" +#include "Cancellation.h" #include "CodeComplete.h" #include "FindSymbols.h" #include "Headers.h" @@ -159,6 +160,9 @@ auto PreambleData = IP->Preamble; + if (CancellationHandler::HasCancelled()) { + return CB(CancellationHandler::GetCancellationError()); + } // FIXME(ibiryukov): even if Preamble is non-null, we may want to check // both the old and the new version in case only one of them matches. CodeCompleteResult Result = clangd::codeComplete( Index: clangd/JSONRPCDispatcher.h =================================================================== --- clangd/JSONRPCDispatcher.h +++ clangd/JSONRPCDispatcher.h @@ -111,6 +111,7 @@ JSONStreamStyle InputStyle, JSONRPCDispatcher &Dispatcher, bool &IsDone); +const llvm::json::Value *GetRequestId(); } // namespace clangd } // namespace clang Index: clangd/JSONRPCDispatcher.cpp =================================================================== --- clangd/JSONRPCDispatcher.cpp +++ clangd/JSONRPCDispatcher.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "JSONRPCDispatcher.h" +#include "Cancellation.h" #include "ProtocolHandlers.h" #include "Trace.h" #include "llvm/ADT/SmallString.h" @@ -366,3 +367,7 @@ } } } + +const json::Value *clangd::GetRequestId() { + return Context::current().get(RequestID); +} Index: clangd/Protocol.h =================================================================== --- clangd/Protocol.h +++ clangd/Protocol.h @@ -846,6 +846,13 @@ llvm::json::Value toJSON(const DocumentHighlight &DH); llvm::raw_ostream &operator<<(llvm::raw_ostream &, const DocumentHighlight &); +struct CancelParams { + std::string ID; +}; +llvm::json::Value toJSON(const CancelParams &); +llvm::raw_ostream &operator<<(llvm::raw_ostream &, const CancelParams &); +bool fromJSON(const llvm::json::Value &, CancelParams &); + } // namespace clangd } // namespace clang Index: clangd/Protocol.cpp =================================================================== --- clangd/Protocol.cpp +++ clangd/Protocol.cpp @@ -605,5 +605,31 @@ O.map("compilationDatabaseChanges", CCPC.compilationDatabaseChanges); } +json::Value toJSON(const CancelParams &CP) { + return json::Object{{"id", CP.ID}}; +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &O, const CancelParams &CP) { + O << toJSON(CP); + return O; +} + +bool fromJSON(const json::Value &Params, CancelParams &CP) { + elog("Cancel params: {0}", Params); + json::ObjectMapper O(Params); + if (!O) + return false; + // ID is either a number or a string, check for both. + if (O.map("id", CP.ID)) + return true; + + int64_t id_number; + if (O.map("id", id_number)) { + CP.ID = utostr(id_number); + return true; + } + return false; +} + } // namespace clangd } // namespace clang Index: clangd/ProtocolHandlers.h =================================================================== --- clangd/ProtocolHandlers.h +++ clangd/ProtocolHandlers.h @@ -55,6 +55,7 @@ virtual void onDocumentHighlight(TextDocumentPositionParams &Params) = 0; virtual void onHover(TextDocumentPositionParams &Params) = 0; virtual void onChangeConfiguration(DidChangeConfigurationParams &Params) = 0; + virtual void onCancelRequest(CancelParams &Params) = 0; }; void registerCallbackHandlers(JSONRPCDispatcher &Dispatcher, Index: clangd/ProtocolHandlers.cpp =================================================================== --- clangd/ProtocolHandlers.cpp +++ clangd/ProtocolHandlers.cpp @@ -75,4 +75,5 @@ Register("workspace/didChangeConfiguration", &ProtocolCallbacks::onChangeConfiguration); Register("workspace/symbol", &ProtocolCallbacks::onWorkspaceSymbol); + Register("$/cancelRequest", &ProtocolCallbacks::onCancelRequest); } Index: unittests/clangd/CMakeLists.txt =================================================================== --- unittests/clangd/CMakeLists.txt +++ unittests/clangd/CMakeLists.txt @@ -10,6 +10,7 @@ add_extra_unittest(ClangdTests Annotations.cpp + CancellationTests.cpp ClangdTests.cpp ClangdUnitTests.cpp CodeCompleteTests.cpp Index: unittests/clangd/CancellationTests.cpp =================================================================== --- /dev/null +++ unittests/clangd/CancellationTests.cpp @@ -0,0 +1,34 @@ +#include "Cancellation.h" +#include "Context.h" +#include "llvm/Support/Error.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include <atomic> +#include <iostream> + +namespace clang { +namespace clangd { +namespace { + +TEST(CancellationTest, CancellationTest) { + { + std::shared_ptr<std::atomic<bool>> CancellationToken = + std::make_shared<std::atomic<bool>>(); + WithContext ContextWithCancellation( + CancellationHandler::SetCurrentCancellationToken(CancellationToken)); + *CancellationToken = true; + EXPECT_TRUE(CancellationHandler::HasCancelled()); + } + EXPECT_FALSE(CancellationHandler::HasCancelled()); +} + +TEST(CancellationTest, CheckForError) { + llvm::Error e = handleErrors(CancellationHandler::GetCancellationError(), + [](const TaskCancelledError &) {}); + EXPECT_TRUE(!e); +} + +} // namespace +} // namespace clangd +} // namespace clang