diff --git a/mlir/include/mlir/Parser/AsmParserState.h b/mlir/include/mlir/Parser/AsmParserState.h --- a/mlir/include/mlir/Parser/AsmParserState.h +++ b/mlir/include/mlir/Parser/AsmParserState.h @@ -83,6 +83,9 @@ AsmParserState(); ~AsmParserState(); + /// Clear out any internal data held by this parser state. + void clear(); + //===--------------------------------------------------------------------===// // Access State //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/AsmParserState.cpp b/mlir/lib/Parser/AsmParserState.cpp --- a/mlir/lib/Parser/AsmParserState.cpp +++ b/mlir/lib/Parser/AsmParserState.cpp @@ -53,6 +53,8 @@ AsmParserState::AsmParserState() : impl(std::make_unique()) {} AsmParserState::~AsmParserState() {} +void AsmParserState::clear() { impl = std::make_unique(); } + //===----------------------------------------------------------------------===// // Access State diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp @@ -59,6 +59,10 @@ MLIRServer &server; JSONTransport &transport; + /// An outgoing notification used to send diagnostics to the client when they + /// are ready to be processed. + OutgoingNotification publishDiagnostics; + /// Used to indicate that the 'shutdown' request was received from the /// Language Server client. bool shutdownRequestReceived = false; @@ -99,11 +103,21 @@ void LSPServer::Impl::onDocumentDidOpen( const DidOpenTextDocumentParams ¶ms) { - server.addOrUpdateDocument(params.textDocument.uri, params.textDocument.text); + PublishDiagnosticsParams diagParams(params.textDocument.uri); + server.addOrUpdateDocument(params.textDocument.uri, params.textDocument.text, + diagParams.diagnostics); + + // Publish any recorded diagnostics. + publishDiagnostics(diagParams); } void LSPServer::Impl::onDocumentDidClose( const DidCloseTextDocumentParams ¶ms) { server.removeDocument(params.textDocument.uri); + + // Empty out the diagnostics shown for this document. This will clear out + // anything currently displayed by the client for this document (e.g. in the + // "Problems" pane of VSCode). + publishDiagnostics(PublishDiagnosticsParams(params.textDocument.uri)); } void LSPServer::Impl::onDocumentDidChange( const DidChangeTextDocumentParams ¶ms) { @@ -111,8 +125,13 @@ // to avoid this. if (params.contentChanges.size() != 1) return; + PublishDiagnosticsParams diagParams(params.textDocument.uri); server.addOrUpdateDocument(params.textDocument.uri, - params.contentChanges.front().text); + params.contentChanges.front().text, + diagParams.diagnostics); + + // Publish any recorded diagnostics. + publishDiagnostics(diagParams); } //===----------------------------------------------------------------------===// @@ -173,6 +192,11 @@ // Hover messageHandler.method("textDocument/hover", impl.get(), &Impl::onHover); + // Diagnostics + impl->publishDiagnostics = + messageHandler.outgoingNotification( + "textDocument/publishDiagnostics"); + // Run the main loop of the transport. LogicalResult result = success(); if (llvm::Error error = impl->transport.run(messageHandler)) { diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h @@ -16,6 +16,7 @@ class DialectRegistry; namespace lsp { +struct Diagnostic; struct Hover; struct Location; struct Position; @@ -30,8 +31,10 @@ MLIRServer(DialectRegistry ®istry); ~MLIRServer(); - /// Add or update the document at the given URI. - void addOrUpdateDocument(const URIForFile &uri, StringRef contents); + /// Add or update the document at the given URI. Any diagnostics emitted for + /// this document should be added to `diagnostics` + void addOrUpdateDocument(const URIForFile &uri, StringRef contents, + std::vector &diagnostics); /// Remove the document with the given uri. void removeDocument(const URIForFile &uri); diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -60,7 +60,28 @@ lsp::Position position; position.line = loc.getLine() - 1; position.character = loc.getColumn(); - return lsp::Location{*sourceURI, lsp::Range{position, position}}; + return lsp::Location{*sourceURI, lsp::Range(position)}; +} + +/// Returns a language server location from the given MLIR location, or None if +/// one couldn't be created. `uri` is an optional additional filter that, when +/// present, is used to filter sub locations that do not share the same uri. +static Optional +getLocationFromLoc(Location loc, const lsp::URIForFile *uri = nullptr) { + Optional location; + loc->walk([&](Location nestedLoc) { + FileLineColLoc fileLoc = nestedLoc.dyn_cast(); + if (!fileLoc) + return WalkResult::advance(); + + Optional sourceLoc = getLocationFromLoc(fileLoc); + if (sourceLoc && (!uri || sourceLoc->uri == *uri)) { + location = *sourceLoc; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return location; } /// Collect all of the locations from the given MLIR location that are not @@ -173,6 +194,57 @@ printDefBlockName(os, def.block, def.definition.loc); } +/// Convert the given MLIR diagnostic to the lsp form. +static lsp::Diagnostic getLspDiagnoticFromDiag(Diagnostic &diag, + const lsp::URIForFile &uri) { + lsp::Diagnostic lspDiag; + lspDiag.source = "mlir"; + + // Note: Right now all of the diagnostics are treated as parser issues, but + // conceptually some are parser and some are verifier. + lspDiag.category = "Parse Error"; + + // Try to grab a file location for this diagnostic. + // TODO: For simplicity, we just grab the first one. It may be likely that we + // will need a more interesting heuristic here.' + Optional lspLocation = + getLocationFromLoc(diag.getLocation(), &uri); + if (lspLocation) + lspDiag.range = lspLocation->range; + + // Convert the severity for the diagnostic. + switch (diag.getSeverity()) { + case DiagnosticSeverity::Note: + assert(0 && "expected notes to be handled separately"); + break; + case DiagnosticSeverity::Warning: + lspDiag.severity = lsp::DiagnosticSeverity::Error; + break; + case DiagnosticSeverity::Error: + lspDiag.severity = lsp::DiagnosticSeverity::Error; + break; + case DiagnosticSeverity::Remark: + lspDiag.severity = lsp::DiagnosticSeverity::Information; + break; + } + lspDiag.message = diag.str(); + + // Attach any notes to the main diagnostic as related information. + std::vector relatedDiags; + for (Diagnostic ¬e : diag.getNotes()) { + lsp::Location noteLoc; + if (Optional loc = getLocationFromLoc(note.getLocation())) + noteLoc = *loc; + else + noteLoc.uri = uri; + relatedDiags.emplace_back(noteLoc, note.str()); + } + if (!relatedDiags.empty()) + lspDiag.relatedInformation = std::move(relatedDiags); + + return lspDiag; +} + //===----------------------------------------------------------------------===// // MLIRDocument //===----------------------------------------------------------------------===// @@ -182,7 +254,8 @@ /// document. struct MLIRDocument { MLIRDocument(const lsp::URIForFile &uri, StringRef contents, - DialectRegistry ®istry); + DialectRegistry ®istry, + std::vector &diagnostics); //===--------------------------------------------------------------------===// // Definitions and References @@ -227,15 +300,12 @@ } // namespace MLIRDocument::MLIRDocument(const lsp::URIForFile &uri, StringRef contents, - DialectRegistry ®istry) + DialectRegistry ®istry, + std::vector &diagnostics) : context(registry) { context.allowUnregisteredDialects(); ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) { - // TODO: What should we do with these diagnostics? - // * Cache and show to the user? - // * Ignore? - lsp::Logger::error("Error when parsing MLIR document `{0}`: `{1}`", - uri.file(), diag.str()); + diagnostics.push_back(getLspDiagnoticFromDiag(diag, uri)); }); // Try to parsed the given IR string. @@ -246,9 +316,13 @@ } sourceMgr.AddNewSourceBuffer(std::move(memBuffer), llvm::SMLoc()); - if (failed( - parseSourceFile(sourceMgr, &parsedIR, &context, nullptr, &asmState))) + if (failed(parseSourceFile(sourceMgr, &parsedIR, &context, nullptr, + &asmState))) { + // If parsing failed, clear out any of the current state. + parsedIR.clear(); + asmState.clear(); return; + } } //===----------------------------------------------------------------------===// @@ -495,10 +569,11 @@ : impl(std::make_unique(registry)) {} lsp::MLIRServer::~MLIRServer() {} -void lsp::MLIRServer::addOrUpdateDocument(const URIForFile &uri, - StringRef contents) { - impl->documents[uri.file()] = - std::make_unique(uri, contents, impl->registry); +void lsp::MLIRServer::addOrUpdateDocument( + const URIForFile &uri, StringRef contents, + std::vector &diagnostics) { + impl->documents[uri.file()] = std::make_unique( + uri, contents, impl->registry, diagnostics); } void lsp::MLIRServer::removeDocument(const URIForFile &uri) { diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h --- a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h @@ -228,6 +228,10 @@ //===----------------------------------------------------------------------===// struct Range { + Range() = default; + Range(Position start, Position end) : start(start), end(end) {} + Range(Position loc) : Range(loc, loc) {} + /// The range's start position. Position start; @@ -393,6 +397,79 @@ }; llvm::json::Value toJSON(const Hover &hover); +//===----------------------------------------------------------------------===// +// DiagnosticRelatedInformation +//===----------------------------------------------------------------------===// + +/// Represents a related message and source code location for a diagnostic. +/// This should be used to point to code locations that cause or related to a +/// diagnostics, e.g. when duplicating a symbol in a scope. +struct DiagnosticRelatedInformation { + DiagnosticRelatedInformation(Location location, std::string message) + : location(location), message(std::move(message)) {} + + /// The location of this related diagnostic information. + Location location; + /// The message of this related diagnostic information. + std::string message; +}; +llvm::json::Value toJSON(const DiagnosticRelatedInformation &info); + +//===----------------------------------------------------------------------===// +// Diagnostic +//===----------------------------------------------------------------------===// + +enum class DiagnosticSeverity { + /// It is up to the client to interpret diagnostics as error, warning, info or + /// hint. + Undetermined = 0, + Error = 1, + Warning = 2, + Information = 3, + Hint = 4 +}; + +struct Diagnostic { + /// The source range where the message applies. + Range range; + + /// The diagnostic's severity. Can be omitted. If omitted it is up to the + /// client to interpret diagnostics as error, warning, info or hint. + DiagnosticSeverity severity = DiagnosticSeverity::Undetermined; + + /// A human-readable string describing the source of this diagnostic, e.g. + /// 'typescript' or 'super lint'. + std::string source; + + /// The diagnostic's message. + std::string message; + + /// An array of related diagnostic information, e.g. when symbol-names within + /// a scope collide all definitions can be marked via this property. + Optional> relatedInformation; + + /// The diagnostic's category. Can be omitted. + /// An LSP extension that's used to send the name of the category over to the + /// client. The category typically describes the compilation stage during + /// which the issue was produced, e.g. "Semantic Issue" or "Parse Issue". + Optional category; +}; +llvm::json::Value toJSON(const Diagnostic &diag); + +//===----------------------------------------------------------------------===// +// PublishDiagnosticsParams +//===----------------------------------------------------------------------===// + +struct PublishDiagnosticsParams { + PublishDiagnosticsParams(URIForFile uri) : uri(uri) {} + + /// The URI for which diagnostic information is reported. + URIForFile uri; + /// The list of reported diagnostics. + std::vector diagnostics; +}; +llvm::json::Value toJSON(const PublishDiagnosticsParams ¶ms); + } // namespace lsp } // namespace mlir diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp --- a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp @@ -473,3 +473,44 @@ result["range"] = toJSON(*hover.range); return std::move(result); } + +//===----------------------------------------------------------------------===// +// DiagnosticRelatedInformation +//===----------------------------------------------------------------------===// + +llvm::json::Value mlir::lsp::toJSON(const DiagnosticRelatedInformation &info) { + return llvm::json::Object{ + {"location", info.location}, + {"message", info.message}, + }; +} + +//===----------------------------------------------------------------------===// +// Diagnostic +//===----------------------------------------------------------------------===// + +llvm::json::Value mlir::lsp::toJSON(const Diagnostic &diag) { + llvm::json::Object result{ + {"range", diag.range}, + {"severity", (int)diag.severity}, + {"message", diag.message}, + }; + if (diag.category) + result["category"] = *diag.category; + if (!diag.source.empty()) + result["source"] = diag.source; + if (diag.relatedInformation) + result["relatedInformation"] = *diag.relatedInformation; + return std::move(result); +} + +//===----------------------------------------------------------------------===// +// PublishDiagnosticsParams +//===----------------------------------------------------------------------===// + +llvm::json::Value mlir::lsp::toJSON(const PublishDiagnosticsParams ¶ms) { + return llvm::json::Object{ + {"uri", params.uri}, + {"diagnostics", params.diagnostics}, + }; +} diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.h b/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.h --- a/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.h +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.h @@ -28,32 +28,61 @@ namespace mlir { namespace lsp { -class JSONTransport; +class MessageHandler; //===----------------------------------------------------------------------===// -// Reply +// JSONTransport //===----------------------------------------------------------------------===// -/// Function object to reply to an LSP call. -/// Each instance must be called exactly once, otherwise: -/// - if there was no reply, an error reply is sent -/// - if there were multiple replies, only the first is sent -class Reply { +/// The encoding style of the JSON-RPC messages (both input and output). +enum JSONStreamStyle { + /// Encoding per the LSP specification, with mandatory Content-Length header. + Standard, + /// Messages are delimited by a '// -----' line. Comment lines start with //. + Delimited +}; + +/// A transport class that performs the JSON-RPC communication with the LSP +/// client. +class JSONTransport { public: - Reply(const llvm::json::Value &id, StringRef method, - JSONTransport &transport); - Reply(Reply &&other); - Reply &operator=(Reply &&) = delete; - Reply(const Reply &) = delete; - Reply &operator=(const Reply &) = delete; + JSONTransport(std::FILE *in, raw_ostream &out, + JSONStreamStyle style = JSONStreamStyle::Standard, + bool prettyOutput = false) + : in(in), out(out), style(style), prettyOutput(prettyOutput) {} - void operator()(llvm::Expected reply); + /// The following methods are used to send a message to the LSP client. + void notify(StringRef method, llvm::json::Value params); + void call(StringRef method, llvm::json::Value params, llvm::json::Value id); + void reply(llvm::json::Value id, llvm::Expected result); + + /// Start executing the JSON-RPC transport. + llvm::Error run(MessageHandler &handler); private: - StringRef method; - std::atomic replied = {false}; - llvm::json::Value id; - JSONTransport *transport; + /// Dispatches the given incoming json message to the message handler. + bool handleMessage(llvm::json::Value msg, MessageHandler &handler); + /// Writes the given message to the output stream. + void sendMessage(llvm::json::Value msg); + + /// Read in a message from the input stream. + LogicalResult readMessage(std::string &json) { + return style == JSONStreamStyle::Delimited ? readDelimitedMessage(json) + : readStandardMessage(json); + } + LogicalResult readDelimitedMessage(std::string &json); + LogicalResult readStandardMessage(std::string &json); + + /// An output buffer used when building output messages. + SmallVector outputBuffer; + /// The input file stream. + std::FILE *in; + /// The output file stream. + raw_ostream &out; + /// The JSON stream style to use. + JSONStreamStyle style; + /// If the output JSON should be formatted for easier readability. + bool prettyOutput; }; //===----------------------------------------------------------------------===// @@ -65,6 +94,11 @@ template using Callback = llvm::unique_function)>; +/// An OutgoingNotification is a function used for outgoing notifications +/// send to the client. +template +using OutgoingNotification = llvm::unique_function; + /// A handler used to process the incoming transport messages. class MessageHandler { public: @@ -119,6 +153,14 @@ }; } + /// Create an OutgoingNotification object used for the given method. + template + OutgoingNotification outgoingNotification(llvm::StringLiteral method) { + return [&, method](const T ¶ms) { + transport.notify(method, llvm::json::Value(params)); + }; + } + private: template using HandlerMap = llvm::StringMap>; @@ -130,61 +172,6 @@ JSONTransport &transport; }; -//===----------------------------------------------------------------------===// -// JSONTransport -//===----------------------------------------------------------------------===// - -/// The encoding style of the JSON-RPC messages (both input and output). -enum JSONStreamStyle { - /// Encoding per the LSP specification, with mandatory Content-Length header. - Standard, - /// Messages are delimited by a '// -----' line. Comment lines start with //. - Delimited -}; - -/// A transport class that performs the JSON-RPC communication with the LSP -/// client. -class JSONTransport { -public: - JSONTransport(std::FILE *in, raw_ostream &out, - JSONStreamStyle style = JSONStreamStyle::Standard, - bool prettyOutput = false) - : in(in), out(out), style(style), prettyOutput(prettyOutput) {} - - /// The following methods are used to send a message to the LSP client. - void notify(StringRef method, llvm::json::Value params); - void call(StringRef method, llvm::json::Value params, llvm::json::Value id); - void reply(llvm::json::Value id, llvm::Expected result); - - /// Start executing the JSON-RPC transport. - llvm::Error run(MessageHandler &handler); - -private: - /// Dispatches the given incoming json message to the message handler. - bool handleMessage(llvm::json::Value msg, MessageHandler &handler); - /// Writes the given message to the output stream. - void sendMessage(llvm::json::Value msg); - - /// Read in a message from the input stream. - LogicalResult readMessage(std::string &json) { - return style == JSONStreamStyle::Delimited ? readDelimitedMessage(json) - : readStandardMessage(json); - } - LogicalResult readDelimitedMessage(std::string &json); - LogicalResult readStandardMessage(std::string &json); - - /// An output buffer used when building output messages. - SmallVector outputBuffer; - /// The input file stream. - std::FILE *in; - /// The output file stream. - raw_ostream &out; - /// The JSON stream style to use. - JSONStreamStyle style; - /// If the output JSON should be formatted for easier readability. - bool prettyOutput; -}; - } // namespace lsp } // namespace mlir diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.cpp b/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.cpp --- a/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.cpp @@ -21,6 +21,30 @@ // Reply //===----------------------------------------------------------------------===// +namespace { +/// Function object to reply to an LSP call. +/// Each instance must be called exactly once, otherwise: +/// - if there was no reply, an error reply is sent +/// - if there were multiple replies, only the first is sent +class Reply { +public: + Reply(const llvm::json::Value &id, StringRef method, + JSONTransport &transport); + Reply(Reply &&other); + Reply &operator=(Reply &&) = delete; + Reply(const Reply &) = delete; + Reply &operator=(const Reply &) = delete; + + void operator()(llvm::Expected reply); + +private: + StringRef method; + std::atomic replied = {false}; + llvm::json::Value id; + JSONTransport *transport; +}; +} // namespace + Reply::Reply(const llvm::json::Value &id, llvm::StringRef method, JSONTransport &transport) : id(id), transport(&transport) {} diff --git a/mlir/test/mlir-lsp-server/diagnostics.test b/mlir/test/mlir-lsp-server/diagnostics.test new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-lsp-server/diagnostics.test @@ -0,0 +1,35 @@ +// RUN: mlir-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"mlir","capabilities":{},"trace":"off"}} +// ----- +{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{ + "uri":"test:///foo.mlir", + "languageId":"mlir", + "version":1, + "text":"func ()" +}}} +// CHECK: "method": "textDocument/publishDiagnostics", +// CHECK-NEXT: "params": { +// CHECK-NEXT: "diagnostics": [ +// CHECK-NEXT: { +// CHECK-NEXT: "category": "Parse Error", +// CHECK-NEXT: "message": "custom op 'func' expected valid '@'-identifier for symbol name", +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": 6, +// CHECK-NEXT: "line": 0 +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": 6, +// CHECK-NEXT: "line": 0 +// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: "severity": 1, +// CHECK-NEXT: "source": "mlir" +// CHECK-NEXT: } +// CHECK-NEXT: ], +// CHECK-NEXT: "uri": "test:///foo.mlir" +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":3,"method":"shutdown"} +// ----- +{"jsonrpc":"2.0","method":"exit"}