diff --git a/mlir/docs/Tools/MLIRLSP.md b/mlir/docs/Tools/MLIRLSP.md --- a/mlir/docs/Tools/MLIRLSP.md +++ b/mlir/docs/Tools/MLIRLSP.md @@ -109,3 +109,6 @@ * Syntax highlighting for .mlir files and `mlir` markdown blocks * go-to-definition and cross references * Definitions include the source file locations of operations in the .mlir +* Hover over IR entities to see more information about them + * e.g. for a Block, you can see its block number as well as any + predecessors or successors. diff --git a/mlir/include/mlir/Parser/AsmParserState.h b/mlir/include/mlir/Parser/AsmParserState.h --- a/mlir/include/mlir/Parser/AsmParserState.h +++ b/mlir/include/mlir/Parser/AsmParserState.h @@ -95,6 +95,10 @@ /// Return a range of the BlockDefinitions held by the current parser state. iterator_range getBlockDefs() const; + /// Return the definition for the given block, or nullptr if the given + /// block does not have a definition. + const BlockDefinition *getBlockDef(Block *block) const; + /// Return a range of the OperationDefinitions held by the current parser /// state. iterator_range getOpDefs() const; diff --git a/mlir/lib/Parser/AsmParserState.cpp b/mlir/lib/Parser/AsmParserState.cpp --- a/mlir/lib/Parser/AsmParserState.cpp +++ b/mlir/lib/Parser/AsmParserState.cpp @@ -60,6 +60,12 @@ return llvm::make_pointee_range(llvm::makeArrayRef(impl->blocks)); } +auto AsmParserState::getBlockDef(Block *block) const + -> const BlockDefinition * { + auto it = impl->blocksToIdx.find(block); + return it == impl->blocksToIdx.end() ? nullptr : &*impl->blocks[it->second]; +} + auto AsmParserState::getOpDefs() const -> iterator_range { return llvm::make_pointee_range(llvm::makeArrayRef(impl->operations)); } diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp @@ -50,6 +50,12 @@ void onReference(const ReferenceParams ¶ms, Callback> reply); + //===--------------------------------------------------------------------===// + // Hover + + void onHover(const TextDocumentPositionParams ¶ms, + Callback> reply); + MLIRServer &server; JSONTransport &transport; @@ -72,6 +78,7 @@ }}, {"definitionProvider", true}, {"referencesProvider", true}, + {"hoverProvider", true}, }; llvm::json::Object result{ @@ -125,6 +132,14 @@ reply(std::move(locations)); } +//===----------------------------------------------------------------------===// +// Hover + +void LSPServer::Impl::onHover(const TextDocumentPositionParams ¶ms, + Callback> reply) { + reply(server.findHover(params.textDocument.uri, params.position)); +} + //===----------------------------------------------------------------------===// // LSPServer //===----------------------------------------------------------------------===// @@ -155,6 +170,10 @@ messageHandler.method("textDocument/references", impl.get(), &Impl::onReference); + // Hover + messageHandler.method("textDocument/hover", impl.get(), &Impl::onHover); + + // Run the main loop of the transport. LogicalResult result = success(); if (llvm::Error error = impl->transport.run(messageHandler)) { Logger::error("Transport error: {0}", error); diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h @@ -16,6 +16,7 @@ class DialectRegistry; namespace lsp { +struct Hover; struct Location; struct Position; class URIForFile; @@ -43,6 +44,10 @@ void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector &references); + /// Find a hover description for the given hover position, or None if one + /// couldn't be found. + Optional findHover(const URIForFile &uri, const Position &hoverPos); + private: struct Impl; diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -90,13 +90,84 @@ } /// Returns true if the given location is contained by the definition or one of -/// the uses of the given SMDefinition. -static bool isDefOrUse(const AsmParserState::SMDefinition &def, - llvm::SMLoc loc) { - auto isUseFn = [&](const llvm::SMRange &range) { +/// the uses of the given SMDefinition. If provided, `overlappedRange` is set to +/// the range within `def` that the provided `loc` overlapped with. +static bool isDefOrUse(const AsmParserState::SMDefinition &def, llvm::SMLoc loc, + llvm::SMRange *overlappedRange = nullptr) { + // Check the main definition. + if (contains(def.loc, loc)) { + if (overlappedRange) + *overlappedRange = def.loc; + return true; + } + + // Check the uses. + auto useIt = llvm::find_if(def.uses, [&](const llvm::SMRange &range) { return contains(range, loc); + }); + if (useIt != def.uses.end()) { + if (overlappedRange) + *overlappedRange = *useIt; + return true; + } + return false; +} + +/// Given a location pointing to a result, return the result number it refers +/// to or None if it refers to all of the results. +static Optional getResultNumberFromLoc(llvm::SMLoc loc) { + // Skip all of the identifier characters. + auto isIdentifierChar = [](char c) { + return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' || + c == '-'; }; - return contains(def.loc, loc) || llvm::any_of(def.uses, isUseFn); + const char *curPtr = loc.getPointer(); + while (isIdentifierChar(*curPtr)) + ++curPtr; + if (*curPtr != '#') + return llvm::None; + + // Compute the number from the string. + const char *numberStart = ++curPtr; + while (llvm::isDigit(*curPtr)) + ++curPtr; + StringRef numberStr(numberStart, curPtr - numberStart); + unsigned resultNumber = 0; + return numberStr.consumeInteger(10, resultNumber) ? Optional() + : resultNumber; +} + +/// Given a source location range, return the text covered by the given range. +/// If the range is invalid, returns None. +static Optional getTextFromRange(llvm::SMRange range) { + if (!range.isValid()) + return None; + const char *startPtr = range.Start.getPointer(); + return StringRef(startPtr, range.End.getPointer() - startPtr); +} + +/// Given a block, return its position in its parent region. +static unsigned getBlockNumber(Block *block) { + return std::distance(block->getParent()->begin(), block->getIterator()); +} + +/// Given a block and source location, print the source name of the block to the +/// given output stream. +static void printDefBlockName(raw_ostream &os, Block *block, + llvm::SMRange loc = {}) { + // Try to extract a name from the source location. + Optional text = getTextFromRange(loc); + if (text && text->startswith("^")) { + os << *text; + return; + } + + // Otherwise, we don't have a name so print the block number. + os << ""; +} +static void printDefBlockName(raw_ostream &os, + const AsmParserState::BlockDefinition &def) { + printDefBlockName(os, def.block, def.definition.loc); } //===----------------------------------------------------------------------===// @@ -110,11 +181,33 @@ MLIRDocument(const lsp::URIForFile &uri, StringRef contents, DialectRegistry ®istry); + //===--------------------------------------------------------------------===// + // Definitions and References + //===--------------------------------------------------------------------===// + void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, std::vector &locations); void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, std::vector &references); + //===--------------------------------------------------------------------===// + // Hover + //===--------------------------------------------------------------------===// + + Optional findHover(const lsp::URIForFile &uri, + const lsp::Position &hoverPos); + Optional + buildHoverForOperation(const AsmParserState::OperationDefinition &op); + lsp::Hover buildHoverForOperationResult(llvm::SMRange hoverRange, + Operation *op, unsigned resultStart, + unsigned resultEnd, + llvm::SMLoc posLoc); + lsp::Hover buildHoverForBlock(llvm::SMRange hoverRange, + const AsmParserState::BlockDefinition &block); + lsp::Hover + buildHoverForBlockArgument(llvm::SMRange hoverRange, BlockArgument arg, + const AsmParserState::BlockDefinition &block); + /// The context used to hold the state contained by the parsed document. MLIRContext context; @@ -155,6 +248,10 @@ return; } +//===----------------------------------------------------------------------===// +// MLIRDocument: Definitions and References +//===----------------------------------------------------------------------===// + void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, std::vector &locations) { @@ -223,6 +320,154 @@ } } +//===----------------------------------------------------------------------===// +// MLIRDocument: Hover +//===----------------------------------------------------------------------===// + +Optional MLIRDocument::findHover(const lsp::URIForFile &uri, + const lsp::Position &hoverPos) { + llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, hoverPos); + llvm::SMRange hoverRange; + + // Check for Hovers on operations and results. + for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { + // Check if the position points at this operation. + if (contains(op.loc, posLoc)) + return buildHoverForOperation(op); + + // Check if the position points at a result group. + for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) { + const auto &result = op.resultGroups[i]; + if (!isDefOrUse(result.second, posLoc, &hoverRange)) + continue; + + // Get the range of results covered by the over position. + unsigned resultStart = result.first; + unsigned resultEnd = + (i == e - 1) ? op.op->getNumResults() : op.resultGroups[i + 1].first; + return buildHoverForOperationResult(hoverRange, op.op, resultStart, + resultEnd, posLoc); + } + } + + // Check to see if the hover is over a block argument. + for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { + if (isDefOrUse(block.definition, posLoc, &hoverRange)) + return buildHoverForBlock(hoverRange, block); + + for (const auto &arg : llvm::enumerate(block.arguments)) { + if (!isDefOrUse(arg.value(), posLoc, &hoverRange)) + continue; + + return buildHoverForBlockArgument( + hoverRange, block.block->getArgument(arg.index()), block); + } + } + return llvm::None; +} + +Optional MLIRDocument::buildHoverForOperation( + const AsmParserState::OperationDefinition &op) { + // Don't show hovers for operations with regions to avoid huge hover blocks. + // TODO: Should we add support for printing an op without its regions? + if (llvm::any_of(op.op->getRegions(), + [](Region ®ion) { return !region.empty(); })) + return llvm::None; + + lsp::Hover hover(getRangeFromLoc(sourceMgr, op.loc)); + llvm::raw_string_ostream os(hover.contents.value); + + // For hovers on an operation, show the generic form. + os << "```mlir\n"; + op.op->print(os, OpPrintingFlags().printGenericOpForm()); + os << "\n```\n"; + + return hover; +} + +lsp::Hover MLIRDocument::buildHoverForOperationResult(llvm::SMRange hoverRange, + Operation *op, + unsigned resultStart, + unsigned resultEnd, + llvm::SMLoc posLoc) { + lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); + llvm::raw_string_ostream os(hover.contents.value); + + // Add the parent operation name to the hover. + os << "Operation: \"" << op->getName() << "\"\n\n"; + + // Check to see if the location points to a specific result within the + // group. + if (Optional resultNumber = getResultNumberFromLoc(posLoc)) { + if ((resultStart + *resultNumber) < resultEnd) { + resultStart += *resultNumber; + resultEnd = resultStart + 1; + } + } + + // Add the range of results and their types to the hover info. + if ((resultStart + 1) == resultEnd) { + os << "Result #" << resultStart << "\n\n" + << "Type: `" << op->getResult(resultStart).getType() << "`\n\n"; + } else { + os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n" + << "Types: "; + llvm::interleaveComma( + op->getResults().slice(resultStart, resultEnd), os, + [&](Value result) { os << "`" << result.getType() << "`"; }); + } + + return hover; +} + +lsp::Hover +MLIRDocument::buildHoverForBlock(llvm::SMRange hoverRange, + const AsmParserState::BlockDefinition &block) { + lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); + llvm::raw_string_ostream os(hover.contents.value); + + // Print the given block to the hover output stream. + auto printBlockToHover = [&](Block *newBlock) { + if (const auto *def = asmState.getBlockDef(newBlock)) + printDefBlockName(os, *def); + else + printDefBlockName(os, newBlock); + }; + + // Display the parent operation, block number, predecessors, and successors. + os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" + << "Block #" << getBlockNumber(block.block) << "\n\n"; + if (!block.block->hasNoPredecessors()) { + os << "Predecessors: "; + llvm::interleaveComma(block.block->getPredecessors(), os, + printBlockToHover); + os << "\n\n"; + } + if (!block.block->hasNoSuccessors()) { + os << "Successors: "; + llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover); + os << "\n\n"; + } + + return hover; +} + +lsp::Hover MLIRDocument::buildHoverForBlockArgument( + llvm::SMRange hoverRange, BlockArgument arg, + const AsmParserState::BlockDefinition &block) { + lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); + llvm::raw_string_ostream os(hover.contents.value); + + // Display the parent operation, block, the argument number, and the type. + os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" + << "Block: "; + printDefBlockName(os, block); + os << "\n\nArgument #" << arg.getArgNumber() << "\n\n" + << "Type: `" << arg.getType() << "`\n\n"; + + return hover; +} + //===----------------------------------------------------------------------===// // MLIRServer::Impl //===----------------------------------------------------------------------===// @@ -271,3 +516,11 @@ if (fileIt != impl->documents.end()) fileIt->second->findReferencesOf(uri, pos, references); } + +Optional lsp::MLIRServer::findHover(const URIForFile &uri, + const Position &hoverPos) { + auto fileIt = impl->documents.find(uri.file()); + if (fileIt != impl->documents.end()) + return fileIt->second->findHover(uri, hoverPos); + return llvm::None; +} diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h --- a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h @@ -358,6 +358,41 @@ bool fromJSON(const llvm::json::Value &value, DidChangeTextDocumentParams &result, llvm::json::Path path); +//===----------------------------------------------------------------------===// +// MarkupContent +//===----------------------------------------------------------------------===// + +/// Describes the content type that a client supports in various result literals +/// like `Hover`. +enum class MarkupKind { + PlainText, + Markdown, +}; +raw_ostream &operator<<(raw_ostream &os, MarkupKind kind); + +struct MarkupContent { + MarkupKind kind = MarkupKind::PlainText; + std::string value; +}; +llvm::json::Value toJSON(const MarkupContent &mc); + +//===----------------------------------------------------------------------===// +// Hover +//===----------------------------------------------------------------------===// + +struct Hover { + /// Construct a default hover with the given range that uses Markdown content. + Hover(Range range) : contents{MarkupKind::Markdown, ""}, range(range) {} + + /// The hover's content. + MarkupContent contents; + + /// An optional range is a range inside a text document that is used to + /// visualize a hover, e.g. by changing the background color. + Optional range; +}; +llvm::json::Value toJSON(const Hover &hover); + } // namespace lsp } // namespace mlir diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp --- a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp @@ -434,3 +434,42 @@ return o && o.map("textDocument", result.textDocument) && o.map("contentChanges", result.contentChanges); } + +//===----------------------------------------------------------------------===// +// MarkupContent +//===----------------------------------------------------------------------===// + +static llvm::StringRef toTextKind(MarkupKind kind) { + switch (kind) { + case MarkupKind::PlainText: + return "plaintext"; + case MarkupKind::Markdown: + return "markdown"; + } + llvm_unreachable("Invalid MarkupKind"); +} + +raw_ostream &mlir::lsp::operator<<(raw_ostream &os, MarkupKind kind) { + return os << toTextKind(kind); +} + +llvm::json::Value mlir::lsp::toJSON(const MarkupContent &mc) { + if (mc.value.empty()) + return nullptr; + + return llvm::json::Object{ + {"kind", toTextKind(mc.kind)}, + {"value", mc.value}, + }; +} + +//===----------------------------------------------------------------------===// +// Hover +//===----------------------------------------------------------------------===// + +llvm::json::Value mlir::lsp::toJSON(const Hover &hover) { + llvm::json::Object result{{"contents", toJSON(hover.contents)}}; + if (hover.range.hasValue()) + result["range"] = toJSON(*hover.range); + return std::move(result); +} diff --git a/mlir/test/mlir-lsp-server/hover.test b/mlir/test/mlir-lsp-server/hover.test new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-lsp-server/hover.test @@ -0,0 +1,109 @@ +// RUN: mlir-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"mlir","capabilities":{},"trace":"off"}} +// ----- +{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{ + "uri":"test:///foo.mlir", + "languageId":"mlir", + "version":1, + "text":"func @foo(%arg: i1) {\n%value = constant true\nbr ^bb2\n^bb2:\nreturn\n}" +}}} +// ----- +// Hover on an operation. +{"jsonrpc":"2.0","id":1,"method":"textDocument/hover","params":{ + "textDocument":{"uri":"test:///foo.mlir"}, + "position":{"line":1,"character":12} +}} +// CHECK: "id": 1, +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "contents": { +// CHECK-NEXT: "kind": "markdown", +// CHECK-NEXT: "value": "```mlir\n%true = \"std.constant\"() {value = true} : () -> i1\n```\n" +// CHECK-NEXT: }, +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": 17, +// CHECK-NEXT: "line": 1 +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": 10, +// CHECK-NEXT: "line": 1 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- +// Hover on an operation result. +{"jsonrpc":"2.0","id":1,"method":"textDocument/hover","params":{ + "textDocument":{"uri":"test:///foo.mlir"}, + "position":{"line":1,"character":2} +}} +// CHECK: "id": 1, +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "contents": { +// CHECK-NEXT: "kind": "markdown", +// CHECK-NEXT: "value": "Operation: \"std.constant\"\n\nResult #0\n\nType: `i1`\n\n" +// CHECK-NEXT: }, +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": 6, +// CHECK-NEXT: "line": 1 +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": 1, +// CHECK-NEXT: "line": 1 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- +// Hover on a Block. +{"jsonrpc":"2.0","id":1,"method":"textDocument/hover","params":{ + "textDocument":{"uri":"test:///foo.mlir"}, + "position":{"line":3,"character":2} +}} +// CHECK: "id": 1, +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "contents": { +// CHECK-NEXT: "kind": "markdown", +// CHECK-NEXT: "value": "Operation: \"func\"\n\nBlock #1\n\nPredecessors: \n\n" +// CHECK-NEXT: }, +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": 4, +// CHECK-NEXT: "line": 3 +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": 1, +// CHECK-NEXT: "line": 3 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- +// Hover on a Block argument. +{"jsonrpc":"2.0","id":1,"method":"textDocument/hover","params":{ + "textDocument":{"uri":"test:///foo.mlir"}, + "position":{"line":0,"character":12} +}} +// CHECK: "id": 1, +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "contents": { +// CHECK-NEXT: "kind": "markdown", +// CHECK-NEXT: "value": "Operation: \"func\"\n\nBlock: \n\nArgument #0\n\nType: `i1`\n\n" +// CHECK-NEXT: }, +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": 14, +// CHECK-NEXT: "line": 0 +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": 11, +// CHECK-NEXT: "line": 0 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":3,"method":"shutdown"} +// ----- +{"jsonrpc":"2.0","method":"exit"} diff --git a/mlir/test/mlir-lsp-server/initialize-params.test b/mlir/test/mlir-lsp-server/initialize-params.test --- a/mlir/test/mlir-lsp-server/initialize-params.test +++ b/mlir/test/mlir-lsp-server/initialize-params.test @@ -6,6 +6,7 @@ // CHECK-NEXT: "result": { // CHECK-NEXT: "capabilities": { // CHECK-NEXT: "definitionProvider": true, +// CHECK-NEXT: "hoverProvider": true, // CHECK-NEXT: "referencesProvider": true, // CHECK-NEXT: "textDocumentSync": { // CHECK-NEXT: "change": 1,