Index: clangd/CMakeLists.txt =================================================================== --- clangd/CMakeLists.txt +++ clangd/CMakeLists.txt @@ -9,6 +9,7 @@ add_clang_library(clangDaemon AST.cpp + Cancellation.cpp ClangdLSPServer.cpp ClangdServer.cpp ClangdUnit.cpp Index: clangd/Cancellation.h =================================================================== --- /dev/null +++ clangd/Cancellation.h @@ -0,0 +1,104 @@ +//===--- Cancellation.h -------------------------------------------*-C++-*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// CancellationToken mechanism for async threads. The caller can generate a +// TaskHandle for cancellable tasks, then bind that handle to current context +// and check from every task that are running through that context. +// Later on client can trigger cancel on that handle to tell the async task that +// it has been cancelled. Example use case: +// +// void Caller() { +// // You should store this handle if you wanna cancel the task later on. +// TaskHandle TH = StartAsyncTask(Task); +// // To cancel the task: +// TH.cancel(); +// } +// +// TaskHandle StartAsyncTask(Task T) { +// // Make sure TaskHandler is created before starting the thread. Otherwise +// // CancellationToken might not get copied into thread. +// auto TH = TaskHandle::createCancellableTaskHandle(); +// auto run = [TH](){ +// WithContext ContextWithCancellationToken(std::move(TH)); +// T(); +// } +// // Start run() in a new thread. +// return TH; +// } +// +// void Task() { +// // You can either store the read only token by calling hasCancelled once +// // and just use the variable everytime you want to check for cancellation, +// // or call hasCancelled everytime. The former is more efficient if you are +// // going to have multiple checks. +// const auto CT = CancellationHandler::hasCancelled(); +// // DO SMTHNG... +// if(CT) { +// // Task has benn cancelled, lets get out. +// return; +// } +// // DO SOME MORE THING... +// } + +#ifndef LLVM_CLANG_TOOLS_EXTRA_CLANGD_CANCELLATION_H +#define LLVM_CLANG_TOOLS_EXTRA_CLANGD_CANCELLATION_H + +#include "Context.h" +#include "llvm/Support/Error.h" +#include +#include +#include + +namespace clang { +namespace clangd { + +class CancellationToken { +private: + std::shared_ptr> Token; + +public: + bool isCancelled() const { return Token ? static_cast(*Token) : false; } + operator bool() const { return isCancelled(); } + CancellationToken(const std::shared_ptr> Token) + : Token(Token) {} +}; + +class TaskHandle { +public: + void cancel(); + static TaskHandle createCancellableTaskHandle(); + friend class CancellationHandler; + +private: + TaskHandle() : CT(std::make_shared>()) {} + std::shared_ptr> CT; +}; + +class CancellationHandler { +public: + static CancellationToken isCancelled(); + LLVM_NODISCARD static Context setCurrentCancellationToken(TaskHandle TH); + static llvm::Error getCancellationError(); +}; + +class TaskCancelledError : public llvm::ErrorInfo { +public: + static char ID; + + void log(llvm::raw_ostream &OS) const override { + OS << "Task got cancelled."; + } + std::error_code convertToErrorCode() const override { + return std::make_error_code(std::errc::operation_canceled); + } +}; + +} // namespace clangd +} // namespace clang + +#endif Index: clangd/Cancellation.cpp =================================================================== --- /dev/null +++ clangd/Cancellation.cpp @@ -0,0 +1,42 @@ +//===--- Cancellation.cpp -----------------------------------------*-C++-*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "Cancellation.h" +#include + +namespace clang { +namespace clangd { + +namespace { +static Key>> CancellationTokenKey; +} // namespace + +char TaskCancelledError::ID = 0; + +CancellationToken CancellationHandler::isCancelled() { + const auto *CT = Context::current().get(CancellationTokenKey); + if (!CT) + return CancellationToken(nullptr); + return CancellationToken(*CT); +} + +Context CancellationHandler::setCurrentCancellationToken(TaskHandle TH) { + return Context::current().derive(CancellationTokenKey, std::move(TH.CT)); +} + +llvm::Error CancellationHandler::getCancellationError() { + return llvm::make_error(); +} + +void TaskHandle::cancel() { *CT = true; } + +TaskHandle TaskHandle::createCancellableTaskHandle() { return TaskHandle(); } + +} // namespace clangd +} // namespace clang Index: clangd/ClangdLSPServer.h =================================================================== --- clangd/ClangdLSPServer.h +++ clangd/ClangdLSPServer.h @@ -75,6 +75,7 @@ void onRename(RenameParams &Parames) override; void onHover(TextDocumentPositionParams &Params) override; void onChangeConfiguration(DidChangeConfigurationParams &Params) override; + void onCancelRequest(CancelParams &Params) override; std::vector getFixes(StringRef File, const clangd::Diagnostic &D); @@ -167,8 +168,17 @@ // the worker thread that may otherwise run an async callback on partially // destructed instance of ClangdLSPServer. ClangdServer Server; -}; + // Holds task handles for running requets. Key of the map is a serialized + // request id. + llvm::StringMap TaskHandles; + std::mutex TaskHandlesMutex; + + // Following two functions are context-aware, they create and delete tokens + // associated with only their thread. + void CleanupTaskHandle(); + void StoreTaskHandle(TaskHandle TH); +}; } // namespace clangd } // namespace clang Index: clangd/ClangdLSPServer.cpp =================================================================== --- clangd/ClangdLSPServer.cpp +++ clangd/ClangdLSPServer.cpp @@ -8,10 +8,12 @@ //===---------------------------------------------------------------------===// #include "ClangdLSPServer.h" +#include "Cancellation.h" #include "Diagnostics.h" #include "JSONRPCDispatcher.h" #include "SourceCode.h" #include "URI.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Errc.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Path.h" @@ -69,6 +71,12 @@ return Defaults; } +const std::string NormalizeRequestID(const json::Value &ID) { + std::string NormalizedID; + llvm::raw_string_ostream OS(NormalizedID); + OS << ID; + return OS.str(); +} } // namespace void ClangdLSPServer::onInitialize(InitializeParams &Params) { @@ -337,17 +345,29 @@ } void ClangdLSPServer::onCompletion(TextDocumentPositionParams &Params) { - Server.codeComplete(Params.textDocument.uri.file(), Params.position, CCOpts, - [this](llvm::Expected List) { - if (!List) - return replyError(ErrorCode::InvalidParams, - llvm::toString(List.takeError())); - CompletionList LSPList; - LSPList.isIncomplete = List->HasMore; - for (const auto &R : List->Completions) - LSPList.items.push_back(R.render(CCOpts)); - reply(std::move(LSPList)); - }); + TaskHandle TH = Server.codeComplete( + Params.textDocument.uri.file(), Params.position, CCOpts, + [this](llvm::Expected List) { + auto _ = llvm::make_scope_exit([this]() { CleanupTaskHandle(); }); + + if (!List) { + Error Uncaught = + handleErrors(List.takeError(), [](const TaskCancelledError &) { + replyError(ErrorCode::RequestCancelled, + "Request got cancelled"); + }); + if (Uncaught) + replyError(ErrorCode::InvalidParams, + llvm::toString(List.takeError())); + return; + } + CompletionList LSPList; + LSPList.isIncomplete = List->HasMore; + for (const auto &R : List->Completions) + LSPList.items.push_back(R.render(CCOpts)); + reply(std::move(LSPList)); + }); + StoreTaskHandle(std::move(TH)); } void ClangdLSPServer::onSignatureHelp(TextDocumentPositionParams &Params) { @@ -362,14 +382,14 @@ } void ClangdLSPServer::onGoToDefinition(TextDocumentPositionParams &Params) { - Server.findDefinitions( - Params.textDocument.uri.file(), Params.position, - [](llvm::Expected> Items) { - if (!Items) - return replyError(ErrorCode::InvalidParams, - llvm::toString(Items.takeError())); - reply(json::Array(*Items)); - }); + Server.findDefinitions(Params.textDocument.uri.file(), Params.position, + [](llvm::Expected> Items) { + if (!Items) + return replyError( + ErrorCode::InvalidParams, + llvm::toString(Items.takeError())); + reply(json::Array(*Items)); + }); } void ClangdLSPServer::onSwitchSourceHeader(TextDocumentIdentifier &Params) { @@ -602,3 +622,34 @@ return *CachingCDB; return *CDB; } + +void ClangdLSPServer::onCancelRequest(CancelParams &Params) { + std::lock_guard _(TaskHandlesMutex); + const auto &it = TaskHandles.find(Params.ID); + if (it != TaskHandles.end()) { + it->second.cancel(); + TaskHandles.erase(it); + } +} + +void ClangdLSPServer::CleanupTaskHandle() { + const json::Value *ID = GetRequestId(); + if (!ID) + return; + const std::string &NormalizedID = NormalizeRequestID(*ID); + { + std::lock_guard _(TaskHandlesMutex); + TaskHandles.erase(NormalizedID); + } +} + +void ClangdLSPServer::StoreTaskHandle(TaskHandle TH) { + const json::Value *ID = GetRequestId(); + if (!ID) + return; + const std::string &NormalizedID = NormalizeRequestID(*ID); + { + std::lock_guard _(TaskHandlesMutex); + TaskHandles.insert({NormalizedID, std::move(TH)}); + } +} Index: clangd/ClangdServer.h =================================================================== --- clangd/ClangdServer.h +++ clangd/ClangdServer.h @@ -10,6 +10,7 @@ #ifndef LLVM_CLANG_TOOLS_EXTRA_CLANGD_CLANGDSERVER_H #define LLVM_CLANG_TOOLS_EXTRA_CLANGD_CLANGDSERVER_H +#include "Cancellation.h" #include "ClangdUnit.h" #include "CodeComplete.h" #include "FSProvider.h" @@ -122,9 +123,9 @@ /// while returned future is not yet ready. /// A version of `codeComplete` that runs \p Callback on the processing thread /// when codeComplete results become available. - void codeComplete(PathRef File, Position Pos, - const clangd::CodeCompleteOptions &Opts, - Callback CB); + TaskHandle codeComplete(PathRef File, Position Pos, + const clangd::CodeCompleteOptions &Opts, + Callback CB); /// Provide signature help for \p File at \p Pos. This method should only be /// called for tracked files. Index: clangd/ClangdServer.cpp =================================================================== --- clangd/ClangdServer.cpp +++ clangd/ClangdServer.cpp @@ -8,6 +8,7 @@ //===-------------------------------------------------------------------===// #include "ClangdServer.h" +#include "Cancellation.h" #include "CodeComplete.h" #include "FindSymbols.h" #include "Headers.h" @@ -140,25 +141,32 @@ WorkScheduler.remove(File); } -void ClangdServer::codeComplete(PathRef File, Position Pos, - const clangd::CodeCompleteOptions &Opts, - Callback CB) { +TaskHandle ClangdServer::codeComplete(PathRef File, Position Pos, + const clangd::CodeCompleteOptions &Opts, + Callback CB) { // Copy completion options for passing them to async task handler. auto CodeCompleteOpts = Opts; if (!CodeCompleteOpts.Index) // Respect overridden index. CodeCompleteOpts.Index = Index; + auto CancellableTaskHandle = TaskHandle::createCancellableTaskHandle(); // Copy PCHs to avoid accessing this->PCHs concurrently std::shared_ptr PCHs = this->PCHs; auto FS = FSProvider.getFileSystem(); - auto Task = [PCHs, Pos, FS, - CodeCompleteOpts](Path File, Callback CB, - llvm::Expected IP) { + auto Task = [PCHs, Pos, FS, CodeCompleteOpts, CancellableTaskHandle]( + Path File, Callback CB, + llvm::Expected IP) { if (!IP) return CB(IP.takeError()); auto PreambleData = IP->Preamble; + WithContext ContextWithCancellation( + CancellationHandler::setCurrentCancellationToken( + std::move(CancellableTaskHandle))); + if (CancellationHandler::isCancelled()) { + return CB(CancellationHandler::getCancellationError()); + } // FIXME(ibiryukov): even if Preamble is non-null, we may want to check // both the old and the new version in case only one of them matches. CodeCompleteResult Result = clangd::codeComplete( @@ -170,6 +178,7 @@ WorkScheduler.runWithPreamble("CodeComplete", File, Bind(Task, File.str(), std::move(CB))); + return CancellableTaskHandle; } void ClangdServer::signatureHelp(PathRef File, Position Pos, Index: clangd/JSONRPCDispatcher.h =================================================================== --- clangd/JSONRPCDispatcher.h +++ clangd/JSONRPCDispatcher.h @@ -111,6 +111,7 @@ JSONStreamStyle InputStyle, JSONRPCDispatcher &Dispatcher, bool &IsDone); +const llvm::json::Value *GetRequestId(); } // namespace clangd } // namespace clang Index: clangd/JSONRPCDispatcher.cpp =================================================================== --- clangd/JSONRPCDispatcher.cpp +++ clangd/JSONRPCDispatcher.cpp @@ -366,3 +366,7 @@ } } } + +const json::Value *clangd::GetRequestId() { + return Context::current().get(RequestID); +} Index: clangd/Protocol.h =================================================================== --- clangd/Protocol.h +++ clangd/Protocol.h @@ -861,6 +861,13 @@ llvm::json::Value toJSON(const DocumentHighlight &DH); llvm::raw_ostream &operator<<(llvm::raw_ostream &, const DocumentHighlight &); +struct CancelParams { + std::string ID; +}; +llvm::json::Value toJSON(const CancelParams &); +llvm::raw_ostream &operator<<(llvm::raw_ostream &, const CancelParams &); +bool fromJSON(const llvm::json::Value &, CancelParams &); + } // namespace clangd } // namespace clang Index: clangd/Protocol.cpp =================================================================== --- clangd/Protocol.cpp +++ clangd/Protocol.cpp @@ -615,5 +615,30 @@ O.map("compilationDatabaseChanges", CCPC.compilationDatabaseChanges); } +json::Value toJSON(const CancelParams &CP) { + return json::Object{{"id", CP.ID}}; +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &O, const CancelParams &CP) { + O << toJSON(CP); + return O; +} + +bool fromJSON(const json::Value &Params, CancelParams &CP) { + json::ObjectMapper O(Params); + if (!O) + return false; + // ID is either a number or a string, check for both. + if (O.map("id", CP.ID)) + return true; + + int64_t id_number; + if (O.map("id", id_number)) { + CP.ID = utostr(id_number); + return true; + } + return false; +} + } // namespace clangd } // namespace clang Index: clangd/ProtocolHandlers.h =================================================================== --- clangd/ProtocolHandlers.h +++ clangd/ProtocolHandlers.h @@ -55,6 +55,7 @@ virtual void onDocumentHighlight(TextDocumentPositionParams &Params) = 0; virtual void onHover(TextDocumentPositionParams &Params) = 0; virtual void onChangeConfiguration(DidChangeConfigurationParams &Params) = 0; + virtual void onCancelRequest(CancelParams &Params) = 0; }; void registerCallbackHandlers(JSONRPCDispatcher &Dispatcher, Index: clangd/ProtocolHandlers.cpp =================================================================== --- clangd/ProtocolHandlers.cpp +++ clangd/ProtocolHandlers.cpp @@ -75,4 +75,5 @@ Register("workspace/didChangeConfiguration", &ProtocolCallbacks::onChangeConfiguration); Register("workspace/symbol", &ProtocolCallbacks::onWorkspaceSymbol); + Register("$/cancelRequest", &ProtocolCallbacks::onCancelRequest); } Index: unittests/clangd/CMakeLists.txt =================================================================== --- unittests/clangd/CMakeLists.txt +++ unittests/clangd/CMakeLists.txt @@ -10,6 +10,7 @@ add_extra_unittest(ClangdTests Annotations.cpp + CancellationTests.cpp ClangdTests.cpp ClangdUnitTests.cpp CodeCompleteTests.cpp Index: unittests/clangd/CancellationTests.cpp =================================================================== --- /dev/null +++ unittests/clangd/CancellationTests.cpp @@ -0,0 +1,76 @@ +#include "Cancellation.h" +#include "Context.h" +#include "llvm/Support/Error.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include +#include +#include + +namespace clang { +namespace clangd { +namespace { + +TEST(CancellationTest, CancellationTest) { + { + TaskHandle TH = TaskHandle::createCancellableTaskHandle(); + WithContext ContextWithCancellation( + CancellationHandler::setCurrentCancellationToken(TH)); + EXPECT_FALSE(CancellationHandler::isCancelled()); + TH.cancel(); + EXPECT_TRUE(CancellationHandler::isCancelled()); + } + EXPECT_FALSE(CancellationHandler::isCancelled()); +} + +TEST(CancellationTest, CheckForError) { + llvm::Error e = handleErrors(CancellationHandler::getCancellationError(), + [](const TaskCancelledError &) {}); + EXPECT_FALSE(e); +} + +TEST(CancellationTest, TaskHandleTestHandleDiesContextLives) { + llvm::Optional ContextWithCancellation; + { + auto CancellableTaskHandle = TaskHandle::createCancellableTaskHandle(); + ContextWithCancellation.emplace( + CancellationHandler::setCurrentCancellationToken( + CancellableTaskHandle)); + EXPECT_FALSE(CancellationHandler::isCancelled()); + CancellableTaskHandle.cancel(); + EXPECT_TRUE(CancellationHandler::isCancelled()); + } + EXPECT_TRUE(CancellationHandler::isCancelled()); + ContextWithCancellation.reset(); + EXPECT_FALSE(CancellationHandler::isCancelled()); +} + +TEST(CancellationTest, TaskHandleContextDiesHandleLives) { + { + auto CancellableTaskHandle = TaskHandle::createCancellableTaskHandle(); + { + WithContext ContextWithCancellation( + CancellationHandler::setCurrentCancellationToken( + CancellableTaskHandle)); + EXPECT_FALSE(CancellationHandler::isCancelled()); + CancellableTaskHandle.cancel(); + EXPECT_TRUE(CancellationHandler::isCancelled()); + } + } + EXPECT_FALSE(CancellationHandler::isCancelled()); +} + +TEST(CancellationTest, CancellationToken) { + auto CancellableTaskHandle = TaskHandle::createCancellableTaskHandle(); + WithContext ContextWithCancellation( + CancellationHandler::setCurrentCancellationToken(CancellableTaskHandle)); + auto CT = CancellationHandler::isCancelled(); + EXPECT_FALSE(CT); + CancellableTaskHandle.cancel(); + EXPECT_TRUE(CT); +} + +} // namespace +} // namespace clangd +} // namespace clang