Index: clangd/CMakeLists.txt =================================================================== --- clangd/CMakeLists.txt +++ clangd/CMakeLists.txt @@ -9,6 +9,7 @@ add_clang_library(clangDaemon AST.cpp + Cancellation.cpp ClangdLSPServer.cpp ClangdServer.cpp ClangdUnit.cpp Index: clangd/Cancellation.h =================================================================== --- /dev/null +++ clangd/Cancellation.h @@ -0,0 +1,142 @@ +//===--- Cancellation.h -------------------------------------------*-C++-*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// Cancellation mechanism for async tasks. Roughly all the clients of this code +// can be classified into three categories: +// 1. The code that creates and schedules async tasks, e.g. TUScheduler. +// 2. The callers of the async method that can cancel some of the running tasks, +// e.g. `ClangdLSPServer` +// 3. The code running inside the async task itself, i.e. code completion or +// find definition implementation that run clang, etc. +// +// For (1), the guideline is to accept a callback for the result of async +// operation and return a `TaskHandle` to allow cancelling the request. +// +// TaskHandle someAsyncMethod(Runnable T, +// function)> Callback) { +// auto TH = Task::createHandle(); +// WithContext ContextWithCancellationToken(TH); +// auto run = [](){ +// Callback(T()); +// } +// // Start run() in a new async thread, and make sure to propagate Context. +// return TH; +// } +// +// The callers of async methods (2) can issue cancellations and should be +// prepared to handle `TaskCancelledError` result: +// +// void Caller() { +// // You should store this handle if you wanna cancel the task later on. +// TaskHandle TH = someAsyncMethod(Task, [](llvm::Expected R) { +// if(/*check for task cancellation error*/) +// // Handle the error +// // Do other things on R. +// }); +// // To cancel the task: +// sleep(5); +// TH->cancel(); +// } +// +// The worker code itself (3) should check for cancellations using +// `Task::isCancelled` that can be retrieved via `getCurrentTask()`. +// +// llvm::Expected AsyncTask() { +// // You can either store the read only TaskHandle by calling getCurrentTask +// // once and just use the variable everytime you want to check for +// // cancellation, or call isCancelled everytime. The former is more +// // efficient if you are going to have multiple checks. +// const auto T = getCurrentTask(); +// // DO SMTHNG... +// if(T.isCancelled()) { +// // Task has been cancelled, lets get out. +// return llvm::makeError(); +// } +// // DO SOME MORE THING... +// if(T.isCancelled()) { +// // Task has been cancelled, lets get out. +// return llvm::makeError(); +// } +// return ResultType(...); +// } +// If the operation was cancelled before task could run to completion, it should +// propagate the TaskCancelledError as a result. + +#ifndef LLVM_CLANG_TOOLS_EXTRA_CLANGD_CANCELLATION_H +#define LLVM_CLANG_TOOLS_EXTRA_CLANGD_CANCELLATION_H + +#include "Context.h" +#include "llvm/Support/Error.h" +#include +#include +#include + +namespace clang { +namespace clangd { + +/// Enables signalling a cancellation on an async task or checking for +/// cancellation. It is thread-safe to trigger cancellation from multiple +/// threads or check for cancellation. Task object for the currently running +/// task can be obtained via clangd::getCurrentTask(). +class Task { +public: + void cancel() { CT = true; } + /// If cancellation checks are rare, one could use the isCancelled() helper in + /// the namespace to simplify the code. However, if cancellation checks are + /// frequent, the guideline is first obtain the Task object for the currently + /// running task with getCurrentTask() and do cancel checks using it to avoid + /// extra lookups in the Context. + bool isCancelled() const { return CT; } + + /// Creates a task handle that can be used by an asyn task to check for + /// information that can change during it's runtime, like Cancellation. + static std::shared_ptr createHandle() { + return std::shared_ptr(new Task()); + } + + Task(const Task &) = delete; + Task &operator=(const Task &) = delete; + Task(Task &&) = delete; + Task &operator=(Task &&) = delete; + +private: + Task() : CT(false) {} + std::atomic CT; +}; +using ConstTaskHandle = std::shared_ptr; +using TaskHandle = std::shared_ptr; + +/// Fetches current task information from Context. TaskHandle must have been +/// stashed into context beforehand. +const Task &getCurrentTask(); + +/// Stashes current task information within the context. +LLVM_NODISCARD Context setCurrentTask(ConstTaskHandle TH); + +/// Checks whether the current task has been cancelled or not. +/// Consider storing the task handler returned by getCurrentTask and then +/// calling isCancelled through it. getCurrentTask is expensive since it does a +/// lookup in the context. +inline bool isCancelled() { return getCurrentTask().isCancelled(); } + +class CancelledError : public llvm::ErrorInfo { +public: + static char ID; + + void log(llvm::raw_ostream &OS) const override { + OS << "Task was cancelled."; + } + std::error_code convertToErrorCode() const override { + return std::make_error_code(std::errc::operation_canceled); + } +}; + +} // namespace clangd +} // namespace clang + +#endif Index: clangd/Cancellation.cpp =================================================================== --- /dev/null +++ clangd/Cancellation.cpp @@ -0,0 +1,34 @@ +//===--- Cancellation.cpp -----------------------------------------*-C++-*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "Cancellation.h" +#include + +namespace clang { +namespace clangd { + +namespace { +static Key TaskKey; +} // namespace + +char CancelledError::ID = 0; + +const Task &getCurrentTask() { + const auto TH = Context::current().getExisting(TaskKey); + assert(TH && "Fetched a nullptr for TaskHandle from context."); + return *TH; +} + +Context setCurrentTask(ConstTaskHandle TH) { + assert(TH && "Trying to stash a nullptr as TaskHandle into context."); + return Context::current().derive(TaskKey, std::move(TH)); +} + +} // namespace clangd +} // namespace clang Index: clangd/ClangdLSPServer.h =================================================================== --- clangd/ClangdLSPServer.h +++ clangd/ClangdLSPServer.h @@ -75,6 +75,7 @@ void onRename(RenameParams &Parames) override; void onHover(TextDocumentPositionParams &Params) override; void onChangeConfiguration(DidChangeConfigurationParams &Params) override; + void onCancelRequest(CancelParams &Params) override; std::vector getFixes(StringRef File, const clangd::Diagnostic &D); @@ -167,8 +168,22 @@ // the worker thread that may otherwise run an async callback on partially // destructed instance of ClangdLSPServer. ClangdServer Server; -}; + // Holds task handles for running requets. Key of the map is a serialized + // request id. + llvm::StringMap TaskHandles; + std::mutex TaskHandlesMutex; + + // Following three functions are for managing TaskHandles map. They store or + // remove a task handle for the request-id stored in current Context. + // FIXME(kadircet): Wrap the following three functions in a RAII object to + // make sure these do not get misused. The object might be stored in the + // Context of the thread or moved around until a reply is generated for the + // request. + void CleanupTaskHandle(); + void CreateSpaceForTaskHandle(); + void StoreTaskHandle(TaskHandle TH); +}; } // namespace clangd } // namespace clang Index: clangd/ClangdLSPServer.cpp =================================================================== --- clangd/ClangdLSPServer.cpp +++ clangd/ClangdLSPServer.cpp @@ -8,10 +8,12 @@ //===----------------------------------------------------------------------===// #include "ClangdLSPServer.h" +#include "Cancellation.h" #include "Diagnostics.h" #include "JSONRPCDispatcher.h" #include "SourceCode.h" #include "URI.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Errc.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Path.h" @@ -69,6 +71,12 @@ return Defaults; } +std::string NormalizeRequestID(const json::Value &ID) { + std::string NormalizedID; + assert(parseNumberOrString(json::Object{{"id", ID}}, NormalizedID, "id") && + "Was not able to parse RequestID"); + return NormalizedID; +} } // namespace void ClangdLSPServer::onInitialize(InitializeParams &Params) { @@ -337,17 +345,21 @@ } void ClangdLSPServer::onCompletion(TextDocumentPositionParams &Params) { - Server.codeComplete(Params.textDocument.uri.file(), Params.position, CCOpts, - [this](llvm::Expected List) { - if (!List) - return replyError(ErrorCode::InvalidParams, - llvm::toString(List.takeError())); - CompletionList LSPList; - LSPList.isIncomplete = List->HasMore; - for (const auto &R : List->Completions) - LSPList.items.push_back(R.render(CCOpts)); - reply(std::move(LSPList)); - }); + CreateSpaceForTaskHandle(); + TaskHandle TH = Server.codeComplete( + Params.textDocument.uri.file(), Params.position, CCOpts, + [this](llvm::Expected List) { + auto _ = llvm::make_scope_exit([this]() { CleanupTaskHandle(); }); + + if (!List) + return replyError(List.takeError()); + CompletionList LSPList; + LSPList.isIncomplete = List->HasMore; + for (const auto &R : List->Completions) + LSPList.items.push_back(R.render(CCOpts)); + return reply(std::move(LSPList)); + }); + StoreTaskHandle(std::move(TH)); } void ClangdLSPServer::onSignatureHelp(TextDocumentPositionParams &Params) { @@ -362,14 +374,14 @@ } void ClangdLSPServer::onGoToDefinition(TextDocumentPositionParams &Params) { - Server.findDefinitions( - Params.textDocument.uri.file(), Params.position, - [](llvm::Expected> Items) { - if (!Items) - return replyError(ErrorCode::InvalidParams, - llvm::toString(Items.takeError())); - reply(json::Array(*Items)); - }); + Server.findDefinitions(Params.textDocument.uri.file(), Params.position, + [](llvm::Expected> Items) { + if (!Items) + return replyError( + ErrorCode::InvalidParams, + llvm::toString(Items.takeError())); + reply(json::Array(*Items)); + }); } void ClangdLSPServer::onSwitchSourceHeader(TextDocumentIdentifier &Params) { @@ -604,3 +616,49 @@ return *CachingCDB; return *CDB; } + +void ClangdLSPServer::onCancelRequest(CancelParams &Params) { + std::lock_guard Lock(TaskHandlesMutex); + const auto &It = TaskHandles.find(Params.ID); + if (It == TaskHandles.end()) + return; + if (It->second) + It->second->cancel(); + TaskHandles.erase(It); +} + +void ClangdLSPServer::CleanupTaskHandle() { + const json::Value *ID = getRequestId(); + if (!ID) + return; + std::string NormalizedID = NormalizeRequestID(*ID); + std::lock_guard Lock(TaskHandlesMutex); + TaskHandles.erase(NormalizedID); +} + +void ClangdLSPServer::CreateSpaceForTaskHandle() { + const json::Value *ID = getRequestId(); + if (!ID) + return; + std::string NormalizedID = NormalizeRequestID(*ID); + std::lock_guard Lock(TaskHandlesMutex); + if (!TaskHandles.insert({NormalizedID, nullptr}).second) + elog("Creation of space for task handle: {0} failed.", NormalizedID); +} + +void ClangdLSPServer::StoreTaskHandle(TaskHandle TH) { + const json::Value *ID = getRequestId(); + if (!ID) + return; + std::string NormalizedID = NormalizeRequestID(*ID); + std::lock_guard Lock(TaskHandlesMutex); + auto It = TaskHandles.find(NormalizedID); + if (It == TaskHandles.end()) { + elog("CleanupTaskHandle called before store can happen for request:{0}.", + NormalizedID); + return; + } + if (It->second != nullptr) + elog("TaskHandle didn't get cleared for: {0}.", NormalizedID); + It->second = std::move(TH); +} Index: clangd/ClangdServer.h =================================================================== --- clangd/ClangdServer.h +++ clangd/ClangdServer.h @@ -10,6 +10,7 @@ #ifndef LLVM_CLANG_TOOLS_EXTRA_CLANGD_CLANGDSERVER_H #define LLVM_CLANG_TOOLS_EXTRA_CLANGD_CLANGDSERVER_H +#include "Cancellation.h" #include "ClangdUnit.h" #include "CodeComplete.h" #include "FSProvider.h" @@ -122,9 +123,9 @@ /// while returned future is not yet ready. /// A version of `codeComplete` that runs \p Callback on the processing thread /// when codeComplete results become available. - void codeComplete(PathRef File, Position Pos, - const clangd::CodeCompleteOptions &Opts, - Callback CB); + TaskHandle codeComplete(PathRef File, Position Pos, + const clangd::CodeCompleteOptions &Opts, + Callback CB); /// Provide signature help for \p File at \p Pos. This method should only be /// called for tracked files. Index: clangd/ClangdServer.cpp =================================================================== --- clangd/ClangdServer.cpp +++ clangd/ClangdServer.cpp @@ -8,6 +8,7 @@ //===-------------------------------------------------------------------===// #include "ClangdServer.h" +#include "Cancellation.h" #include "CodeComplete.h" #include "FindSymbols.h" #include "Headers.h" @@ -140,20 +141,25 @@ WorkScheduler.remove(File); } -void ClangdServer::codeComplete(PathRef File, Position Pos, - const clangd::CodeCompleteOptions &Opts, - Callback CB) { +TaskHandle ClangdServer::codeComplete(PathRef File, Position Pos, + const clangd::CodeCompleteOptions &Opts, + Callback CB) { // Copy completion options for passing them to async task handler. auto CodeCompleteOpts = Opts; if (!CodeCompleteOpts.Index) // Respect overridden index. CodeCompleteOpts.Index = Index; + TaskHandle TH = Task::createHandle(); + WithContext ContextWithCancellation(setCurrentTask(TH)); // Copy PCHs to avoid accessing this->PCHs concurrently std::shared_ptr PCHs = this->PCHs; auto FS = FSProvider.getFileSystem(); auto Task = [PCHs, Pos, FS, CodeCompleteOpts](Path File, Callback CB, llvm::Expected IP) { + if (isCancelled()) + return CB(llvm::make_error()); + if (!IP) return CB(IP.takeError()); @@ -170,6 +176,7 @@ WorkScheduler.runWithPreamble("CodeComplete", File, Bind(Task, File.str(), std::move(CB))); + return TH; } void ClangdServer::signatureHelp(PathRef File, Position Pos, Index: clangd/JSONRPCDispatcher.h =================================================================== --- clangd/JSONRPCDispatcher.h +++ clangd/JSONRPCDispatcher.h @@ -64,6 +64,13 @@ /// Sends an error response to the client, and logs it. /// Current context must derive from JSONRPCDispatcher::Handler. void replyError(ErrorCode Code, const llvm::StringRef &Message); +/// Implements ErrorCode and message extraction from a given llvm::Error. It +/// fetches the related message from error's message method. If error doesn't +/// match any known errors, uses ErrorCode::InvalidParams for the error. +void replyError(llvm::Error E); +/// Returns the request-id of the current request. Should not be used directly +/// for replying to requests, use the above mentioned methods for that case. +const llvm::json::Value *getRequestId(); /// Sends a request to the client. /// Current context must derive from JSONRPCDispatcher::Handler. void call(llvm::StringRef Method, llvm::json::Value &&Params); @@ -110,7 +117,6 @@ void runLanguageServerLoop(std::FILE *In, JSONOutput &Out, JSONStreamStyle InputStyle, JSONRPCDispatcher &Dispatcher, bool &IsDone); - } // namespace clangd } // namespace clang Index: clangd/JSONRPCDispatcher.cpp =================================================================== --- clangd/JSONRPCDispatcher.cpp +++ clangd/JSONRPCDispatcher.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "JSONRPCDispatcher.h" +#include "Cancellation.h" #include "ProtocolHandlers.h" #include "Trace.h" #include "llvm/ADT/SmallString.h" @@ -93,7 +94,7 @@ } void clangd::reply(json::Value &&Result) { - auto ID = Context::current().get(RequestID); + auto ID = getRequestId(); if (!ID) { elog("Attempted to reply to a notification!"); return; @@ -116,7 +117,7 @@ {"message", Message.str()}}; }); - if (auto ID = Context::current().get(RequestID)) { + if (auto ID = getRequestId()) { log("--> reply({0}) error: {1}", *ID, Message); Context::current() .getExisting(RequestOut) @@ -129,6 +130,16 @@ } } +void clangd::replyError(Error E) { + handleAllErrors(std::move(E), + [](const CancelledError &TCE) { + replyError(ErrorCode::RequestCancelled, TCE.message()); + }, + [](const ErrorInfoBase &EIB) { + replyError(ErrorCode::InvalidParams, EIB.message()); + }); +} + void clangd::call(StringRef Method, json::Value &&Params) { RequestSpan::attach([&](json::Object &Args) { Args["Call"] = json::Object{{"method", Method.str()}, {"params", Params}}; @@ -366,3 +377,7 @@ } } } + +const json::Value *clangd::getRequestId() { + return Context::current().get(RequestID); +} Index: clangd/Protocol.h =================================================================== --- clangd/Protocol.h +++ clangd/Protocol.h @@ -867,6 +867,22 @@ llvm::json::Value toJSON(const DocumentHighlight &DH); llvm::raw_ostream &operator<<(llvm::raw_ostream &, const DocumentHighlight &); +struct CancelParams { + /// The request id to cancel. + /// This can be either a number or string, if it is a number simply print it + /// out and always use a string. + std::string ID; +}; +llvm::json::Value toJSON(const CancelParams &); +llvm::raw_ostream &operator<<(llvm::raw_ostream &, const CancelParams &); +bool fromJSON(const llvm::json::Value &, CancelParams &); + +/// Parses the Field in Params into Parsed. Params[Field] can be either of type +/// string or number. Returns true if parsing was succesful. In case of a number +/// converts it into a string. +bool parseNumberOrString(const llvm::json::Value &Params, std::string &Parsed, + const std::string &Field); + } // namespace clangd } // namespace clang Index: clangd/Protocol.cpp =================================================================== --- clangd/Protocol.cpp +++ clangd/Protocol.cpp @@ -615,5 +615,35 @@ O.map("compilationDatabaseChanges", CCPC.compilationDatabaseChanges); } +json::Value toJSON(const CancelParams &CP) { + return json::Object{{"id", CP.ID}}; +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &O, const CancelParams &CP) { + O << toJSON(CP); + return O; +} + +bool parseNumberOrString(const json::Value &Params, std::string &Parsed, + const std::string &Field) { + json::ObjectMapper O(Params); + if (!O) + return false; + // ID is either a number or a string, check for both. + if (O.map(Field, Parsed)) + return true; + + int64_t id_number; + if (O.map(Field, id_number)) { + Parsed = utostr(id_number); + return true; + } + return false; +} + +bool fromJSON(const json::Value &Params, CancelParams &CP) { + return parseNumberOrString(Params, CP.ID, "id"); +} + } // namespace clangd } // namespace clang Index: clangd/ProtocolHandlers.h =================================================================== --- clangd/ProtocolHandlers.h +++ clangd/ProtocolHandlers.h @@ -55,6 +55,7 @@ virtual void onDocumentHighlight(TextDocumentPositionParams &Params) = 0; virtual void onHover(TextDocumentPositionParams &Params) = 0; virtual void onChangeConfiguration(DidChangeConfigurationParams &Params) = 0; + virtual void onCancelRequest(CancelParams &Params) = 0; }; void registerCallbackHandlers(JSONRPCDispatcher &Dispatcher, Index: clangd/ProtocolHandlers.cpp =================================================================== --- clangd/ProtocolHandlers.cpp +++ clangd/ProtocolHandlers.cpp @@ -75,4 +75,5 @@ Register("workspace/didChangeConfiguration", &ProtocolCallbacks::onChangeConfiguration); Register("workspace/symbol", &ProtocolCallbacks::onWorkspaceSymbol); + Register("$/cancelRequest", &ProtocolCallbacks::onCancelRequest); } Index: unittests/clangd/CMakeLists.txt =================================================================== --- unittests/clangd/CMakeLists.txt +++ unittests/clangd/CMakeLists.txt @@ -10,6 +10,7 @@ add_extra_unittest(ClangdTests Annotations.cpp + CancellationTests.cpp ClangdTests.cpp ClangdUnitTests.cpp CodeCompleteTests.cpp Index: unittests/clangd/CancellationTests.cpp =================================================================== --- /dev/null +++ unittests/clangd/CancellationTests.cpp @@ -0,0 +1,74 @@ +#include "Cancellation.h" +#include "Context.h" +#include "Threading.h" +#include "llvm/Support/Error.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include +#include +#include + +namespace clang { +namespace clangd { +namespace { + +TEST(CancellationTest, CancellationTest) { + TaskHandle TH = Task::createHandle(); + WithContext ContextWithCancellation(setCurrentTask(TH)); + EXPECT_FALSE(isCancelled()); + TH->cancel(); + EXPECT_TRUE(isCancelled()); +} + +TEST(CancellationTest, TaskTestHandleDiesContextLives) { + llvm::Optional ContextWithCancellation; + { + TaskHandle TH = Task::createHandle(); + ContextWithCancellation.emplace(setCurrentTask(TH)); + EXPECT_FALSE(isCancelled()); + TH->cancel(); + EXPECT_TRUE(isCancelled()); + } + EXPECT_TRUE(isCancelled()); +} + +TEST(CancellationTest, TaskContextDiesHandleLives) { + TaskHandle TH = Task::createHandle(); + { + WithContext ContextWithCancellation(setCurrentTask(TH)); + EXPECT_FALSE(isCancelled()); + TH->cancel(); + EXPECT_TRUE(isCancelled()); + } + // Still should be able to cancel without any problems. + TH->cancel(); +} + +TEST(CancellationTest, CancellationToken) { + TaskHandle TH = Task::createHandle(); + WithContext ContextWithCancellation(setCurrentTask(TH)); + const auto &CT = getCurrentTask(); + EXPECT_FALSE(CT.isCancelled()); + TH->cancel(); + EXPECT_TRUE(CT.isCancelled()); +} + +TEST(CancellationTest, AsynCancellationTest) { + std::atomic HasCancelled(false); + Notification Cancelled; + auto TaskToBeCancelled = [&](ConstTaskHandle CT) { + WithContext ContextGuard(setCurrentTask(std::move(CT))); + Cancelled.wait(); + HasCancelled = isCancelled(); + }; + TaskHandle TH = Task::createHandle(); + std::thread AsyncTask(TaskToBeCancelled, TH); + TH->cancel(); + Cancelled.notify(); + AsyncTask.join(); + + EXPECT_TRUE(HasCancelled); +} +} // namespace +} // namespace clangd +} // namespace clang