Index: clangd/CMakeLists.txt =================================================================== --- clangd/CMakeLists.txt +++ clangd/CMakeLists.txt @@ -9,6 +9,7 @@ add_clang_library(clangDaemon AST.cpp + Cancellation.cpp ClangdLSPServer.cpp ClangdServer.cpp ClangdUnit.cpp Index: clangd/Cancellation.h =================================================================== --- /dev/null +++ clangd/Cancellation.h @@ -0,0 +1,154 @@ +//===--- Cancellation.h -------------------------------------------*-C++-*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// Cancellation mechanism for async tasks. Roughly all the clients of this code +// can be classified into three categories: +// 1. The code that creates and schedules async tasks, e.g. TUScheduler. +// 2. The callers of the async method that can cancel some of the running tasks, +// e.g. `ClangdLSPServer` +// 3. The code running inside the async task itself, i.e. code completion or +// find definition implementation that run clang, etc. +// +// For (1), the guideline is to accept a callback for the result of async +// operation and return a `TaskHandle` to allow cancelling the request. +// +// TaskHandle someAsyncMethod(Task T, +// function)> Callback) { +// // Make sure TaskHandler is created before starting the thread. Otherwise +// // CancellationToken might not get copied into thread. +// auto TH = TaskHandle::create(); +// auto run = [](TaskHandle TH){ +// WithContext ContextWithCancellationToken(std::move(TH)); +// Callback(T()); +// } +// // Start run() in a new thread. You should bind TH.clone() to run since +// // TaskHandle doesn't allow implicit copies. +// return TH; +// } +// +// The callers of async methods (2) can issue cancellations and should be +// prepared to handle `TaskCancelledError` result: +// +// void Caller() { +// // You should store this handle if you wanna cancel the task later on. +// TaskHandle TH = someAsyncMethod(Task, [](llvm::Expected R) { +// if(/*check for task cancellation error*/) +// // Handle the error +// // Do other things on R. +// }); +// // To cancel the task: +// sleep(5); +// TH.cancel(); +// } +// +// The worker code itself (3) should check for cancellations using +// `CancellationToken` that can be retrieved via +// `CancellationToken::isCancelled()`. +// +// llvm::Expected Task() { +// // You can either store the read only token by calling hasCancelled once +// // and just use the variable everytime you want to check for cancellation, +// // or call hasCancelled everytime. The former is more efficient if you are +// // going to have multiple checks. +// const auto CT = CancellationHandler::hasCancelled(); +// // DO SMTHNG... +// if(CT) { +// // Task has been cancelled, lets get out. +// return llvm::makeError(); +// } +// // DO SOME MORE THING... +// if(CT) { +// // Task has been cancelled, lets get out. +// return llvm::makeError(); +// } +// return ResultType(...); +// } +// If the operation was cancelled before task could run to completion, it should +// propagate the TaskCancelledError as a result. + +#ifndef LLVM_CLANG_TOOLS_EXTRA_CLANGD_CANCELLATION_H +#define LLVM_CLANG_TOOLS_EXTRA_CLANGD_CANCELLATION_H + +#include "Context.h" +#include "llvm/Support/Error.h" +#include +#include +#include + +namespace clang { +namespace clangd { + +/// Enables async tasks to check for cancellation signal, contains a read only +/// token cached from context. Tokens for the currently running task can be +/// obtained via clangd::getCurrentCancellationToken. It is thread-safe, +/// multiple threads can check for cancellation on the same token. +/// +/// If cancellation checks are rare, one could use the isCancelled() helper in +/// the namespace to simplify the code. However, if cancellation checks are +/// frequent, the guideline is first obtain the CancellationToken for the +/// currently running task with getCurrentCancellationToken() and do cancel +/// checks using it to avoid extra lookups in the Context. +class CancellationToken { +public: + bool isCancelled() const { return Token ? Token->load() : false; } + friend class TaskHandle; + operator bool() const { return isCancelled(); } + + CancellationToken(const CancellationToken &) = delete; + CancellationToken &operator=(const CancellationToken &) = delete; + CancellationToken(CancellationToken &&) = default; + CancellationToken &operator=(CancellationToken &&) = default; + +private: + CancellationToken(const std::shared_ptr> Token) + : Token(Token) {} + + std::shared_ptr> Token; +}; + +/// Enables signalling a cancellation on an async task. It is thread-safe to +/// trigger cancellation from multiple threads or create cancellation tokens. +class TaskHandle { +public: + void cancel(); + static TaskHandle create(); + CancellationToken createCancellationToken() const; + + TaskHandle(const TaskHandle &) = delete; + TaskHandle &operator=(const TaskHandle &) = delete; + TaskHandle(TaskHandle &&) = default; + TaskHandle &operator=(TaskHandle &&) = default; + +private: + TaskHandle() : CT(std::make_shared>()) {} + TaskHandle(std::shared_ptr> CT) : CT(std::move(CT)) {} + std::shared_ptr> CT; +}; + +const CancellationToken &getCurrentCancellationToken(); +LLVM_NODISCARD Context setCurrentCancellationToken(CancellationToken TH); +inline bool isCancelled() { + return getCurrentCancellationToken().isCancelled(); +} + +class CancelledError : public llvm::ErrorInfo { +public: + static char ID; + + void log(llvm::raw_ostream &OS) const override { + OS << "Task was cancelled."; + } + std::error_code convertToErrorCode() const override { + return std::make_error_code(std::errc::operation_canceled); + } +}; + +} // namespace clangd +} // namespace clang + +#endif Index: clangd/Cancellation.cpp =================================================================== --- /dev/null +++ clangd/Cancellation.cpp @@ -0,0 +1,39 @@ +//===--- Cancellation.cpp -----------------------------------------*-C++-*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "Cancellation.h" +#include + +namespace clang { +namespace clangd { + +namespace { +static Key CancellationTokenKey; +} // namespace + +char CancelledError::ID = 0; + +const CancellationToken &getCurrentCancellationToken() { + return Context::current().getExisting(CancellationTokenKey); +} + +Context setCurrentCancellationToken(CancellationToken CT) { + return Context::current().derive(CancellationTokenKey, std::move(CT)); +} + +void TaskHandle::cancel() { *CT = true; } + +TaskHandle TaskHandle::create() { return TaskHandle(); } + +CancellationToken TaskHandle::createCancellationToken() const { + return CancellationToken(CT); +} + +} // namespace clangd +} // namespace clang Index: clangd/ClangdLSPServer.h =================================================================== --- clangd/ClangdLSPServer.h +++ clangd/ClangdLSPServer.h @@ -75,6 +75,7 @@ void onRename(RenameParams &Parames) override; void onHover(TextDocumentPositionParams &Params) override; void onChangeConfiguration(DidChangeConfigurationParams &Params) override; + void onCancelRequest(CancelParams &Params) override; std::vector getFixes(StringRef File, const clangd::Diagnostic &D); @@ -167,8 +168,20 @@ // the worker thread that may otherwise run an async callback on partially // destructed instance of ClangdLSPServer. ClangdServer Server; -}; + // Holds task handles for running requets. Key of the map is a serialized + // request id. + llvm::StringMap TaskHandles; + std::mutex TaskHandlesMutex; + + // Following two functions are for managing TaskHandles map. They store or + // remove a task handle for the request-id stored in current Context. + // FIXME(kadircet): Wrap the following two functions in a RAII object to make + // sure these do not get misused. The object might be stored in the Context of + // the thread or moved around until a reply is generated for the request. + void CleanupTaskHandle(); + void StoreTaskHandle(TaskHandle TH); +}; } // namespace clangd } // namespace clang Index: clangd/ClangdLSPServer.cpp =================================================================== --- clangd/ClangdLSPServer.cpp +++ clangd/ClangdLSPServer.cpp @@ -8,10 +8,12 @@ //===----------------------------------------------------------------------===// #include "ClangdLSPServer.h" +#include "Cancellation.h" #include "Diagnostics.h" #include "JSONRPCDispatcher.h" #include "SourceCode.h" #include "URI.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Errc.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Path.h" @@ -69,6 +71,15 @@ return Defaults; } +std::string NormalizeRequestID(const json::Value &ID) { + std::string NormalizedID; + llvm::raw_string_ostream OS(NormalizedID); + OS << ID; + OS.flush(); + if (NormalizedID.front() == '"') + NormalizedID = NormalizedID.substr(1, NormalizedID.size() - 2); + return NormalizedID; +} } // namespace void ClangdLSPServer::onInitialize(InitializeParams &Params) { @@ -337,17 +348,20 @@ } void ClangdLSPServer::onCompletion(TextDocumentPositionParams &Params) { - Server.codeComplete(Params.textDocument.uri.file(), Params.position, CCOpts, - [this](llvm::Expected List) { - if (!List) - return replyError(ErrorCode::InvalidParams, - llvm::toString(List.takeError())); - CompletionList LSPList; - LSPList.isIncomplete = List->HasMore; - for (const auto &R : List->Completions) - LSPList.items.push_back(R.render(CCOpts)); - reply(std::move(LSPList)); - }); + TaskHandle TH = Server.codeComplete( + Params.textDocument.uri.file(), Params.position, CCOpts, + [this](llvm::Expected List) { + auto _ = llvm::make_scope_exit([this]() { CleanupTaskHandle(); }); + + if (!List) + return replyError(List.takeError()); + CompletionList LSPList; + LSPList.isIncomplete = List->HasMore; + for (const auto &R : List->Completions) + LSPList.items.push_back(R.render(CCOpts)); + return reply(std::move(LSPList)); + }); + StoreTaskHandle(std::move(TH)); } void ClangdLSPServer::onSignatureHelp(TextDocumentPositionParams &Params) { @@ -362,14 +376,14 @@ } void ClangdLSPServer::onGoToDefinition(TextDocumentPositionParams &Params) { - Server.findDefinitions( - Params.textDocument.uri.file(), Params.position, - [](llvm::Expected> Items) { - if (!Items) - return replyError(ErrorCode::InvalidParams, - llvm::toString(Items.takeError())); - reply(json::Array(*Items)); - }); + Server.findDefinitions(Params.textDocument.uri.file(), Params.position, + [](llvm::Expected> Items) { + if (!Items) + return replyError( + ErrorCode::InvalidParams, + llvm::toString(Items.takeError())); + reply(json::Array(*Items)); + }); } void ClangdLSPServer::onSwitchSourceHeader(TextDocumentIdentifier &Params) { @@ -604,3 +618,31 @@ return *CachingCDB; return *CDB; } + +void ClangdLSPServer::onCancelRequest(CancelParams &Params) { + std::lock_guard Lock(TaskHandlesMutex); + const auto &it = TaskHandles.find(Params.ID); + if (it != TaskHandles.end()) { + it->second.cancel(); + TaskHandles.erase(it); + } +} + +void ClangdLSPServer::CleanupTaskHandle() { + const json::Value *ID = getRequestId(); + if (!ID) + return; + std::string NormalizedID = NormalizeRequestID(*ID); + std::lock_guard Lock(TaskHandlesMutex); + TaskHandles.erase(NormalizedID); +} + +void ClangdLSPServer::StoreTaskHandle(TaskHandle TH) { + const json::Value *ID = getRequestId(); + if (!ID) + return; + std::string NormalizedID = NormalizeRequestID(*ID); + std::lock_guard Lock(TaskHandlesMutex); + if (!TaskHandles.insert({NormalizedID, std::move(TH)}).second) + elog("Insertion of TaskHandle for request: {0} failed.", NormalizedID); +} Index: clangd/ClangdServer.h =================================================================== --- clangd/ClangdServer.h +++ clangd/ClangdServer.h @@ -10,6 +10,7 @@ #ifndef LLVM_CLANG_TOOLS_EXTRA_CLANGD_CLANGDSERVER_H #define LLVM_CLANG_TOOLS_EXTRA_CLANGD_CLANGDSERVER_H +#include "Cancellation.h" #include "ClangdUnit.h" #include "CodeComplete.h" #include "FSProvider.h" @@ -122,9 +123,9 @@ /// while returned future is not yet ready. /// A version of `codeComplete` that runs \p Callback on the processing thread /// when codeComplete results become available. - void codeComplete(PathRef File, Position Pos, - const clangd::CodeCompleteOptions &Opts, - Callback CB); + TaskHandle codeComplete(PathRef File, Position Pos, + const clangd::CodeCompleteOptions &Opts, + Callback CB); /// Provide signature help for \p File at \p Pos. This method should only be /// called for tracked files. Index: clangd/ClangdServer.cpp =================================================================== --- clangd/ClangdServer.cpp +++ clangd/ClangdServer.cpp @@ -8,6 +8,7 @@ //===-------------------------------------------------------------------===// #include "ClangdServer.h" +#include "Cancellation.h" #include "CodeComplete.h" #include "FindSymbols.h" #include "Headers.h" @@ -140,14 +141,17 @@ WorkScheduler.remove(File); } -void ClangdServer::codeComplete(PathRef File, Position Pos, - const clangd::CodeCompleteOptions &Opts, - Callback CB) { +TaskHandle ClangdServer::codeComplete(PathRef File, Position Pos, + const clangd::CodeCompleteOptions &Opts, + Callback CB) { // Copy completion options for passing them to async task handler. auto CodeCompleteOpts = Opts; if (!CodeCompleteOpts.Index) // Respect overridden index. CodeCompleteOpts.Index = Index; + auto TH = TaskHandle::create(); + WithContext ContextWithCancellation( + setCurrentCancellationToken(TH.createCancellationToken())); // Copy PCHs to avoid accessing this->PCHs concurrently std::shared_ptr PCHs = this->PCHs; auto FS = FSProvider.getFileSystem(); @@ -159,6 +163,8 @@ auto PreambleData = IP->Preamble; + if (isCancelled()) + return CB(llvm::make_error()); // FIXME(ibiryukov): even if Preamble is non-null, we may want to check // both the old and the new version in case only one of them matches. CodeCompleteResult Result = clangd::codeComplete( @@ -170,6 +176,7 @@ WorkScheduler.runWithPreamble("CodeComplete", File, Bind(Task, File.str(), std::move(CB))); + return TH; } void ClangdServer::signatureHelp(PathRef File, Position Pos, Index: clangd/JSONRPCDispatcher.h =================================================================== --- clangd/JSONRPCDispatcher.h +++ clangd/JSONRPCDispatcher.h @@ -64,6 +64,13 @@ /// Sends an error response to the client, and logs it. /// Current context must derive from JSONRPCDispatcher::Handler. void replyError(ErrorCode Code, const llvm::StringRef &Message); +/// Implements ErrorCode and message extraction from a given llvm::Error. It +/// fetches the related message from error's message method. If error doesn't +/// match any known errors, uses ErrorCode::InvalidParams for the error. +void replyError(llvm::Error E); +/// Returns the request-id of the current request. Should not be used directly +/// for replying to requests, use the above mentioned methods for that case. +const llvm::json::Value *getRequestId(); /// Sends a request to the client. /// Current context must derive from JSONRPCDispatcher::Handler. void call(llvm::StringRef Method, llvm::json::Value &&Params); @@ -110,7 +117,6 @@ void runLanguageServerLoop(std::FILE *In, JSONOutput &Out, JSONStreamStyle InputStyle, JSONRPCDispatcher &Dispatcher, bool &IsDone); - } // namespace clangd } // namespace clang Index: clangd/JSONRPCDispatcher.cpp =================================================================== --- clangd/JSONRPCDispatcher.cpp +++ clangd/JSONRPCDispatcher.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "JSONRPCDispatcher.h" +#include "Cancellation.h" #include "ProtocolHandlers.h" #include "Trace.h" #include "llvm/ADT/SmallString.h" @@ -93,7 +94,7 @@ } void clangd::reply(json::Value &&Result) { - auto ID = Context::current().get(RequestID); + auto ID = getRequestId(); if (!ID) { elog("Attempted to reply to a notification!"); return; @@ -116,7 +117,7 @@ {"message", Message.str()}}; }); - if (auto ID = Context::current().get(RequestID)) { + if (auto ID = getRequestId()) { log("--> reply({0}) error: {1}", *ID, Message); Context::current() .getExisting(RequestOut) @@ -129,6 +130,16 @@ } } +void clangd::replyError(Error E) { + handleAllErrors(std::move(E), + [](const CancelledError &TCE) { + replyError(ErrorCode::RequestCancelled, TCE.message()); + }, + [](const ErrorInfoBase &EIB) { + replyError(ErrorCode::InvalidParams, EIB.message()); + }); +} + void clangd::call(StringRef Method, json::Value &&Params) { RequestSpan::attach([&](json::Object &Args) { Args["Call"] = json::Object{{"method", Method.str()}, {"params", Params}}; @@ -366,3 +377,7 @@ } } } + +const json::Value *clangd::getRequestId() { + return Context::current().get(RequestID); +} Index: clangd/Protocol.h =================================================================== --- clangd/Protocol.h +++ clangd/Protocol.h @@ -867,6 +867,16 @@ llvm::json::Value toJSON(const DocumentHighlight &DH); llvm::raw_ostream &operator<<(llvm::raw_ostream &, const DocumentHighlight &); +struct CancelParams { + /// The request id to cancel. + /// This can be either a number or string, if it is a number simply print it + /// out and always use a string. + std::string ID; +}; +llvm::json::Value toJSON(const CancelParams &); +llvm::raw_ostream &operator<<(llvm::raw_ostream &, const CancelParams &); +bool fromJSON(const llvm::json::Value &, CancelParams &); + } // namespace clangd } // namespace clang Index: clangd/Protocol.cpp =================================================================== --- clangd/Protocol.cpp +++ clangd/Protocol.cpp @@ -615,5 +615,30 @@ O.map("compilationDatabaseChanges", CCPC.compilationDatabaseChanges); } +json::Value toJSON(const CancelParams &CP) { + return json::Object{{"id", CP.ID}}; +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &O, const CancelParams &CP) { + O << toJSON(CP); + return O; +} + +bool fromJSON(const json::Value &Params, CancelParams &CP) { + json::ObjectMapper O(Params); + if (!O) + return false; + // ID is either a number or a string, check for both. + if (O.map("id", CP.ID)) + return true; + + int64_t id_number; + if (O.map("id", id_number)) { + CP.ID = utostr(id_number); + return true; + } + return false; +} + } // namespace clangd } // namespace clang Index: clangd/ProtocolHandlers.h =================================================================== --- clangd/ProtocolHandlers.h +++ clangd/ProtocolHandlers.h @@ -55,6 +55,7 @@ virtual void onDocumentHighlight(TextDocumentPositionParams &Params) = 0; virtual void onHover(TextDocumentPositionParams &Params) = 0; virtual void onChangeConfiguration(DidChangeConfigurationParams &Params) = 0; + virtual void onCancelRequest(CancelParams &Params) = 0; }; void registerCallbackHandlers(JSONRPCDispatcher &Dispatcher, Index: clangd/ProtocolHandlers.cpp =================================================================== --- clangd/ProtocolHandlers.cpp +++ clangd/ProtocolHandlers.cpp @@ -75,4 +75,5 @@ Register("workspace/didChangeConfiguration", &ProtocolCallbacks::onChangeConfiguration); Register("workspace/symbol", &ProtocolCallbacks::onWorkspaceSymbol); + Register("$/cancelRequest", &ProtocolCallbacks::onCancelRequest); } Index: unittests/clangd/CMakeLists.txt =================================================================== --- unittests/clangd/CMakeLists.txt +++ unittests/clangd/CMakeLists.txt @@ -10,6 +10,7 @@ add_extra_unittest(ClangdTests Annotations.cpp + CancellationTests.cpp ClangdTests.cpp ClangdUnitTests.cpp CodeCompleteTests.cpp Index: unittests/clangd/CancellationTests.cpp =================================================================== --- /dev/null +++ unittests/clangd/CancellationTests.cpp @@ -0,0 +1,79 @@ +#include "Cancellation.h" +#include "Context.h" +#include "Threading.h" +#include "llvm/Support/Error.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include +#include +#include + +namespace clang { +namespace clangd { +namespace { + +TEST(CancellationTest, CancellationTest) { + TaskHandle TH = TaskHandle::create(); + WithContext ContextWithCancellation( + setCurrentCancellationToken(TH.createCancellationToken())); + EXPECT_FALSE(isCancelled()); + TH.cancel(); + EXPECT_TRUE(isCancelled()); +} + +TEST(CancellationTest, TaskHandleTestHandleDiesContextLives) { + llvm::Optional ContextWithCancellation; + { + auto CancellableTaskHandle = TaskHandle::create(); + ContextWithCancellation.emplace(setCurrentCancellationToken( + CancellableTaskHandle.createCancellationToken())); + EXPECT_FALSE(isCancelled()); + CancellableTaskHandle.cancel(); + EXPECT_TRUE(isCancelled()); + } + EXPECT_TRUE(isCancelled()); +} + +TEST(CancellationTest, TaskHandleContextDiesHandleLives) { + auto CancellableTaskHandle = TaskHandle::create(); + { + WithContext ContextWithCancellation(setCurrentCancellationToken( + CancellableTaskHandle.createCancellationToken())); + EXPECT_FALSE(isCancelled()); + CancellableTaskHandle.cancel(); + EXPECT_TRUE(isCancelled()); + } + // Still should be able to cancel without any problems. + CancellableTaskHandle.cancel(); +} + +TEST(CancellationTest, CancellationToken) { + auto CancellableTaskHandle = TaskHandle::create(); + WithContext ContextWithCancellation(setCurrentCancellationToken( + CancellableTaskHandle.createCancellationToken())); + const auto &CT = getCurrentCancellationToken(); + EXPECT_FALSE(CT); + CancellableTaskHandle.cancel(); + EXPECT_TRUE(CT); +} + +TEST(CancellationTest, AsynCancellationTest) { + std::atomic HasCancelled(false); + Notification Cancelled; + auto TaskToBeCancelled = [&](CancellationToken CT) { + WithContext ContextGuard(setCurrentCancellationToken(std::move(CT))); + Cancelled.wait(); + HasCancelled = isCancelled(); + }; + TaskHandle TH = TaskHandle::create(); + std::thread AsyncTask(TaskToBeCancelled, TH.createCancellationToken()); + TH.cancel(); + Cancelled.notify(); + AsyncTask.join(); + + EXPECT_TRUE(HasCancelled); +} + +} // namespace +} // namespace clangd +} // namespace clang