Index: clangd/CMakeLists.txt =================================================================== --- clangd/CMakeLists.txt +++ clangd/CMakeLists.txt @@ -25,6 +25,7 @@ clangSerialization clangTooling clangToolingCore + clangToolingRefactor ${LLVM_PTHREAD_LIB} ) Index: clangd/ClangdLSPServer.h =================================================================== --- clangd/ClangdLSPServer.h +++ clangd/ClangdLSPServer.h @@ -66,6 +66,7 @@ void onGoToDefinition(Ctx C, TextDocumentPositionParams &Params) override; void onSwitchSourceHeader(Ctx C, TextDocumentIdentifier &Params) override; void onFileEvent(Ctx C, DidChangeWatchedFilesParams &Params) override; + void onWorkspaceExecuteCommand(Ctx C, ExecuteCommandParams &Params) override; std::vector getFixIts(StringRef File, const clangd::Diagnostic &D); Index: clangd/ClangdLSPServer.cpp =================================================================== --- clangd/ClangdLSPServer.cpp +++ clangd/ClangdLSPServer.cpp @@ -127,6 +127,21 @@ R"(", [)" + Edits + R"(]]},)"; } + std::vector EditorCommands = + Server.getCommands(Params.textDocument.uri.file, Params.range); + for (const auto &Command : EditorCommands) { + Commands += + R"({"title":")" + Command.title + R"(", "command": "clangd.)" + + Command.command + R"(", "arguments": [)"; + Commands += CommandArgument::unparse( + CommandArgument::makeDocumentID(Params.textDocument)); + Commands += ","; + for (const auto &Arg : Command.arguments) { + Commands += CommandArgument::unparse(Arg); + Commands += ", "; + } + Commands += R"(]},)"; + } if (!Commands.empty()) Commands.pop_back(); C.reply("[" + Commands + "]"); @@ -187,6 +202,29 @@ C.reply(Result ? URI::unparse(URI::fromFile(*Result)) : R"("")"); } +void ClangdLSPServer::onWorkspaceExecuteCommand(Ctx C, + ExecuteCommandParams &Params) { + Optional SelectionRange; + Optional DocID; + for (const auto &Arg : Params.arguments) { + switch (Arg.ArgumentKind) { + case CommandArgument::SelectionRangeKind: + SelectionRange = Arg.getSelectionRange(); + break; + case CommandArgument::TextDocumentIdentifierKind: + DocID = Arg.getDocumentID(); + break; + } + } + if (!SelectionRange || !DocID) + return; + std::string Code = Server.getDocument(DocID->uri.file); + std::string Edits = replacementsToEdits( + Code, + Server.executeCommand(Params.command, DocID->uri.file, *SelectionRange)); + C.reply("[" + Edits + "]"); +} + ClangdLSPServer::ClangdLSPServer(JSONOutput &Out, unsigned AsyncThreadsCount, bool SnippetCompletions, llvm::Optional ResourceDir, Index: clangd/ClangdServer.h =================================================================== --- clangd/ClangdServer.h +++ clangd/ClangdServer.h @@ -15,6 +15,7 @@ #include "GlobalCompilationDatabase.h" #include "clang/Tooling/CompilationDatabase.h" #include "clang/Tooling/Core/Replacement.h" +#include "clang/Tooling/Refactoring/EditorClient.h" #include "llvm/ADT/IntrusiveRefCntPtr.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/StringRef.h" @@ -34,6 +35,10 @@ namespace clang { class PCHContainerOperations; +namespace tooling { +class RefactoringEditorClient; +} + namespace clangd { class Logger; @@ -291,6 +296,13 @@ /// Called when an event occurs for a watched file in the workspace. void onFileEvent(const DidChangeWatchedFilesParams &Params); + /// Returns the list of available commands that can be performed in a file. + std::vector getCommands(PathRef File, Range SelectionRange); + + /// Execute an editor command in the specified file and the range. + std::vector + executeCommand(StringRef CommandName, PathRef File, Range SelectionRange); + private: std::future scheduleReparseAndDiags(PathRef File, VersionedDraft Contents, @@ -320,6 +332,8 @@ // called before all other members to stop the worker thread that references // ClangdServer ClangdScheduler WorkScheduler; + /// The editor client that allows Clangd to use Clang's refactoring actions. + std::unique_ptr RefactoringClient; }; } // namespace clangd Index: clangd/ClangdServer.cpp =================================================================== --- clangd/ClangdServer.cpp +++ clangd/ClangdServer.cpp @@ -483,3 +483,48 @@ // FIXME: Do nothing for now. This will be used for indexing and potentially // invalidating other caches. } + +std::vector ClangdServer::getCommands(PathRef File, + Range SelectionRange) { + auto FileContents = DraftMgr.getDraft(File); + assert(FileContents.Draft && "getCommands is called for non-added document"); + std::shared_ptr Resources = Units.getFile(File); + assert(Resources && "Calling getCommands on non-added file"); + + std::vector Results; + Resources->getAST().get()->runUnderLock([SelectionRange, &Results, + this](ParsedAST *AST) { + if (!AST) + return; + if (!RefactoringClient) + RefactoringClient = llvm::make_unique(); + Results = clangd::findAvailableRefactoringCommands(*RefactoringClient, *AST, + SelectionRange, Logger); + }); + return Results; +} + +std::vector +ClangdServer::executeCommand(StringRef CommandName, PathRef File, + Range SelectionRange) { + auto FileContents = DraftMgr.getDraft(File); + assert(FileContents.Draft && + "executeCommand is called for non-added document"); + std::shared_ptr Resources = Units.getFile(File); + assert(Resources && "executeCommand getCommands on non-added file"); + + std::vector Results; + Resources->getAST().get()->runUnderLock([SelectionRange, CommandName, + &Results, this](ParsedAST *AST) { + if (!AST) + return; + assert(RefactoringClient && "no ref client?"); + if (!RefactoringClient) + RefactoringClient = llvm::make_unique(); + Results = + clangd::performRefactoringCommand(*RefactoringClient, CommandName, *AST, + + SelectionRange, Logger); + }); + return Results; +} Index: clangd/ClangdUnit.h =================================================================== --- clangd/ClangdUnit.h +++ clangd/ClangdUnit.h @@ -18,6 +18,7 @@ #include "clang/Serialization/ASTBitCodes.h" #include "clang/Tooling/CompilationDatabase.h" #include "clang/Tooling/Core/Replacement.h" +#include "clang/Tooling/Refactoring/AtomicChange.h" #include #include #include @@ -36,6 +37,7 @@ namespace tooling { struct CompileCommand; +class RefactoringEditorClient; } namespace clangd { @@ -275,6 +277,19 @@ /// unserialized Decls, so use with care. void dumpAST(ParsedAST &AST, llvm::raw_ostream &OS); +/// Get the list of available refactoring actions for the given range \p +/// SelectionRange. +std::vector +findAvailableRefactoringCommands(tooling::RefactoringEditorClient &Client, + ParsedAST &AST, Range SelectionRange, + clangd::Logger &Logger); + +/// Performs the specified refactoring command and returns the source changes. +std::vector +performRefactoringCommand(tooling::RefactoringEditorClient &Client, + StringRef CommandName, ParsedAST &AST, + Range SelectionRange, clangd::Logger &Logger); + } // namespace clangd } // namespace clang #endif Index: clangd/ClangdUnit.cpp =================================================================== --- clangd/ClangdUnit.cpp +++ clangd/ClangdUnit.cpp @@ -23,6 +23,8 @@ #include "clang/Sema/Sema.h" #include "clang/Serialization/ASTWriter.h" #include "clang/Tooling/CompilationDatabase.h" +#include "clang/Tooling/Refactoring/EditorClient.h" +#include "clang/Tooling/Refactoring/EditorCommands.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/CrashRecoveryContext.h" @@ -1022,6 +1024,59 @@ return DeclLocationsFinder->takeLocations(); } +std::vector clangd::findAvailableRefactoringCommands( + tooling::RefactoringEditorClient &Client, ParsedAST &AST, + Range SelectionRange, clangd::Logger &Logger) { + const SourceManager &SourceMgr = AST.getASTContext().getSourceManager(); + const FileEntry *FE = SourceMgr.getFileEntryForID(SourceMgr.getMainFileID()); + if (!FE) + return {}; + SourceLocation Begin = + getMacroArgExpandedLocation(SourceMgr, FE, SelectionRange.start); + SourceLocation End = + getMacroArgExpandedLocation(SourceMgr, FE, SelectionRange.end); + std::vector EditorCommands = + Client.getAvailableRefactorings(AST.getASTContext(), + SourceRange(Begin, End)); + std::vector Results; + for (const tooling::EditorCommand *Cmd : EditorCommands) { + Command Result; + Result.title = Cmd->getTitle(); + Result.command = (llvm::Twine("refactor.") + Cmd->getName()).str(); + Result.arguments = {CommandArgument::makeSelectionRange(SelectionRange)}; + Results.push_back(std::move(Result)); + } + return Results; +} + +std::vector clangd::performRefactoringCommand( + tooling::RefactoringEditorClient &Client, StringRef CommandName, + ParsedAST &AST, Range SelectionRange, clangd::Logger &Logger) { + const SourceManager &SourceMgr = AST.getASTContext().getSourceManager(); + const FileEntry *FE = SourceMgr.getFileEntryForID(SourceMgr.getMainFileID()); + if (!FE) + return {}; + SourceLocation Begin = + getMacroArgExpandedLocation(SourceMgr, FE, SelectionRange.start); + SourceLocation End = + getMacroArgExpandedLocation(SourceMgr, FE, SelectionRange.end); + Expected Changes = Client.performRefactoring( + AST.getASTContext(), CommandName.drop_front(strlen("clangd.refactor.")), + SourceRange(Begin, End)); + if (!Changes) { + // FIXME: Propage the errors to the user. + llvm::consumeError(Changes.takeError()); + return {}; + } + std::vector Replacements; + for (const tooling::AtomicChange &Change : *Changes) { + tooling::Replacements ChangeReps = Change.getReplacements(); + for (const auto &Rep : ChangeReps) + Replacements.push_back(Rep); + } + return Replacements; +} + void ParsedAST::ensurePreambleDeclsDeserialized() { if (PendingTopLevelDecls.empty()) return; Index: clangd/Protocol.h =================================================================== --- clangd/Protocol.h +++ clangd/Protocol.h @@ -530,6 +530,72 @@ static std::string unparse(const SignatureHelp &); }; +/// Represents an editor command argument. +struct CommandArgument { + union { + llvm::AlignedCharArrayUnion SelectionRange; + llvm::AlignedCharArrayUnion DocumentID; + }; + enum Kind { SelectionRangeKind, TextDocumentIdentifierKind }; + Kind ArgumentKind; + + CommandArgument() {} + + static CommandArgument makeSelectionRange(Range R) { + CommandArgument Result; + *reinterpret_cast(Result.SelectionRange.buffer) = R; + Result.ArgumentKind = SelectionRangeKind; + return Result; + } + + const Range &getSelectionRange() const { + assert(ArgumentKind == SelectionRangeKind); + return *reinterpret_cast(SelectionRange.buffer); + } + + static CommandArgument makeDocumentID(TextDocumentIdentifier ID) { + CommandArgument Result; + *reinterpret_cast(Result.DocumentID.buffer) = ID; + Result.ArgumentKind = TextDocumentIdentifierKind; + return Result; + } + + const TextDocumentIdentifier &getDocumentID() const { + assert(ArgumentKind == TextDocumentIdentifierKind); + return *reinterpret_cast(DocumentID.buffer); + } + + static std::string unparse(const CommandArgument &); + + static llvm::Optional parse(llvm::yaml::MappingNode *Params, + clangd::Logger &Logger); +}; + +/// Represents an editor command. +struct Command { + /// The title of the command. + std::string title; + + /// The identifier of the actual command handler. + std::string command; + + /// The arguments that the command handler should be invoked with. + std::vector arguments; +}; + +/// Represents an editor command execution request. +struct ExecuteCommandParams { + /// The identifier of the actual command handler. + std::string command; + + /// Arguments that the command handler should be invoked with. + /// Not currently used by clangd. + std::vector arguments; + + static llvm::Optional + parse(llvm::yaml::MappingNode *Params, clangd::Logger &Logger); +}; + } // namespace clangd } // namespace clang Index: clangd/Protocol.cpp =================================================================== --- clangd/Protocol.cpp +++ clangd/Protocol.cpp @@ -973,3 +973,96 @@ Result.push_back('}'); return Result; } + +std::string CommandArgument::unparse(const CommandArgument &Argument) { + switch (Argument.ArgumentKind) { + case SelectionRangeKind: + return std::string("{\"selection\":") + + Range::unparse(Argument.getSelectionRange()) + "}"; + case TextDocumentIdentifierKind: + return std::string("{\"doc\":{\"uri\":") + + URI::unparse(Argument.getDocumentID().uri) + "}}"; + } + llvm_unreachable("invalid kind"); +} + +llvm::Optional +CommandArgument::parse(llvm::yaml::MappingNode *Params, + clangd::Logger &Logger) { + CommandArgument Result; + for (auto &NextKeyValue : *Params) { + auto *KeyString = dyn_cast(NextKeyValue.getKey()); + if (!KeyString) + return None; + + llvm::SmallString<10> KeyStorage; + StringRef KeyValue = KeyString->getValue(KeyStorage); + + if (KeyValue == "range") { + auto *Value = + dyn_cast_or_null(NextKeyValue.getValue()); + if (!Value) + return None; + Optional SelectionRange = Range::parse(Value, Logger); + if (!SelectionRange) + return None; + Result = makeSelectionRange(std::move(*SelectionRange)); + } else if (KeyValue == "doc") { + auto *Value = + dyn_cast_or_null(NextKeyValue.getValue()); + if (!Value) + return None; + Optional ID = + TextDocumentIdentifier::parse(Value, Logger); + if (!ID) + return None; + Result = makeDocumentID(std::move(*ID)); + } else { + logIgnoredField(KeyValue, Logger); + return None; + } + } + return Result; +} + +Optional +ExecuteCommandParams::parse(llvm::yaml::MappingNode *Params, Logger &Logger) { + ExecuteCommandParams Result; + for (auto &NextKeyValue : *Params) { + auto *KeyString = dyn_cast(NextKeyValue.getKey()); + if (!KeyString) + return llvm::None; + + llvm::SmallString<10> KeyStorage; + StringRef KeyValue = KeyString->getValue(KeyStorage); + + if (KeyValue == "command") { + auto *Value = + dyn_cast_or_null(NextKeyValue.getValue()); + if (!Value) + return None; + llvm::SmallString<20> ValueStorage; + Result.command = Value->getValue(ValueStorage).str(); + } else if (KeyValue == "arguments") { + auto *Seq = + dyn_cast_or_null(NextKeyValue.getValue()); + if (!Seq) + return None; + std::vector Arguments; + for (auto &Item : *Seq) { + auto *Node = dyn_cast(&Item); + if (!Node) + return None; + Optional Argument = + CommandArgument::parse(Node, Logger); + if (!Argument) + return None; + Arguments.push_back(std::move(*Argument)); + } + Result.arguments = std::move(Arguments); + } else { + logIgnoredField(KeyValue, Logger); + } + } + return Result; +} Index: clangd/ProtocolHandlers.h =================================================================== --- clangd/ProtocolHandlers.h +++ clangd/ProtocolHandlers.h @@ -51,6 +51,8 @@ virtual void onGoToDefinition(Ctx C, TextDocumentPositionParams &Params) = 0; virtual void onSwitchSourceHeader(Ctx C, TextDocumentIdentifier &Params) = 0; virtual void onFileEvent(Ctx C, DidChangeWatchedFilesParams &Params) = 0; + virtual void onWorkspaceExecuteCommand(Ctx C, + ExecuteCommandParams &Params) = 0; }; void registerCallbackHandlers(JSONRPCDispatcher &Dispatcher, JSONOutput &Out, Index: clangd/ProtocolHandlers.cpp =================================================================== --- clangd/ProtocolHandlers.cpp +++ clangd/ProtocolHandlers.cpp @@ -67,4 +67,6 @@ Register("textDocument/switchSourceHeader", &ProtocolCallbacks::onSwitchSourceHeader); Register("workspace/didChangeWatchedFiles", &ProtocolCallbacks::onFileEvent); + Register("workspace/executeCommand", + &ProtocolCallbacks::onWorkspaceExecuteCommand); } Index: test/clangd/refactoring.test =================================================================== --- /dev/null +++ test/clangd/refactoring.test @@ -0,0 +1,28 @@ +# RUN: clangd -run-synchronously < %s | FileCheck %s +# It is absolutely vital that this file has CRLF line endings. +# +Content-Length: 125 + +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"clangd","capabilities":{},"trace":"off"}} +# +Content-Length: 181 + +{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"uri":"file:///foo.c","languageId":"c","version":1,"text":"int main(int i, char **a) { if (i == 2) {}}"}}} +# +# CHECK: {"jsonrpc":"2.0","method":"textDocument/publishDiagnostics","params":{"uri":"file:///foo.c","diagnostics":[]}} +# +Content-Length: 214 + +{"jsonrpc":"2.0","id":2,"method":"textDocument/codeAction","params":{"textDocument":{"uri":"file:///foo.c"},"range":{"start":{"line":0,"character":32},"end":{"line":0,"character":38}},"context":{"diagnostics":[]}}} +# +# CHECK: {"jsonrpc":"2.0","id":2,"result":[{"title":"Extract Function", "command": "clangd.refactor.ExtractFunction", "arguments": [{"doc":{"uri":"file:///foo.c"}},{"selection":{"start": {"line": 0, "character": 32}, "end": {"line": 0, "character": 38}}}, ]}]} +# +Content-Length: 243 + +{"jsonrpc":"2.0","id":2,"method":"workspace/executeCommand","params":{"command":"clangd.refactor.ExtractFunction","arguments": [{"doc":{"uri":"file:///foo.c"}}, {"range":{"start":{"line":0,"character":32},"end":{"line":0,"character":38}}}]}} +# +# CHECK: {"jsonrpc":"2.0","id":2,"result":[{"range": {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}}, "newText": "static int extracted() {\nreturn i == 2;\n}\n\n"},{"range": {"start": {"line": 0, "character": 32}, "end": {"line": 0, "character": 38}}, "newText": "extracted()"}]} +# +Content-Length: 44 + +{"jsonrpc":"2.0","id":3,"method":"shutdown"}