Index: clangd/CMakeLists.txt =================================================================== --- clangd/CMakeLists.txt +++ clangd/CMakeLists.txt @@ -9,6 +9,7 @@ add_clang_library(clangDaemon AST.cpp + Cancellation.cpp ClangdLSPServer.cpp ClangdServer.cpp ClangdUnit.cpp Index: clangd/Cancellation.h =================================================================== --- /dev/null +++ clangd/Cancellation.h @@ -0,0 +1,143 @@ +//===--- Cancellation.h -------------------------------------------*-C++-*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// Cancellation mechanism for async tasks. Roughly all the clients of this code +// can be classified into three categories: +// 1. The code that creates and schedules async tasks, e.g. TUScheduler. +// 2. The callers of the async method that can cancel some of the running tasks, +// e.g. `ClangdLSPServer` +// 3. The code running inside the async task itself, i.e. code completion or +// find definition implementation that run clang, etc. +// +// For (1), the guideline is to accept a callback for the result of async +// operation and return a `TaskHandle` to allow cancelling the request. +// +// TaskHandle someAsyncMethod(Task T, +// function)> Callback) { +// // Make sure TaskHandler is created before starting the thread. Otherwise +// // CancellationToken might not get copied into thread. +// auto TH = TaskHandle::create(); +// auto run = [](TaskHandle TH){ +// WithContext ContextWithCancellationToken(std::move(TH)); +// Callback(T()); +// } +// // Start run() in a new thread. You should bind TH.clone() to run since +// // TaskHandle doesn't allow implicit copies. +// return TH; +// } +// +// The callers of async methods (2) can issue cancellations and should be +// prepared to handle `TaskCancelledError` result: +// +// void Caller() { +// // You should store this handle if you wanna cancel the task later on. +// TaskHandle TH = someAsyncMethod(Task, [](llvm::Expected R) { +// if(/*check for task cancellation error*/) +// // Handle the error +// // Do other things on R. +// }); +// // To cancel the task: +// sleep(5); +// TH.cancel(); +// } +// +// The worker code itself (3) should check for cancellations using +// `CancellationToken` that can be retrieved via +// `CancellationToken::isCancelled()`. +// +// llvm::Expected Task() { +// // You can either store the read only token by calling hasCancelled once +// // and just use the variable everytime you want to check for cancellation, +// // or call hasCancelled everytime. The former is more efficient if you are +// // going to have multiple checks. +// const auto CT = CancellationHandler::hasCancelled(); +// // DO SMTHNG... +// if(CT) { +// // Task has been cancelled, lets get out. +// return llvm::makeError(); +// } +// // DO SOME MORE THING... +// if(CT) { +// // Task has been cancelled, lets get out. +// return llvm::makeError(); +// } +// return ResultType(...); +// } +// If the operation was cancelled before task could run to completion, it should +// propagate the TaskCancelledError as a result. + +#ifndef LLVM_CLANG_TOOLS_EXTRA_CLANGD_CANCELLATION_H +#define LLVM_CLANG_TOOLS_EXTRA_CLANGD_CANCELLATION_H + +#include "Context.h" +#include "llvm/Support/Error.h" +#include +#include +#include + +namespace clang { +namespace clangd { + +/// Enables async tasks to check for cancellation signal, contains a read only +/// token cached from context. Can be created with clangd::isCancelled. Totally +/// thread-safe multiple threads can check for cancellation on the same token. +class CancellationToken { +public: + bool isCancelled() const { return Token ? static_cast(*Token) : false; } + operator bool() const { return isCancelled(); } + friend CancellationToken isCancelled(); + +private: + CancellationToken(const std::shared_ptr> Token) + : Token(Token) {} + + std::shared_ptr> Token; +}; + +/// Enables signalling a cancellation on an asyn task. Totally thread-safe, can +/// trigger cancellation from multiple threads or make copies. Since contains a +/// std::shared_ptr inside copying is expensive, therefore must be triggered +/// with clone explicitly. +class TaskHandle { +public: + void cancel(); + static TaskHandle create(); + friend Context setCurrentCancellationToken(TaskHandle TH); + + TaskHandle clone(); + + TaskHandle(const TaskHandle &) = delete; + TaskHandle &operator=(const TaskHandle &) = delete; + TaskHandle(TaskHandle &&) = default; + TaskHandle &operator=(TaskHandle &&) = default; + +private: + TaskHandle() : CT(std::make_shared>()) {} + TaskHandle(std::shared_ptr> CT) : CT(std::move(CT)) {} + std::shared_ptr> CT; +}; + +CancellationToken isCancelled(); +LLVM_NODISCARD Context setCurrentCancellationToken(TaskHandle TH); + +class CancelledError : public llvm::ErrorInfo { +public: + static char ID; + + void log(llvm::raw_ostream &OS) const override { + OS << "Task got cancelled."; + } + std::error_code convertToErrorCode() const override { + return std::make_error_code(std::errc::operation_canceled); + } +}; + +} // namespace clangd +} // namespace clang + +#endif Index: clangd/Cancellation.cpp =================================================================== --- /dev/null +++ clangd/Cancellation.cpp @@ -0,0 +1,40 @@ +//===--- Cancellation.cpp -----------------------------------------*-C++-*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "Cancellation.h" +#include + +namespace clang { +namespace clangd { + +namespace { +static Key>> CancellationTokenKey; +} // namespace + +char CancelledError::ID = 0; + +CancellationToken isCancelled() { + const auto *CT = Context::current().get(CancellationTokenKey); + if (!CT) + return CancellationToken(nullptr); + return CancellationToken(*CT); +} + +Context setCurrentCancellationToken(TaskHandle TH) { + return Context::current().derive(CancellationTokenKey, std::move(TH.CT)); +} + +void TaskHandle::cancel() { *CT = true; } + +TaskHandle TaskHandle::create() { return TaskHandle(); } + +TaskHandle TaskHandle::clone() { return TaskHandle(CT); } + +} // namespace clangd +} // namespace clang Index: clangd/ClangdLSPServer.h =================================================================== --- clangd/ClangdLSPServer.h +++ clangd/ClangdLSPServer.h @@ -75,6 +75,7 @@ void onRename(RenameParams &Parames) override; void onHover(TextDocumentPositionParams &Params) override; void onChangeConfiguration(DidChangeConfigurationParams &Params) override; + void onCancelRequest(CancelParams &Params) override; std::vector getFixes(StringRef File, const clangd::Diagnostic &D); @@ -167,8 +168,17 @@ // the worker thread that may otherwise run an async callback on partially // destructed instance of ClangdLSPServer. ClangdServer Server; -}; + // Holds task handles for running requets. Key of the map is a serialized + // request id. + llvm::StringMap TaskHandles; + std::mutex TaskHandlesMutex; + + // Following two functions are context-aware, they create and delete tokens + // associated with only their thread. + void CleanupTaskHandle(); + void StoreTaskHandle(TaskHandle TH); +}; } // namespace clangd } // namespace clang Index: clangd/ClangdLSPServer.cpp =================================================================== --- clangd/ClangdLSPServer.cpp +++ clangd/ClangdLSPServer.cpp @@ -8,10 +8,12 @@ //===----------------------------------------------------------------------===// #include "ClangdLSPServer.h" +#include "Cancellation.h" #include "Diagnostics.h" #include "JSONRPCDispatcher.h" #include "SourceCode.h" #include "URI.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Errc.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Path.h" @@ -69,6 +71,12 @@ return Defaults; } +const std::string NormalizeRequestID(const json::Value &ID) { + std::string NormalizedID; + llvm::raw_string_ostream OS(NormalizedID); + OS << ID; + return OS.str(); +} } // namespace void ClangdLSPServer::onInitialize(InitializeParams &Params) { @@ -337,17 +345,21 @@ } void ClangdLSPServer::onCompletion(TextDocumentPositionParams &Params) { - Server.codeComplete(Params.textDocument.uri.file(), Params.position, CCOpts, - [this](llvm::Expected List) { - if (!List) - return replyError(ErrorCode::InvalidParams, - llvm::toString(List.takeError())); - CompletionList LSPList; - LSPList.isIncomplete = List->HasMore; - for (const auto &R : List->Completions) - LSPList.items.push_back(R.render(CCOpts)); - reply(std::move(LSPList)); - }); + TaskHandle TH = Server.codeComplete( + Params.textDocument.uri.file(), Params.position, CCOpts, + [this](llvm::Expected List) { + auto _ = llvm::make_scope_exit([this]() { CleanupTaskHandle(); }); + + if (!List) { + return replyError(List.takeError()); + } + CompletionList LSPList; + LSPList.isIncomplete = List->HasMore; + for (const auto &R : List->Completions) + LSPList.items.push_back(R.render(CCOpts)); + return reply(std::move(LSPList)); + }); + StoreTaskHandle(std::move(TH)); } void ClangdLSPServer::onSignatureHelp(TextDocumentPositionParams &Params) { @@ -362,14 +374,14 @@ } void ClangdLSPServer::onGoToDefinition(TextDocumentPositionParams &Params) { - Server.findDefinitions( - Params.textDocument.uri.file(), Params.position, - [](llvm::Expected> Items) { - if (!Items) - return replyError(ErrorCode::InvalidParams, - llvm::toString(Items.takeError())); - reply(json::Array(*Items)); - }); + Server.findDefinitions(Params.textDocument.uri.file(), Params.position, + [](llvm::Expected> Items) { + if (!Items) + return replyError( + ErrorCode::InvalidParams, + llvm::toString(Items.takeError())); + reply(json::Array(*Items)); + }); } void ClangdLSPServer::onSwitchSourceHeader(TextDocumentIdentifier &Params) { @@ -604,3 +616,34 @@ return *CachingCDB; return *CDB; } + +void ClangdLSPServer::onCancelRequest(CancelParams &Params) { + std::lock_guard Lock(TaskHandlesMutex); + const auto &it = TaskHandles.find(Params.ID); + if (it != TaskHandles.end()) { + it->second.cancel(); + TaskHandles.erase(it); + } +} + +void ClangdLSPServer::CleanupTaskHandle() { + const json::Value *ID = GetRequestId(); + if (!ID) + return; + const std::string &NormalizedID = NormalizeRequestID(*ID); + { + std::lock_guard Lock(TaskHandlesMutex); + TaskHandles.erase(NormalizedID); + } +} + +void ClangdLSPServer::StoreTaskHandle(TaskHandle TH) { + const json::Value *ID = GetRequestId(); + if (!ID) + return; + const std::string &NormalizedID = NormalizeRequestID(*ID); + { + std::lock_guard Lock(TaskHandlesMutex); + TaskHandles.insert({NormalizedID, std::move(TH)}); + } +} Index: clangd/ClangdServer.h =================================================================== --- clangd/ClangdServer.h +++ clangd/ClangdServer.h @@ -10,6 +10,7 @@ #ifndef LLVM_CLANG_TOOLS_EXTRA_CLANGD_CLANGDSERVER_H #define LLVM_CLANG_TOOLS_EXTRA_CLANGD_CLANGDSERVER_H +#include "Cancellation.h" #include "ClangdUnit.h" #include "CodeComplete.h" #include "FSProvider.h" @@ -122,9 +123,9 @@ /// while returned future is not yet ready. /// A version of `codeComplete` that runs \p Callback on the processing thread /// when codeComplete results become available. - void codeComplete(PathRef File, Position Pos, - const clangd::CodeCompleteOptions &Opts, - Callback CB); + TaskHandle codeComplete(PathRef File, Position Pos, + const clangd::CodeCompleteOptions &Opts, + Callback CB); /// Provide signature help for \p File at \p Pos. This method should only be /// called for tracked files. Index: clangd/ClangdServer.cpp =================================================================== --- clangd/ClangdServer.cpp +++ clangd/ClangdServer.cpp @@ -8,6 +8,7 @@ //===-------------------------------------------------------------------===// #include "ClangdServer.h" +#include "Cancellation.h" #include "CodeComplete.h" #include "FindSymbols.h" #include "Headers.h" @@ -140,25 +141,31 @@ WorkScheduler.remove(File); } -void ClangdServer::codeComplete(PathRef File, Position Pos, - const clangd::CodeCompleteOptions &Opts, - Callback CB) { +TaskHandle ClangdServer::codeComplete(PathRef File, Position Pos, + const clangd::CodeCompleteOptions &Opts, + Callback CB) { // Copy completion options for passing them to async task handler. auto CodeCompleteOpts = Opts; if (!CodeCompleteOpts.Index) // Respect overridden index. CodeCompleteOpts.Index = Index; + auto CancellableTaskHandle = TaskHandle::create(); // Copy PCHs to avoid accessing this->PCHs concurrently std::shared_ptr PCHs = this->PCHs; auto FS = FSProvider.getFileSystem(); - auto Task = [PCHs, Pos, FS, - CodeCompleteOpts](Path File, Callback CB, - llvm::Expected IP) { + auto Task = [PCHs, Pos, FS, CodeCompleteOpts]( + Path File, Callback CB, TaskHandle TH, + llvm::Expected IP) { if (!IP) return CB(IP.takeError()); auto PreambleData = IP->Preamble; + WithContext ContextWithCancellation( + setCurrentCancellationToken(std::move(TH))); + if (isCancelled()) { + return CB(llvm::make_error()); + } // FIXME(ibiryukov): even if Preamble is non-null, we may want to check // both the old and the new version in case only one of them matches. CodeCompleteResult Result = clangd::codeComplete( @@ -168,8 +175,10 @@ CB(std::move(Result)); }; - WorkScheduler.runWithPreamble("CodeComplete", File, - Bind(Task, File.str(), std::move(CB))); + WorkScheduler.runWithPreamble( + "CodeComplete", File, + Bind(Task, File.str(), std::move(CB), CancellableTaskHandle.clone())); + return CancellableTaskHandle; } void ClangdServer::signatureHelp(PathRef File, Position Pos, Index: clangd/JSONRPCDispatcher.h =================================================================== --- clangd/JSONRPCDispatcher.h +++ clangd/JSONRPCDispatcher.h @@ -64,6 +64,7 @@ /// Sends an error response to the client, and logs it. /// Current context must derive from JSONRPCDispatcher::Handler. void replyError(ErrorCode Code, const llvm::StringRef &Message); +void replyError(llvm::Error E); /// Sends a request to the client. /// Current context must derive from JSONRPCDispatcher::Handler. void call(llvm::StringRef Method, llvm::json::Value &&Params); @@ -111,6 +112,7 @@ JSONStreamStyle InputStyle, JSONRPCDispatcher &Dispatcher, bool &IsDone); +const llvm::json::Value *GetRequestId(); } // namespace clangd } // namespace clang Index: clangd/JSONRPCDispatcher.cpp =================================================================== --- clangd/JSONRPCDispatcher.cpp +++ clangd/JSONRPCDispatcher.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "JSONRPCDispatcher.h" +#include "Cancellation.h" #include "ProtocolHandlers.h" #include "Trace.h" #include "llvm/ADT/SmallString.h" @@ -129,6 +130,14 @@ } } +void clangd::replyError(Error E) { + Error Err = handleErrors(std::move(E), [](const CancelledError &TCE) { + replyError(ErrorCode::RequestCancelled, TCE.message()); + }); + if (Err) + replyError(ErrorCode::InvalidParams, llvm::toString(std::move(Err))); +} + void clangd::call(StringRef Method, json::Value &&Params) { RequestSpan::attach([&](json::Object &Args) { Args["Call"] = json::Object{{"method", Method.str()}, {"params", Params}}; @@ -366,3 +375,7 @@ } } } + +const json::Value *clangd::GetRequestId() { + return Context::current().get(RequestID); +} Index: clangd/Protocol.h =================================================================== --- clangd/Protocol.h +++ clangd/Protocol.h @@ -867,6 +867,13 @@ llvm::json::Value toJSON(const DocumentHighlight &DH); llvm::raw_ostream &operator<<(llvm::raw_ostream &, const DocumentHighlight &); +struct CancelParams { + std::string ID; +}; +llvm::json::Value toJSON(const CancelParams &); +llvm::raw_ostream &operator<<(llvm::raw_ostream &, const CancelParams &); +bool fromJSON(const llvm::json::Value &, CancelParams &); + } // namespace clangd } // namespace clang Index: clangd/Protocol.cpp =================================================================== --- clangd/Protocol.cpp +++ clangd/Protocol.cpp @@ -615,5 +615,30 @@ O.map("compilationDatabaseChanges", CCPC.compilationDatabaseChanges); } +json::Value toJSON(const CancelParams &CP) { + return json::Object{{"id", CP.ID}}; +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &O, const CancelParams &CP) { + O << toJSON(CP); + return O; +} + +bool fromJSON(const json::Value &Params, CancelParams &CP) { + json::ObjectMapper O(Params); + if (!O) + return false; + // ID is either a number or a string, check for both. + if (O.map("id", CP.ID)) + return true; + + int64_t id_number; + if (O.map("id", id_number)) { + CP.ID = utostr(id_number); + return true; + } + return false; +} + } // namespace clangd } // namespace clang Index: clangd/ProtocolHandlers.h =================================================================== --- clangd/ProtocolHandlers.h +++ clangd/ProtocolHandlers.h @@ -55,6 +55,7 @@ virtual void onDocumentHighlight(TextDocumentPositionParams &Params) = 0; virtual void onHover(TextDocumentPositionParams &Params) = 0; virtual void onChangeConfiguration(DidChangeConfigurationParams &Params) = 0; + virtual void onCancelRequest(CancelParams &Params) = 0; }; void registerCallbackHandlers(JSONRPCDispatcher &Dispatcher, Index: clangd/ProtocolHandlers.cpp =================================================================== --- clangd/ProtocolHandlers.cpp +++ clangd/ProtocolHandlers.cpp @@ -75,4 +75,5 @@ Register("workspace/didChangeConfiguration", &ProtocolCallbacks::onChangeConfiguration); Register("workspace/symbol", &ProtocolCallbacks::onWorkspaceSymbol); + Register("$/cancelRequest", &ProtocolCallbacks::onCancelRequest); } Index: unittests/clangd/CMakeLists.txt =================================================================== --- unittests/clangd/CMakeLists.txt +++ unittests/clangd/CMakeLists.txt @@ -10,6 +10,7 @@ add_extra_unittest(ClangdTests Annotations.cpp + CancellationTests.cpp ClangdTests.cpp ClangdUnitTests.cpp CodeCompleteTests.cpp Index: unittests/clangd/CancellationTests.cpp =================================================================== --- /dev/null +++ unittests/clangd/CancellationTests.cpp @@ -0,0 +1,68 @@ +#include "Cancellation.h" +#include "Context.h" +#include "llvm/Support/Error.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include +#include +#include + +namespace clang { +namespace clangd { +namespace { + +TEST(CancellationTest, CancellationTest) { + { + TaskHandle TH = TaskHandle::create(); + WithContext ContextWithCancellation( + setCurrentCancellationToken(TH.clone())); + EXPECT_FALSE(isCancelled()); + TH.cancel(); + EXPECT_TRUE(isCancelled()); + } + EXPECT_FALSE(isCancelled()); +} + +TEST(CancellationTest, TaskHandleTestHandleDiesContextLives) { + llvm::Optional ContextWithCancellation; + { + auto CancellableTaskHandle = TaskHandle::create(); + ContextWithCancellation.emplace( + setCurrentCancellationToken(CancellableTaskHandle.clone())); + EXPECT_FALSE(isCancelled()); + CancellableTaskHandle.cancel(); + EXPECT_TRUE(isCancelled()); + } + EXPECT_TRUE(isCancelled()); + ContextWithCancellation.reset(); + EXPECT_FALSE(isCancelled()); +} + +TEST(CancellationTest, TaskHandleContextDiesHandleLives) { + { + auto CancellableTaskHandle = TaskHandle::create(); + { + WithContext ContextWithCancellation( + setCurrentCancellationToken(CancellableTaskHandle.clone())); + EXPECT_FALSE(isCancelled()); + CancellableTaskHandle.cancel(); + EXPECT_TRUE(isCancelled()); + } + } + EXPECT_FALSE(isCancelled()); +} + +TEST(CancellationTest, CancellationToken) { + auto CancellableTaskHandle = TaskHandle::create(); + WithContext ContextWithCancellation( + setCurrentCancellationToken(CancellableTaskHandle.clone())); + auto CT = isCancelled(); + EXPECT_FALSE(CT); + CancellableTaskHandle.cancel(); + EXPECT_TRUE(CT); +} + +} // namespace +} // namespace clangd +} // namespace clang