diff --git a/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspServerMain.h b/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspServerMain.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspServerMain.h @@ -0,0 +1,28 @@ +//===- MlirLspServerMain.h - MLIR Language Server main ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Main entry function for mlir-lsp-server for when built as standalone binary. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRLSPSERVER_MLIRLSPSERVERMAIN_H +#define MLIR_TOOLS_MLIRLSPSERVER_MLIRLSPSERVERMAIN_H + +namespace mlir { +class DialectRegistry; +struct LogicalResult; + +/// Implementation for tools like `mlir-lsp-server`. +/// - registry should contain all the dialects that can be parsed in source IR +/// passed to the server. +LogicalResult MlirLspServerMain(int argc, char **argv, + DialectRegistry ®istry); + +} // end namespace mlir + +#endif // MLIR_TOOLS_MLIRLSPSERVER_MLIRLSPSERVERMAIN_H diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.h b/mlir/lib/Tools/mlir-lsp-server/LSPServer.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.h @@ -0,0 +1,40 @@ +//===- LSPServer.h - MLIR LSP Server ----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_TOOLS_MLIRLSPSERVER_LSPSERVER_H +#define LIB_MLIR_TOOLS_MLIRLSPSERVER_LSPSERVER_H + +#include + +namespace mlir { +struct LogicalResult; + +namespace lsp { +class JSONTransport; +class MLIRServer; + +/// This class represents the main LSP server, and handles communication with +/// the LSP client. +class LSPServer { +public: + /// Construct a new language server with the given MLIR server. + LSPServer(MLIRServer &server, JSONTransport &transport); + ~LSPServer(); + + /// Run the main loop of the server. + LogicalResult run(); + +private: + struct Impl; + + std::unique_ptr impl; +}; +} // namespace lsp +} // namespace mlir + +#endif // LIB_MLIR_TOOLS_MLIRLSPSERVER_LSPSERVER_H diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp @@ -0,0 +1,155 @@ +//===- LSPServer.cpp - MLIR Language Server -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "LSPServer.h" +#include "MLIRServer.h" +#include "lsp/Logging.h" +#include "lsp/Protocol.h" +#include "lsp/Transport.h" +#include "llvm/ADT/FunctionExtras.h" +#include "llvm/ADT/StringMap.h" + +#define DEBUG_TYPE "mlir-lsp-server" + +using namespace mlir; +using namespace mlir::lsp; + +//===----------------------------------------------------------------------===// +// LSPServer::Impl +//===----------------------------------------------------------------------===// + +struct LSPServer::Impl { + Impl(MLIRServer &server, JSONTransport &transport) + : server(server), transport(transport) {} + + //===--------------------------------------------------------------------===// + // Initialization + + void onInitialize(const InitializeParams ¶ms, + Callback reply); + void onInitialized(const InitializedParams ¶ms); + void onShutdown(const NoParams ¶ms, Callback reply); + + //===--------------------------------------------------------------------===// + // Document Change + + void onDocumentDidOpen(const DidOpenTextDocumentParams ¶ms); + void onDocumentDidClose(const DidCloseTextDocumentParams ¶ms); + + //===--------------------------------------------------------------------===// + // Definitions and References + + void onGoToDefinition(const TextDocumentPositionParams ¶ms, + Callback> reply); + void onReference(const ReferenceParams ¶ms, + Callback> reply); + + MLIRServer &server; + JSONTransport &transport; + + /// Used to indicate that the 'shutdown' request was received from the + /// Language Server client. + bool shutdownRequestReceived = false; +}; + +//===----------------------------------------------------------------------===// +// Initialization + +void LSPServer::Impl::onInitialize(const InitializeParams ¶ms, + Callback reply) { + llvm::json::Object serverCaps{ + {"textDocumentSync", + llvm::json::Object{ + {"openClose", true}, + {"change", (int)TextDocumentSyncKind::Full}, + {"save", true}, + }}, + {"definitionProvider", true}, + {"referencesProvider", true}, + }; + + llvm::json::Object result{ + {{"serverInfo", + llvm::json::Object{{"name", "mlir-lsp-server"}, {"version", "0.0.0"}}}, + {"capabilities", std::move(serverCaps)}}}; + reply(std::move(result)); +} +void LSPServer::Impl::onInitialized(const InitializedParams &) {} +void LSPServer::Impl::onShutdown(const NoParams &, + Callback reply) { + shutdownRequestReceived = true; + reply(nullptr); +} + +//===----------------------------------------------------------------------===// +// Document Change + +void LSPServer::Impl::onDocumentDidOpen( + const DidOpenTextDocumentParams ¶ms) { + server.addOrUpdateDocument(params.textDocument.uri, params.textDocument.text); +} +void LSPServer::Impl::onDocumentDidClose( + const DidCloseTextDocumentParams ¶ms) { + server.removeDocument(params.textDocument.uri); +} + +//===----------------------------------------------------------------------===// +// Definitions and References + +void LSPServer::Impl::onGoToDefinition(const TextDocumentPositionParams ¶ms, + Callback> reply) { + std::vector locations; + server.getLocationsOf(params.textDocument.uri, params.position, locations); + reply(std::move(locations)); +} + +void LSPServer::Impl::onReference(const ReferenceParams ¶ms, + Callback> reply) { + std::vector locations; + server.findReferencesOf(params.textDocument.uri, params.position, locations); + reply(std::move(locations)); +} + +//===----------------------------------------------------------------------===// +// LSPServer +//===----------------------------------------------------------------------===// + +LSPServer::LSPServer(MLIRServer &server, JSONTransport &transport) + : impl(std::make_unique(server, transport)) {} +LSPServer::~LSPServer() {} + +LogicalResult LSPServer::run() { + MessageHandler messageHandler(impl->transport); + + // Initialization + messageHandler.method("initialize", impl.get(), &Impl::onInitialize); + messageHandler.notification("initialized", impl.get(), &Impl::onInitialized); + messageHandler.method("shutdown", impl.get(), &Impl::onShutdown); + + // Document Changes + messageHandler.notification("textDocument/didOpen", impl.get(), + &Impl::onDocumentDidOpen); + messageHandler.notification("textDocument/didClose", impl.get(), + &Impl::onDocumentDidClose); + + // Definitions and References + messageHandler.method("textDocument/definition", impl.get(), + &Impl::onGoToDefinition); + messageHandler.method("textDocument/references", impl.get(), + &Impl::onReference); + + LogicalResult result = success(); + if (llvm::Error error = impl->transport.run(messageHandler)) { + Logger::error("Transport error: {0}", error); + llvm::consumeError(std::move(error)); + result = failure(); + } else { + result = success(impl->shutdownRequestReceived); + } + return result; +} diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h @@ -0,0 +1,55 @@ +//===- MLIRServer.h - MLIR General Language Server --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_ +#define LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_ + +#include "mlir/Support/LLVM.h" +#include + +namespace mlir { +class DialectRegistry; + +namespace lsp { +struct Location; +struct Position; +class URIForFile; + +/// This class implements all of the MLIR related functionality necessary for a +/// language server. This class allows for keeping the MLIR specific logic +/// separate from the logic that involves LSP server/client communication. +class MLIRServer { +public: + /// Construct a new server with the given dialect regitstry. + MLIRServer(DialectRegistry ®istry); + ~MLIRServer(); + + /// Add or update the document at the given URI. + void addOrUpdateDocument(const URIForFile &uri, StringRef contents); + + /// Remove the document with the given uri. + void removeDocument(const URIForFile &uri); + + /// Return the locations of the object pointed at by the given position. + void getLocationsOf(const URIForFile &uri, const Position &defPos, + std::vector &locations); + + /// Find all references of the object pointed at by the given position. + void findReferencesOf(const URIForFile &uri, const Position &pos, + std::vector &references); + +private: + struct Impl; + + std::unique_ptr impl; +}; + +} // namespace lsp +} // namespace mlir + +#endif // LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_ diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -0,0 +1,269 @@ +//===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "MLIRServer.h" +#include "lsp/Logging.h" +#include "lsp/Protocol.h" +#include "mlir/IR/Operation.h" +#include "mlir/Parser.h" +#include "mlir/Parser/AsmParserState.h" +#include "llvm/Support/SourceMgr.h" + +using namespace mlir; + +/// Returns a language server position for the given source location. +static lsp::Position getPosFromLoc(llvm::SourceMgr &mgr, llvm::SMLoc loc) { + std::pair lineAndCol = mgr.getLineAndColumn(loc); + lsp::Position pos; + pos.line = lineAndCol.first - 1; + pos.character = lineAndCol.second; + return pos; +} + +/// Returns a language server range for the given source range. +static lsp::Range getRangeFromLoc(llvm::SourceMgr &mgr, llvm::SMRange range) { + // lsp::Range is an inclusive range, SMRange is half-open. + llvm::SMLoc inclusiveEnd = + llvm::SMLoc::getFromPointer(range.End.getPointer() - 1); + return {getPosFromLoc(mgr, range.Start), getPosFromLoc(mgr, inclusiveEnd)}; +} + +/// Returns a language server location from the given source range. +static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, + llvm::SMRange range, + const lsp::URIForFile &uri) { + return lsp::Location{uri, getRangeFromLoc(mgr, range)}; +} + +/// Returns a language server location from the given MLIR file location. +static Optional getLocationFromLoc(FileLineColLoc loc) { + llvm::Expected sourceURI = + lsp::URIForFile::fromFile(loc.getFilename()); + if (!sourceURI) { + lsp::Logger::error("Failed to create URI for file `{0}`: {1}", + loc.getFilename(), + llvm::toString(sourceURI.takeError())); + return llvm::None; + } + + lsp::Position position; + position.line = loc.getLine() - 1; + position.character = loc.getColumn(); + return lsp::Location{*sourceURI, lsp::Range{position, position}}; +} + +/// Collect all of the locations from the given MLIR location that are not +/// contained within the given URI. +static void collectLocationsFromLoc(Location loc, + std::vector &locations, + const lsp::URIForFile &uri) { + SetVector visitedLocs; + loc->walk([&](Location nestedLoc) { + FileLineColLoc fileLoc = nestedLoc.dyn_cast(); + if (!fileLoc || !visitedLocs.insert(nestedLoc)) + return WalkResult::advance(); + + Optional sourceLoc = getLocationFromLoc(fileLoc); + if (sourceLoc && sourceLoc->uri != uri) + locations.push_back(*sourceLoc); + return WalkResult::advance(); + }); +} + +/// Returns true if the given range contains the given source location. Note +/// that this has slightly different behavior than SMRange because it is +/// inclusive of the end location. +static bool contains(llvm::SMRange range, llvm::SMLoc loc) { + return range.Start.getPointer() <= loc.getPointer() && + loc.getPointer() <= range.End.getPointer(); +} + +/// Returns true if the given location is contained by the definition or one of +/// the uses of the given SMDefinition. +static bool isDefOrUse(const AsmParserState::SMDefinition &def, + llvm::SMLoc loc) { + auto isUseFn = [&](const llvm::SMRange &range) { + return contains(range, loc); + }; + return contains(def.loc, loc) || llvm::any_of(def.uses, isUseFn); +} + +//===----------------------------------------------------------------------===// +// MLIRDocument +//===----------------------------------------------------------------------===// + +namespace { +/// This class represents all of the information pertaining to a specific MLIR +/// document. +struct MLIRDocument { + MLIRDocument(const lsp::URIForFile &uri, StringRef contents, + DialectRegistry ®istry); + + void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, + std::vector &locations); + void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, + std::vector &references); + + /// The context used to hold the state contained by the parsed document. + MLIRContext context; + + /// The high level parser state used to find definitions and references within + /// the source file. + AsmParserState asmState; + + /// The container for the IR parsed from the input file. + Block parsedIR; + + /// The source manager containing the contents of the input file. + llvm::SourceMgr sourceMgr; +}; +} // namespace + +MLIRDocument::MLIRDocument(const lsp::URIForFile &uri, StringRef contents, + DialectRegistry ®istry) + : context(registry) { + context.allowUnregisteredDialects(); + ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) { + // TODO: What should we do with these diagnostics? + // * Cache and show to the user? + // * Ignore? + lsp::Logger::error("Error when parsing MLIR document `{0}`: `{1}`", + uri.file(), diag.str()); + }); + + // Try to parsed the given IR string. + auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file()); + if (!memBuffer) { + lsp::Logger::error("Failed to create memory buffer for file", uri.file()); + return; + } + + sourceMgr.AddNewSourceBuffer(std::move(memBuffer), llvm::SMLoc()); + if (failed( + parseSourceFile(sourceMgr, &parsedIR, &context, nullptr, &asmState))) + return; +} + +void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri, + const lsp::Position &defPos, + std::vector &locations) { + llvm::SMLoc posLoc = sourceMgr.FindLocForLineAndColumn( + sourceMgr.getMainFileID(), defPos.line + 1, defPos.character); + + // Functor used to check if an SM definition contains the position. + auto checkSMDef = [&](const AsmParserState::SMDefinition &def) { + if (!isDefOrUse(def, posLoc)) + return false; + locations.push_back(getLocationFromLoc(sourceMgr, def.loc, uri)); + return true; + }; + + // Check all definitions related to operations. + for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { + if (contains(op.loc, posLoc)) + return collectLocationsFromLoc(op.op->getLoc(), locations, uri); + for (const auto &result : op.resultGroups) + if (checkSMDef(result.second)) + return collectLocationsFromLoc(op.op->getLoc(), locations, uri); + } + + // Check all definitions related to blocks. + for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { + if (checkSMDef(block.definition)) + return; + for (const AsmParserState::SMDefinition &arg : block.arguments) + if (checkSMDef(arg)) + return; + } +} + +void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri, + const lsp::Position &pos, + std::vector &references) { + // Functor used to append all of the definitions/uses of the given SM + // definition to the reference list. + auto appendSMDef = [&](const AsmParserState::SMDefinition &def) { + references.push_back(getLocationFromLoc(sourceMgr, def.loc, uri)); + for (const llvm::SMRange &use : def.uses) + references.push_back(getLocationFromLoc(sourceMgr, use, uri)); + }; + + llvm::SMLoc posLoc = sourceMgr.FindLocForLineAndColumn( + sourceMgr.getMainFileID(), pos.line + 1, pos.character); + + // Check all definitions related to operations. + for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { + if (contains(op.loc, posLoc)) { + for (const auto &result : op.resultGroups) + appendSMDef(result.second); + return; + } + for (const auto &result : op.resultGroups) + if (isDefOrUse(result.second, posLoc)) + return appendSMDef(result.second); + } + + // Check all definitions related to blocks. + for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { + if (isDefOrUse(block.definition, posLoc)) + return appendSMDef(block.definition); + + for (const AsmParserState::SMDefinition &arg : block.arguments) + if (isDefOrUse(arg, posLoc)) + return appendSMDef(arg); + } +} + +//===----------------------------------------------------------------------===// +// MLIRServer::Impl +//===----------------------------------------------------------------------===// + +struct lsp::MLIRServer::Impl { + Impl(DialectRegistry ®istry) : registry(registry) {} + + /// The registry containing dialects that can be recognized in parsed .mlir + /// files. + DialectRegistry ®istry; + + /// The documents held by the server, mapped by their URI file name. + llvm::StringMap> documents; +}; + +//===----------------------------------------------------------------------===// +// MLIRServer +//===----------------------------------------------------------------------===// + +lsp::MLIRServer::MLIRServer(DialectRegistry ®istry) + : impl(std::make_unique(registry)) {} +lsp::MLIRServer::~MLIRServer() {} + +void lsp::MLIRServer::addOrUpdateDocument(const URIForFile &uri, + StringRef contents) { + impl->documents[uri.file()] = + std::make_unique(uri, contents, impl->registry); +} + +void lsp::MLIRServer::removeDocument(const URIForFile &uri) { + impl->documents.erase(uri.file()); +} + +void lsp::MLIRServer::getLocationsOf(const URIForFile &uri, + const Position &defPos, + std::vector &locations) { + auto fileIt = impl->documents.find(uri.file()); + if (fileIt != impl->documents.end()) + fileIt->second->getLocationsOf(uri, defPos, locations); +} + +void lsp::MLIRServer::findReferencesOf(const URIForFile &uri, + const Position &pos, + std::vector &references) { + auto fileIt = impl->documents.find(uri.file()); + if (fileIt != impl->documents.end()) + fileIt->second->findReferencesOf(uri, pos, references); +} diff --git a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp @@ -0,0 +1,75 @@ +//===- MlirLspServerMain.cpp - MLIR Language Server main ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" +#include "LSPServer.h" +#include "MLIRServer.h" +#include "lsp/Logging.h" +#include "lsp/Transport.h" +#include "mlir/IR/Dialect.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Program.h" + +using namespace mlir; +using namespace mlir::lsp; + +LogicalResult mlir::MlirLspServerMain(int argc, char **argv, + DialectRegistry ®istry) { + llvm::cl::opt inputStyle{ + "input-style", + llvm::cl::desc("Input JSON stream encoding"), + llvm::cl::values(clEnumValN(JSONStreamStyle::Standard, "standard", + "usual LSP protocol"), + clEnumValN(JSONStreamStyle::Delimited, "delimited", + "messages delimited by `// -----` lines, " + "with // comment support")), + llvm::cl::init(JSONStreamStyle::Standard), + llvm::cl::Hidden, + }; + llvm::cl::opt litTest{ + "lit-test", + llvm::cl::desc( + "Abbreviation for -input-style=delimited -pretty -log=verbose. " + "Intended to simplify lit tests"), + llvm::cl::init(false), + }; + llvm::cl::opt logLevel{ + "log", + llvm::cl::desc("Verbosity of log messages written to stderr"), + llvm::cl::values( + clEnumValN(Logger::Level::Error, "error", "Error messages only"), + clEnumValN(Logger::Level::Info, "info", + "High level execution tracing"), + clEnumValN(Logger::Level::Debug, "verbose", "Low level details")), + llvm::cl::init(Logger::Level::Info), + }; + llvm::cl::opt prettyPrint{ + "pretty", + llvm::cl::desc("Pretty-print JSON output"), + llvm::cl::init(false), + }; + llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR LSP Language Server"); + + if (litTest) { + inputStyle = JSONStreamStyle::Delimited; + logLevel = Logger::Level::Debug; + prettyPrint = true; + } + + // Configure the logger. + Logger::setLogLevel(logLevel); + + // Configure the transport used for communication. + llvm::sys::ChangeStdinToBinary(); + JSONTransport transport(stdin, llvm::outs(), inputStyle, prettyPrint); + + // Configure the servers and start the main language server. + MLIRServer server(registry); + LSPServer lspServer(server, transport); + return lspServer.run(); +} diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Logging.h b/mlir/lib/Tools/mlir-lsp-server/lsp/Logging.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Logging.h @@ -0,0 +1,65 @@ +//===- Logging.h - MLIR LSP Server Logging ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_TOOLS_MLIRLSPSERVER_LSP_LOGGING_H +#define LIB_MLIR_TOOLS_MLIRLSPSERVER_LSP_LOGGING_H + +#include "mlir/Support/LLVM.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include +#include + +namespace mlir { +namespace lsp { + +/// This class represents the main interface for logging, and allows for +/// filtering logging based on different levels of severity or significance. +class Logger { +public: + /// The level of significance for a log message. + enum class Level { Debug, Info, Error }; + + /// Set the severity level of the logger. + static void setLogLevel(Level logLevel); + + /// Initiate a log message at various severity levels. These should be called + /// after a call to `initialize`. + template + static void debug(const char *fmt, Ts &&...vals) { + log(Level::Debug, fmt, llvm::formatv(fmt, std::forward(vals)...)); + } + template + static void info(const char *fmt, Ts &&...vals) { + log(Level::Info, fmt, llvm::formatv(fmt, std::forward(vals)...)); + } + template + static void error(const char *fmt, Ts &&...vals) { + log(Level::Error, fmt, llvm::formatv(fmt, std::forward(vals)...)); + } + +private: + Logger() = default; + + /// Return the main logger instance. + static Logger &get(); + + /// Start a log message with the given severity level. + static void log(Level logLevel, const char *fmt, + const llvm::formatv_object_base &message); + + /// The minimum logging level. Messages with lower level are ignored. + Level logLevel = Level::Error; + + /// A mutex used to guard logging. + std::mutex mutex; +}; +} // namespace lsp +} // namespace mlir + +#endif // LIB_MLIR_TOOLS_MLIRLSPSERVER_LSP_LOGGING_H diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Logging.cpp b/mlir/lib/Tools/mlir-lsp-server/lsp/Logging.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Logging.cpp @@ -0,0 +1,42 @@ +//===- Logging.cpp --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Logging.h" +#include "llvm/Support/Chrono.h" +#include "llvm/Support/ManagedStatic.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::lsp; + +void Logger::setLogLevel(Level logLevel) { get().logLevel = logLevel; } + +Logger &Logger::get() { + static Logger logger; + return logger; +} + +void Logger::log(Level logLevel, const char *fmt, + const llvm::formatv_object_base &message) { + Logger &logger = get(); + + // Ignore messages with log levels below the current setting in the logger. + if (logLevel < logger.logLevel) + return; + + // An indicator character for each log level. + const char *logLevelIndicators = "DIE"; + + // Format the message and print to errs. + llvm::sys::TimePoint<> timestamp = std::chrono::system_clock::now(); + std::lock_guard logGuard(logger.mutex); + llvm::errs() << llvm::formatv( + "{0}[{1:%H:%M:%S.%L}] {2}\n", + logLevelIndicators[static_cast(logLevel)], timestamp, message); + llvm::errs().flush(); +} diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h @@ -0,0 +1,348 @@ +//===--- Protocol.h - Language Server Protocol Implementation ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains structs based on the LSP specification at +// https://github.com/Microsoft/language-server-protocol/blob/master/protocol.md +// +// This is not meant to be a complete implementation, new interfaces are added +// when they're needed. +// +// Each struct has a toJSON and fromJSON function, that converts between +// the struct and a JSON representation. (See JSON.h) +// +// Some structs also have operator<< serialization. This is for debugging and +// tests, and is not generally machine-readable. +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_TOOLS_MLIRLSPSERVER_LSP_PROTOCOL_H_ +#define LIB_MLIR_TOOLS_MLIRLSPSERVER_LSP_PROTOCOL_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/Optional.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include + +namespace mlir { +namespace lsp { + +enum class ErrorCode { + // Defined by JSON RPC. + ParseError = -32700, + InvalidRequest = -32600, + MethodNotFound = -32601, + InvalidParams = -32602, + InternalError = -32603, + + ServerNotInitialized = -32002, + UnknownErrorCode = -32001, + + // Defined by the protocol. + RequestCancelled = -32800, + ContentModified = -32801, +}; + +/// Defines how the host (editor) should sync document changes to the language +/// server. +enum class TextDocumentSyncKind { + /// Documents should not be synced at all. + None = 0, + + /// Documents are synced by always sending the full content of the document. + Full = 1, + + /// Documents are synced by sending the full content on open. After that only + /// incremental updates to the document are sent. + Incremental = 2, +}; + +//===----------------------------------------------------------------------===// +// LSPError +//===----------------------------------------------------------------------===// + +/// This class models an LSP error as an llvm::Error. +class LSPError : public llvm::ErrorInfo { +public: + std::string message; + ErrorCode code; + static char ID; + + LSPError(std::string message, ErrorCode code) + : message(std::move(message)), code(code) {} + + void log(raw_ostream &os) const override { + os << int(code) << ": " << message; + } + std::error_code convertToErrorCode() const override { + return llvm::inconvertibleErrorCode(); + } +}; + +//===----------------------------------------------------------------------===// +// URIForFile +//===----------------------------------------------------------------------===// + +/// URI in "file" scheme for a file. +class URIForFile { +public: + URIForFile() = default; + + /// Try to build a URIForFile from the given URI string. + static llvm::Expected fromURI(StringRef uri); + + /// Try to build a URIForFile from the given absolute file path. + static llvm::Expected fromFile(StringRef absoluteFilepath); + + /// Returns the absolute path to the file. + StringRef file() const { return filePath; } + + /// Returns the original uri of the file. + StringRef uri() const { return uriStr; } + + explicit operator bool() const { return !filePath.empty(); } + + friend bool operator==(const URIForFile &lhs, const URIForFile &rhs) { + return lhs.filePath == rhs.filePath; + } + friend bool operator!=(const URIForFile &lhs, const URIForFile &rhs) { + return !(lhs == rhs); + } + friend bool operator<(const URIForFile &lhs, const URIForFile &rhs) { + return lhs.filePath < rhs.filePath; + } + +private: + explicit URIForFile(std::string &&filePath, std::string &&uriStr) + : filePath(std::move(filePath)), uriStr(uriStr) {} + + std::string filePath; + std::string uriStr; +}; + +llvm::json::Value toJSON(const URIForFile &value); +bool fromJSON(const llvm::json::Value &value, URIForFile &result, + llvm::json::Path path); +raw_ostream &operator<<(raw_ostream &os, const URIForFile &value); + +//===----------------------------------------------------------------------===// +// InitializeParams +//===----------------------------------------------------------------------===// + +enum class TraceLevel { + Off = 0, + Messages = 1, + Verbose = 2, +}; +bool fromJSON(const llvm::json::Value &value, TraceLevel &result, + llvm::json::Path path); + +struct InitializeParams { + /// The initial trace setting. If omitted trace is disabled ('off'). + Optional trace; +}; +bool fromJSON(const llvm::json::Value &value, InitializeParams &result, + llvm::json::Path path); + +//===----------------------------------------------------------------------===// +// InitializedParams +//===----------------------------------------------------------------------===// + +struct NoParams {}; +inline bool fromJSON(const llvm::json::Value &, NoParams &, llvm::json::Path) { + return true; +} +using InitializedParams = NoParams; + +//===----------------------------------------------------------------------===// +// TextDocumentItem +//===----------------------------------------------------------------------===// + +struct TextDocumentItem { + /// The text document's URI. + URIForFile uri; + + /// The text document's language identifier. + std::string languageId; + + /// The content of the opened text document. + std::string text; +}; +bool fromJSON(const llvm::json::Value &value, TextDocumentItem &result, + llvm::json::Path path); + +//===----------------------------------------------------------------------===// +// TextDocumentIdentifier +//===----------------------------------------------------------------------===// + +struct TextDocumentIdentifier { + /// The text document's URI. + URIForFile uri; +}; +llvm::json::Value toJSON(const TextDocumentIdentifier &value); +bool fromJSON(const llvm::json::Value &value, TextDocumentIdentifier &result, + llvm::json::Path path); + +//===----------------------------------------------------------------------===// +// Position +//===----------------------------------------------------------------------===// + +struct Position { + /// Line position in a document (zero-based). + int line = 0; + + /// Character offset on a line in a document (zero-based). + int character = 0; + + friend bool operator==(const Position &lhs, const Position &rhs) { + return std::tie(lhs.line, lhs.character) == + std::tie(rhs.line, rhs.character); + } + friend bool operator!=(const Position &lhs, const Position &rhs) { + return !(lhs == rhs); + } + friend bool operator<(const Position &lhs, const Position &rhs) { + return std::tie(lhs.line, lhs.character) < + std::tie(rhs.line, rhs.character); + } + friend bool operator<=(const Position &lhs, const Position &rhs) { + return std::tie(lhs.line, lhs.character) <= + std::tie(rhs.line, rhs.character); + } +}; +bool fromJSON(const llvm::json::Value &value, Position &result, + llvm::json::Path path); +llvm::json::Value toJSON(const Position &value); +raw_ostream &operator<<(raw_ostream &os, const Position &value); + +//===----------------------------------------------------------------------===// +// Range +//===----------------------------------------------------------------------===// + +struct Range { + /// The range's start position. + Position start; + + /// The range's end position. + Position end; + + friend bool operator==(const Range &lhs, const Range &rhs) { + return std::tie(lhs.start, lhs.end) == std::tie(rhs.start, rhs.end); + } + friend bool operator!=(const Range &lhs, const Range &rhs) { + return !(lhs == rhs); + } + friend bool operator<(const Range &lhs, const Range &rhs) { + return std::tie(lhs.start, lhs.end) < std::tie(rhs.start, rhs.end); + } + + bool contains(Position Pos) const { return start <= Pos && Pos < end; } + bool contains(Range Rng) const { + return start <= Rng.start && Rng.end <= end; + } +}; +bool fromJSON(const llvm::json::Value &value, Range &result, + llvm::json::Path path); +llvm::json::Value toJSON(const Range &value); +raw_ostream &operator<<(raw_ostream &os, const Range &value); + +//===----------------------------------------------------------------------===// +// Location +//===----------------------------------------------------------------------===// + +struct Location { + /// The text document's URI. + URIForFile uri; + Range range; + + friend bool operator==(const Location &lhs, const Location &rhs) { + return lhs.uri == rhs.uri && lhs.range == rhs.range; + } + + friend bool operator!=(const Location &lhs, const Location &rhs) { + return !(lhs == rhs); + } + + friend bool operator<(const Location &lhs, const Location &rhs) { + return std::tie(lhs.uri, lhs.range) < std::tie(rhs.uri, rhs.range); + } +}; +llvm::json::Value toJSON(const Location &value); +raw_ostream &operator<<(raw_ostream &os, const Location &value); + +//===----------------------------------------------------------------------===// +// TextDocumentPositionParams +//===----------------------------------------------------------------------===// + +struct TextDocumentPositionParams { + /// The text document. + TextDocumentIdentifier textDocument; + + /// The position inside the text document. + Position position; +}; +bool fromJSON(const llvm::json::Value &value, + TextDocumentPositionParams &result, llvm::json::Path path); + +//===----------------------------------------------------------------------===// +// ReferenceParams +//===----------------------------------------------------------------------===// + +struct ReferenceContext { + /// Include the declaration of the current symbol. + bool includeDeclaration = false; +}; +bool fromJSON(const llvm::json::Value &value, ReferenceContext &result, + llvm::json::Path path); + +struct ReferenceParams : public TextDocumentPositionParams { + ReferenceContext context; +}; +bool fromJSON(const llvm::json::Value &value, ReferenceParams &result, + llvm::json::Path path); + +//===----------------------------------------------------------------------===// +// DidOpenTextDocumentParams +//===----------------------------------------------------------------------===// + +struct DidOpenTextDocumentParams { + /// The document that was opened. + TextDocumentItem textDocument; +}; +bool fromJSON(const llvm::json::Value &value, DidOpenTextDocumentParams &result, + llvm::json::Path path); + +//===----------------------------------------------------------------------===// +// DidCloseTextDocumentParams +//===----------------------------------------------------------------------===// + +struct DidCloseTextDocumentParams { + /// The document that was closed. + TextDocumentIdentifier textDocument; +}; +bool fromJSON(const llvm::json::Value &value, + DidCloseTextDocumentParams &result, llvm::json::Path path); + +} // namespace lsp +} // namespace mlir + +namespace llvm { +template <> +struct format_provider { + static void format(const mlir::lsp::Position &pos, raw_ostream &os, + StringRef style) { + assert(style.empty() && "style modifiers for this type are not supported"); + os << pos; + } +}; +} // namespace llvm + +#endif diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp @@ -0,0 +1,416 @@ +//===--- Protocol.cpp - Language Server Protocol Implementation -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the serialization code for the LSP structs. +// +//===----------------------------------------------------------------------===// + +#include "Protocol.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::lsp; + +// Helper that doesn't treat `null` and absent fields as failures. +template +static bool mapOptOrNull(const llvm::json::Value ¶ms, + llvm::StringLiteral prop, T &out, + llvm::json::Path path) { + const llvm::json::Object *o = params.getAsObject(); + assert(o); + + // Field is missing or null. + auto *v = o->get(prop); + if (!v || v->getAsNull().hasValue()) + return true; + return fromJSON(*v, out, path.field(prop)); +} + +//===----------------------------------------------------------------------===// +// LSPError +//===----------------------------------------------------------------------===// + +char LSPError::ID; + +//===----------------------------------------------------------------------===// +// URIForFile +//===----------------------------------------------------------------------===// + +static bool isWindowsPath(StringRef path) { + return path.size() > 1 && llvm::isAlpha(path[0]) && path[1] == ':'; +} + +static bool isNetworkPath(StringRef path) { + return path.size() > 2 && path[0] == path[1] && + llvm::sys::path::is_separator(path[0]); +} + +static bool shouldEscapeInURI(unsigned char c) { + // Unreserved characters. + if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9')) + return false; + + switch (c) { + case '-': + case '_': + case '.': + case '~': + // '/' is only reserved when parsing. + case '/': + // ':' is only reserved for relative URI paths, which clangd doesn't produce. + case ':': + return false; + } + return true; +} + +/// Encodes a string according to percent-encoding. +/// - Unreserved characters are not escaped. +/// - Reserved characters always escaped with exceptions like '/'. +/// - All other characters are escaped. +static void percentEncode(StringRef content, std::string &out) { + for (unsigned char c : content) { + if (shouldEscapeInURI(c)) { + out.push_back('%'); + out.push_back(llvm::hexdigit(c / 16)); + out.push_back(llvm::hexdigit(c % 16)); + } else { + out.push_back(c); + } + } +} + +/// Decodes a string according to percent-encoding. +static std::string percentDecode(StringRef content) { + std::string result; + for (auto i = content.begin(), e = content.end(); i != e; ++i) { + if (*i != '%') { + result += *i; + continue; + } + if (*i == '%' && i + 2 < content.end() && llvm::isHexDigit(*(i + 1)) && + llvm::isHexDigit(*(i + 2))) { + result.push_back(llvm::hexFromNibbles(*(i + 1), *(i + 2))); + i += 2; + } else { + result.push_back(*i); + } + } + return result; +} + +static bool isValidScheme(StringRef scheme) { + if (scheme.empty()) + return false; + if (!llvm::isAlpha(scheme[0])) + return false; + return std::all_of(scheme.begin() + 1, scheme.end(), [](char c) { + return llvm::isAlnum(c) || c == '+' || c == '.' || c == '-'; + }); +} + +static llvm::Expected uriFromAbsolutePath(StringRef absolutePath) { + std::string body; + StringRef authority; + StringRef root = llvm::sys::path::root_name(absolutePath); + if (isNetworkPath(root)) { + // Windows UNC paths e.g. \\server\share => file://server/share + authority = root.drop_front(2); + absolutePath.consume_front(root); + } else if (isWindowsPath(root)) { + // Windows paths e.g. X:\path => file:///X:/path + body = "/"; + } + body += llvm::sys::path::convert_to_slash(absolutePath); + + std::string uri = "file:"; + if (authority.empty() && body.empty()) + return uri; + + // If authority if empty, we only print body if it starts with "/"; otherwise, + // the URI is invalid. + if (!authority.empty() || StringRef(body).startswith("/")) { + uri.append("//"); + percentEncode(authority, uri); + } + percentEncode(body, uri); + return uri; +} + +static llvm::Expected getAbsolutePath(StringRef authority, + StringRef body) { + if (!body.startswith("/")) + return llvm::createStringError( + llvm::inconvertibleErrorCode(), + "File scheme: expect body to be an absolute path starting " + "with '/': " + + body); + SmallString<128> path; + if (!authority.empty()) { + // Windows UNC paths e.g. file://server/share => \\server\share + ("//" + authority).toVector(path); + } else if (isWindowsPath(body.substr(1))) { + // Windows paths e.g. file:///X:/path => X:\path + body.consume_front("/"); + } + path.append(body); + llvm::sys::path::native(path); + return std::string(path); +} + +static llvm::Expected parseFilePathFromURI(StringRef origUri) { + StringRef uri = origUri; + + // Decode the scheme of the URI. + size_t pos = uri.find(':'); + if (pos == StringRef::npos) + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "Scheme must be provided in URI: " + + origUri); + StringRef schemeStr = uri.substr(0, pos); + std::string uriScheme = percentDecode(schemeStr); + if (!isValidScheme(uriScheme)) + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "Invalid scheme: " + schemeStr + + " (decoded: " + uriScheme + ")"); + uri = uri.substr(pos + 1); + + // Decode the authority of the URI. + std::string uriAuthority; + if (uri.consume_front("//")) { + pos = uri.find('/'); + uriAuthority = percentDecode(uri.substr(0, pos)); + uri = uri.substr(pos); + } + + // Decode the body of the URI. + std::string uriBody = percentDecode(uri); + + // Compute the absolute path for this uri. + if (uriScheme != "file" && uriScheme != "test") { + return llvm::createStringError( + llvm::inconvertibleErrorCode(), + "mlir-lsp-server only supports 'file' URI scheme for workspace files"); + } + return getAbsolutePath(uriAuthority, uriBody); +} + +llvm::Expected URIForFile::fromURI(StringRef uri) { + llvm::Expected filePath = parseFilePathFromURI(uri); + if (!filePath) + return filePath.takeError(); + return URIForFile(std::move(*filePath), uri.str()); +} + +llvm::Expected URIForFile::fromFile(StringRef absoluteFilepath) { + llvm::Expected uri = uriFromAbsolutePath(absoluteFilepath); + if (!uri) + return uri.takeError(); + return fromURI(*uri); +} + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, URIForFile &result, + llvm::json::Path path) { + if (Optional str = value.getAsString()) { + llvm::Expected expectedURI = URIForFile::fromURI(*str); + if (!expectedURI) { + path.report("unresolvable URI"); + consumeError(expectedURI.takeError()); + return false; + } + result = std::move(*expectedURI); + return true; + } + return false; +} + +llvm::json::Value mlir::lsp::toJSON(const URIForFile &value) { + return value.uri(); +} + +raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const URIForFile &value) { + return os << value.uri(); +} + +//===----------------------------------------------------------------------===// +// InitializeParams +//===----------------------------------------------------------------------===// + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, TraceLevel &result, + llvm::json::Path path) { + if (Optional str = value.getAsString()) { + if (*str == "off") { + result = TraceLevel::Off; + return true; + } + if (*str == "messages") { + result = TraceLevel::Messages; + return true; + } + if (*str == "verbose") { + result = TraceLevel::Verbose; + return true; + } + } + return false; +} + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, + InitializeParams &result, llvm::json::Path path) { + llvm::json::ObjectMapper o(value, path); + if (!o) + return false; + // We deliberately don't fail if we can't parse individual fields. + o.map("trace", result.trace); + return true; +} + +//===----------------------------------------------------------------------===// +// TextDocumentItem +//===----------------------------------------------------------------------===// + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, + TextDocumentItem &result, llvm::json::Path path) { + llvm::json::ObjectMapper o(value, path); + return o && o.map("uri", result.uri) && + o.map("languageId", result.languageId) && o.map("text", result.text); +} + +//===----------------------------------------------------------------------===// +// TextDocumentIdentifier +//===----------------------------------------------------------------------===// + +llvm::json::Value mlir::lsp::toJSON(const TextDocumentIdentifier &value) { + return llvm::json::Object{{"uri", value.uri}}; +} + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, + TextDocumentIdentifier &result, + llvm::json::Path path) { + llvm::json::ObjectMapper o(value, path); + return o && o.map("uri", result.uri); +} + +//===----------------------------------------------------------------------===// +// Position +//===----------------------------------------------------------------------===// + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, Position &result, + llvm::json::Path path) { + llvm::json::ObjectMapper o(value, path); + return o && o.map("line", result.line) && + o.map("character", result.character); +} + +llvm::json::Value mlir::lsp::toJSON(const Position &value) { + return llvm::json::Object{ + {"line", value.line}, + {"character", value.character}, + }; +} + +raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const Position &value) { + return os << value.line << ':' << value.character; +} + +//===----------------------------------------------------------------------===// +// Range +//===----------------------------------------------------------------------===// + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, Range &result, + llvm::json::Path path) { + llvm::json::ObjectMapper o(value, path); + return o && o.map("start", result.start) && o.map("end", result.end); +} + +llvm::json::Value mlir::lsp::toJSON(const Range &value) { + return llvm::json::Object{ + {"start", value.start}, + {"end", value.end}, + }; +} + +raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const Range &value) { + return os << value.start << '-' << value.end; +} + +//===----------------------------------------------------------------------===// +// Location +//===----------------------------------------------------------------------===// + +llvm::json::Value mlir::lsp::toJSON(const Location &value) { + return llvm::json::Object{ + {"uri", value.uri}, + {"range", value.range}, + }; +} + +raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const Location &value) { + return os << value.range << '@' << value.uri; +} + +//===----------------------------------------------------------------------===// +// TextDocumentPositionParams +//===----------------------------------------------------------------------===// + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, + TextDocumentPositionParams &result, + llvm::json::Path path) { + llvm::json::ObjectMapper o(value, path); + return o && o.map("textDocument", result.textDocument) && + o.map("position", result.position); +} + +//===----------------------------------------------------------------------===// +// ReferenceParams +//===----------------------------------------------------------------------===// + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, + ReferenceContext &result, llvm::json::Path path) { + llvm::json::ObjectMapper o(value, path); + return o && o.mapOptional("includeDeclaration", result.includeDeclaration); +} + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, + ReferenceParams &result, llvm::json::Path path) { + TextDocumentPositionParams &base = result; + llvm::json::ObjectMapper o(value, path); + return fromJSON(value, base, path) && o && + o.mapOptional("context", result.context); +} + +//===----------------------------------------------------------------------===// +// DidOpenTextDocumentParams +//===----------------------------------------------------------------------===// + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, + DidOpenTextDocumentParams &result, + llvm::json::Path path) { + llvm::json::ObjectMapper o(value, path); + return o && o.map("textDocument", result.textDocument); +} + +//===----------------------------------------------------------------------===// +// DidCloseTextDocumentParams +//===----------------------------------------------------------------------===// + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, + DidCloseTextDocumentParams &result, + llvm::json::Path path) { + llvm::json::ObjectMapper o(value, path); + return o && o.map("textDocument", result.textDocument); +} diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.h b/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.h @@ -0,0 +1,190 @@ +//===--- Transport.h - Sending and Receiving LSP messages -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The language server protocol is usually implemented by writing messages as +// JSON-RPC over the stdin/stdout of a subprocess. This file contains a JSON +// transport interface that handles this communication. +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_TOOLS_MLIRLSPSERVER_LSP_TRANSPORT_H_ +#define LIB_MLIR_TOOLS_MLIRLSPSERVER_LSP_TRANSPORT_H_ + +#include "Protocol.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/FunctionExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatAdapters.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace lsp { +class JSONTransport; + +//===----------------------------------------------------------------------===// +// Reply +//===----------------------------------------------------------------------===// + +/// Function object to reply to an LSP call. +/// Each instance must be called exactly once, otherwise: +/// - if there was no reply, an error reply is sent +/// - if there were multiple replies, only the first is sent +class Reply { +public: + Reply(const llvm::json::Value &id, StringRef method, + JSONTransport &transport); + Reply(Reply &&other); + Reply &operator=(Reply &&) = delete; + Reply(const Reply &) = delete; + Reply &operator=(const Reply &) = delete; + + void operator()(llvm::Expected reply); + +private: + StringRef method; + std::atomic replied = {false}; + llvm::json::Value id; + JSONTransport *transport; +}; + +//===----------------------------------------------------------------------===// +// MessageHandler +//===----------------------------------------------------------------------===// + +/// A Callback is a void function that accepts Expected. This is +/// accepted by functions that logically return T. +template +using Callback = llvm::unique_function)>; + +/// A handler used to process the incoming transport messages. +class MessageHandler { +public: + MessageHandler(JSONTransport &transport) : transport(transport) {} + + bool onNotify(StringRef method, llvm::json::Value value); + bool onCall(StringRef method, llvm::json::Value params, llvm::json::Value id); + bool onReply(llvm::json::Value id, llvm::Expected result); + + template + static llvm::Expected parse(const llvm::json::Value &raw, + StringRef payloadName, StringRef payloadKind) { + T result; + llvm::json::Path::Root root; + if (fromJSON(raw, result, root)) + return std::move(result); + + // Dump the relevant parts of the broken message. + std::string context; + llvm::raw_string_ostream os(context); + root.printErrorContext(raw, os); + + // Report the error (e.g. to the client). + return llvm::make_error( + llvm::formatv("failed to decode {0} {1}: {2}", payloadName, payloadKind, + fmt_consume(root.getError())), + ErrorCode::InvalidParams); + } + + template + void method(llvm::StringLiteral method, ThisT *thisPtr, + void (ThisT::*handler)(const Param &, Callback)) { + methodHandlers[method] = [method, handler, + thisPtr](llvm::json::Value rawParams, + Callback reply) { + llvm::Expected param = parse(rawParams, method, "request"); + if (!param) + return reply(param.takeError()); + (thisPtr->*handler)(*param, std::move(reply)); + }; + } + + template + void notification(llvm::StringLiteral method, ThisT *thisPtr, + void (ThisT::*handler)(const Param &)) { + notificationHandlers[method] = [method, handler, + thisPtr](llvm::json::Value rawParams) { + llvm::Expected param = parse(rawParams, method, "request"); + if (!param) + return llvm::consumeError(param.takeError()); + (thisPtr->*handler)(*param); + }; + } + +private: + template + using HandlerMap = llvm::StringMap>; + + HandlerMap notificationHandlers; + HandlerMap)> + methodHandlers; + + JSONTransport &transport; +}; + +//===----------------------------------------------------------------------===// +// JSONTransport +//===----------------------------------------------------------------------===// + +/// The encoding style of the JSON-RPC messages (both input and output). +enum JSONStreamStyle { + /// Encoding per the LSP specification, with mandatory Content-Length header. + Standard, + /// Messages are delimited by a '// -----' line. Comment lines start with //. + Delimited +}; + +/// A transport class that performs the JSON-RPC communication with the LSP +/// client. +class JSONTransport { +public: + JSONTransport(std::FILE *in, raw_ostream &out, + JSONStreamStyle style = JSONStreamStyle::Standard, + bool prettyOutput = false) + : in(in), out(out), style(style), prettyOutput(prettyOutput) {} + + /// The following methods are used to send a message to the LSP client. + void notify(StringRef method, llvm::json::Value params); + void call(StringRef method, llvm::json::Value params, llvm::json::Value id); + void reply(llvm::json::Value id, llvm::Expected result); + + /// Start executing the JSON-RPC transport. + llvm::Error run(MessageHandler &handler); + +private: + /// Dispatches the given incoming json message to the message handler. + bool handleMessage(llvm::json::Value msg, MessageHandler &handler); + /// Writes the given message to the output stream. + void sendMessage(llvm::json::Value msg); + + /// Read in a message from the input stream. + LogicalResult readMessage(std::string &json) { + return style == JSONStreamStyle::Delimited ? readDelimitedMessage(json) + : readStandardMessage(json); + } + LogicalResult readDelimitedMessage(std::string &json); + LogicalResult readStandardMessage(std::string &json); + + /// An output buffer used when building output messages. + SmallVector outputBuffer; + /// The input file stream. + std::FILE *in; + /// The output file stream. + raw_ostream &out; + /// The JSON stream style to use. + JSONStreamStyle style; + /// If the output JSON should be formatted for easier readability. + bool prettyOutput; +}; + +} // namespace lsp +} // namespace mlir + +#endif diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.cpp b/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.cpp @@ -0,0 +1,326 @@ +//===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Transport.h" +#include "Logging.h" +#include "Protocol.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/Errno.h" +#include "llvm/Support/Error.h" +#include + +using namespace mlir; +using namespace mlir::lsp; + +//===----------------------------------------------------------------------===// +// Reply +//===----------------------------------------------------------------------===// + +Reply::Reply(const llvm::json::Value &id, llvm::StringRef method, + JSONTransport &transport) + : id(id), transport(&transport) {} + +Reply::Reply(Reply &&other) + : replied(other.replied.load()), id(std::move(other.id)), + transport(other.transport) { + other.transport = nullptr; +} + +void Reply::operator()(llvm::Expected reply) { + if (replied.exchange(true)) { + Logger::error("Replied twice to message {0}({1})", method, id); + assert(false && "must reply to each call only once!"); + return; + } + assert(transport && "expected valid transport to reply to"); + + if (reply) { + Logger::info("--> reply:{0}({1})", method, id); + transport->reply(std::move(id), std::move(reply)); + } else { + llvm::Error error = reply.takeError(); + Logger::info("--> reply:{0}({1})", method, id, error); + transport->reply(std::move(id), std::move(error)); + } +} + +//===----------------------------------------------------------------------===// +// MessageHandler +//===----------------------------------------------------------------------===// + +bool MessageHandler::onNotify(llvm::StringRef method, llvm::json::Value value) { + Logger::info("--> {0}", method); + + if (method == "exit") + return false; + if (method == "$cancel") { + // TODO: Add support for cancelling requests. + } else { + auto it = notificationHandlers.find(method); + if (it != notificationHandlers.end()) + it->second(value); + } + return true; +} + +bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params, + llvm::json::Value id) { + Logger::info("--> {0}({1})", method, id); + + Reply reply(id, method, transport); + + auto it = methodHandlers.find(method); + if (it != methodHandlers.end()) { + it->second(params, std::move(reply)); + } else { + reply(llvm::make_error("method not found: " + method.str(), + ErrorCode::MethodNotFound)); + } + return true; +} + +bool MessageHandler::onReply(llvm::json::Value id, + llvm::Expected result) { + // TODO: Add support for reply callbacks when support for outgoing messages is + // added. For now, we just log an error on any replies received. + Callback replyHandler = + [&id](llvm::Expected result) { + Logger::error( + "received a reply with ID {0}, but there was no such call", id); + if (!result) + llvm::consumeError(result.takeError()); + }; + + // Log and run the reply handler. + if (result) + replyHandler(std::move(result)); + else + replyHandler(result.takeError()); + return true; +} + +//===----------------------------------------------------------------------===// +// JSONTransport +//===----------------------------------------------------------------------===// + +/// Encode the given error as a JSON object. +static llvm::json::Object encodeError(llvm::Error error) { + std::string message; + ErrorCode code = ErrorCode::UnknownErrorCode; + auto handlerFn = [&](const LSPError &lspError) -> llvm::Error { + message = lspError.message; + code = lspError.code; + return llvm::Error::success(); + }; + if (llvm::Error unhandled = llvm::handleErrors(std::move(error), handlerFn)) + message = llvm::toString(std::move(unhandled)); + + return llvm::json::Object{ + {"message", std::move(message)}, + {"code", int64_t(code)}, + }; +} + +/// Decode the given JSON object into an error. +llvm::Error decodeError(const llvm::json::Object &o) { + StringRef msg = o.getString("message").getValueOr("Unspecified error"); + if (Optional code = o.getInteger("code")) + return llvm::make_error(msg.str(), ErrorCode(*code)); + return llvm::make_error(llvm::inconvertibleErrorCode(), + msg.str()); +} + +void JSONTransport::notify(StringRef method, llvm::json::Value params) { + sendMessage(llvm::json::Object{ + {"jsonrpc", "2.0"}, + {"method", method}, + {"params", std::move(params)}, + }); +} +void JSONTransport::call(StringRef method, llvm::json::Value params, + llvm::json::Value id) { + sendMessage(llvm::json::Object{ + {"jsonrpc", "2.0"}, + {"id", std::move(id)}, + {"method", method}, + {"params", std::move(params)}, + }); +} +void JSONTransport::reply(llvm::json::Value id, + llvm::Expected result) { + if (result) { + return sendMessage(llvm::json::Object{ + {"jsonrpc", "2.0"}, + {"id", std::move(id)}, + {"result", std::move(*result)}, + }); + } + + sendMessage(llvm::json::Object{ + {"jsonrpc", "2.0"}, + {"id", std::move(id)}, + {"error", encodeError(result.takeError())}, + }); +} + +llvm::Error JSONTransport::run(MessageHandler &handler) { + std::string json; + while (!feof(in)) { + if (ferror(in)) { + return llvm::errorCodeToError( + std::error_code(errno, std::system_category())); + } + + if (succeeded(readMessage(json))) { + if (llvm::Expected doc = llvm::json::parse(json)) { + if (!handleMessage(std::move(*doc), handler)) + return llvm::Error::success(); + } + } + } + return llvm::errorCodeToError(std::make_error_code(std::errc::io_error)); +} + +void JSONTransport::sendMessage(llvm::json::Value msg) { + outputBuffer.clear(); + llvm::raw_svector_ostream os(outputBuffer); + os << llvm::formatv(prettyOutput ? "{0:2}\n" : "{0}", msg); + out << "Content-Length: " << outputBuffer.size() << "\r\n\r\n" + << outputBuffer; + out.flush(); +} + +bool JSONTransport::handleMessage(llvm::json::Value msg, + MessageHandler &handler) { + // Message must be an object with "jsonrpc":"2.0". + llvm::json::Object *object = msg.getAsObject(); + if (!object || + object->getString("jsonrpc") != llvm::Optional("2.0")) + return false; + + // `id` may be any JSON value. If absent, this is a notification. + llvm::Optional id; + if (llvm::json::Value *i = object->get("id")) + id = std::move(*i); + Optional method = object->getString("method"); + + // This is a response. + if (!method) { + if (!id) + return false; + if (auto *err = object->getObject("error")) + return handler.onReply(std::move(*id), decodeError(*err)); + // result should be given, use null if not. + llvm::json::Value result = nullptr; + if (llvm::json::Value *r = object->get("result")) + result = std::move(*r); + return handler.onReply(std::move(*id), std::move(result)); + } + + // Params should be given, use null if not. + llvm::json::Value params = nullptr; + if (llvm::json::Value *p = object->get("params")) + params = std::move(*p); + + if (id) + return handler.onCall(*method, std::move(params), std::move(*id)); + return handler.onNotify(*method, std::move(params)); +} + +/// Tries to read a line up to and including \n. +/// If failing, feof(), ferror(), or shutdownRequested() will be set. +LogicalResult readLine(std::FILE *in, SmallVectorImpl &out) { + // Big enough to hold any reasonable header line. May not fit content lines + // in delimited mode, but performance doesn't matter for that mode. + static constexpr int bufSize = 128; + size_t size = 0; + out.clear(); + for (;;) { + out.resize_for_overwrite(size + bufSize); + if (!std::fgets(&out[size], bufSize, in)) + return failure(); + + clearerr(in); + + // If the line contained null bytes, anything after it (including \n) will + // be ignored. Fortunately this is not a legal header or JSON. + size_t read = std::strlen(&out[size]); + if (read > 0 && out[size + read - 1] == '\n') { + out.resize(size + read); + return success(); + } + size += read; + } +} + +// Returns None when: +// - ferror(), feof(), or shutdownRequested() are set. +// - Content-Length is missing or empty (protocol error) +LogicalResult JSONTransport::readStandardMessage(std::string &json) { + // A Language Server Protocol message starts with a set of HTTP headers, + // delimited by \r\n, and terminated by an empty line (\r\n). + unsigned long long contentLength = 0; + llvm::SmallString<128> line; + while (true) { + if (feof(in) || ferror(in) || failed(readLine(in, line))) + return failure(); + + // Content-Length is a mandatory header, and the only one we handle. + StringRef lineRef(line); + if (lineRef.consume_front("Content-Length: ")) { + llvm::getAsUnsignedInteger(lineRef.trim(), 0, contentLength); + } else if (!lineRef.trim().empty()) { + // It's another header, ignore it. + continue; + } else { + // An empty line indicates the end of headers. Go ahead and read the JSON. + break; + } + } + + // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999" + if (contentLength == 0 || contentLength > 1 << 30) + return failure(); + + json.resize(contentLength); + for (size_t pos = 0, read; pos < contentLength; pos += read) { + read = std::fread(&json[pos], 1, contentLength - pos, in); + if (read == 0) + return failure(); + + // If we're done, the error was transient. If we're not done, either it was + // transient or we'll see it again on retry. + clearerr(in); + pos += read; + } + return success(); +} + +/// For lit tests we support a simplified syntax: +/// - messages are delimited by '// -----' on a line by itself +/// - lines starting with // are ignored. +/// This is a testing path, so favor simplicity over performance here. +/// When returning failure: feof(), ferror(), or shutdownRequested() will be +/// set. +LogicalResult JSONTransport::readDelimitedMessage(std::string &json) { + json.clear(); + llvm::SmallString<128> line; + while (succeeded(readLine(in, line))) { + StringRef lineRef = StringRef(line).trim(); + if (lineRef.startswith("//")) { + // Found a delimiter for the message. + if (lineRef == "// -----") + break; + continue; + } + + json += line; + } + + return failure(ferror(in)); +} diff --git a/mlir/test/mlir-lsp-server/definition.test b/mlir/test/mlir-lsp-server/definition.test new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-lsp-server/definition.test @@ -0,0 +1,34 @@ +// RUN: mlir-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"clangd","capabilities":{},"trace":"off"}} +// ----- +{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{ + "uri":"test:///foo.mlir", + "languageId":"mlir", + "version":1, + "text":"func @foo() -> i1 {\n%value = constant true\nreturn %value : i1\n}" +}}} +// ----- +{"jsonrpc":"2.0","id":1,"method":"textDocument/definition","params":{ + "textDocument":{"uri":"test:///foo.mlir"}, + "position":{"line":2,"character":12} +}} +// CHECK: "id": 1 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": [ +// CHECK-NEXT: { +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": 6, +// CHECK-NEXT: "line": 1 +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": 1, +// CHECK-NEXT: "line": 1 +// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: "uri": "{{.*}}/foo.mlir" +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":3,"method":"shutdown"} +// ----- +{"jsonrpc":"2.0","method":"exit"} diff --git a/mlir/test/mlir-lsp-server/exit-eof.test b/mlir/test/mlir-lsp-server/exit-eof.test new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-lsp-server/exit-eof.test @@ -0,0 +1,7 @@ +// RUN: not mlir-lsp-server < %s 2> %t.err +// RUN: FileCheck %s < %t.err +// +// No LSP messages here, just let mlir-lsp-server see the end-of-file +// CHECK: Transport error: +// (Typically "Transport error: Input/output error" but platform-dependent). + diff --git a/mlir/test/mlir-lsp-server/exit-with-shutdown.test b/mlir/test/mlir-lsp-server/exit-with-shutdown.test new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-lsp-server/exit-with-shutdown.test @@ -0,0 +1,6 @@ +// RUN: mlir-lsp-server -lit-test < %s +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"clangd","capabilities":{},"trace":"off"}} +// ----- +{"jsonrpc":"2.0","id":3,"method":"shutdown"} +// ----- +{"jsonrpc":"2.0","method":"exit"} diff --git a/mlir/test/mlir-lsp-server/exit-without-shutdown.test b/mlir/test/mlir-lsp-server/exit-without-shutdown.test new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-lsp-server/exit-without-shutdown.test @@ -0,0 +1,4 @@ +// RUN: not mlir-lsp-server -lit-test < %s +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"clangd","capabilities":{},"trace":"off"}} +// ----- +{"jsonrpc":"2.0","method":"exit"} diff --git a/mlir/test/mlir-lsp-server/initialize-params-invalid.test b/mlir/test/mlir-lsp-server/initialize-params-invalid.test new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-lsp-server/initialize-params-invalid.test @@ -0,0 +1,12 @@ +// RUN: mlir-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s +// Test with invalid initialize request parameters +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":"","rootUri":"test:///workspace","capabilities":{},"trace":"verbose"}} +// CHECK: "id": 0, +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "capabilities": { +// ... +// ----- +{"jsonrpc":"2.0","id":3,"method":"shutdown"} +// ----- +{"jsonrpc":"2.0","method":"exit"} diff --git a/mlir/test/mlir-lsp-server/initialize-params.test b/mlir/test/mlir-lsp-server/initialize-params.test new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-lsp-server/initialize-params.test @@ -0,0 +1,27 @@ +// RUN: mlir-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s +// Test initialize request parameters with rootUri +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootUri":"test:///workspace","capabilities":{},"trace":"off"}} +// CHECK: "id": 0, +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "capabilities": { +// CHECK-NEXT: "definitionProvider": true, +// CHECK-NEXT: "referencesProvider": true, +// CHECK-NEXT: "textDocumentSync": { +// CHECK-NEXT: "change": 1, +// CHECK-NEXT: "openClose": true, +// CHECK-NEXT: "save": true +// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: "serverInfo": { +// CHECK-NEXT: "name": "mlir-lsp-server", +// CHECK-NEXT: "version": "{{.*}}" +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":3,"method":"shutdown"} +// CHECK: "id": 3, +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": null +// ----- +{"jsonrpc":"2.0","method":"exit"} diff --git a/mlir/test/mlir-lsp-server/references.test b/mlir/test/mlir-lsp-server/references.test new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-lsp-server/references.test @@ -0,0 +1,49 @@ +// RUN: mlir-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"clangd","capabilities":{},"trace":"off"}} +// ----- +{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{ + "uri":"test:///foo.mlir", + "languageId":"mlir", + "version":1, + "text":"func @foo() -> i1 {\n%value = constant true\nreturn %value : i1\n}" +}}} +// ----- +{"jsonrpc":"2.0","id":1,"method":"textDocument/references","params":{ + "textDocument":{"uri":"test:///foo.mlir"}, + "position":{"line":1,"character":2}, + "context":{"includeDeclaration": false} +}} +// CHECK: "id": 1 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": [ +// CHECK-NEXT: { +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": 6, +// CHECK-NEXT: "line": 1 +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": 1, +// CHECK-NEXT: "line": 1 +// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: "uri": "{{.*}}/foo.mlir" +// CHECK-NEXT: }, +// CHECK-NEXT: { +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": 13, +// CHECK-NEXT: "line": 2 +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": 8, +// CHECK-NEXT: "line": 2 +// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: "uri": "{{.*}}/foo.mlir" +// CHECK-NEXT: } +// CHECK-NEXT: ] +// ----- +{"jsonrpc":"2.0","id":3,"method":"shutdown"} +// ----- +{"jsonrpc":"2.0","method":"exit"} diff --git a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp @@ -0,0 +1,20 @@ +//===- mlir-lsp-server.cpp - MLIR Language Server -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" + +using namespace mlir; + +int main(int argc, char **argv) { + DialectRegistry registry; + registerAllDialects(registry); + return failed(MlirLspServerMain(argc, argv, registry)); +}