diff --git a/mlir/docs/Tools/MLIRLSP.md b/mlir/docs/Tools/MLIRLSP.md --- a/mlir/docs/Tools/MLIRLSP.md +++ b/mlir/docs/Tools/MLIRLSP.md @@ -320,6 +320,18 @@ ![IMG](/tblgen-lsp-server/find_references.gif) +#### Hover + +Hover over a symbol to see more information about it, such as its type, +documentation, and more. + +![IMG](/tblgen-lsp-server/hover_def.png) + +Hovering over an overridden field will also show you information such as +documentation from the base value: + +![IMG](/tblgen-lsp-server/hover_field.png) + ## Language Server Design The design of the various language servers provided by MLIR are effectively the diff --git a/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.h b/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.h --- a/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.h +++ b/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.h @@ -28,6 +28,11 @@ /// supports identifier-like tokens, strings, etc. SMRange convertTokenLocToRange(SMLoc loc); +/// Extract a documentation comment for the given location within the source +/// manager. Returns None if no comment could be computed. +Optional extractSourceDocComment(llvm::SourceMgr &sourceMgr, + SMLoc loc); + //===----------------------------------------------------------------------===// // SourceMgrInclude //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp b/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp --- a/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp +++ b/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp @@ -65,6 +65,50 @@ return SMRange(loc, SMLoc::getFromPointer(curPtr)); } +Optional lsp::extractSourceDocComment(llvm::SourceMgr &sourceMgr, + SMLoc loc) { + // This is a heuristic, and isn't intended to cover every case, but should + // cover the most common. We essentially look for a comment preceding the + // line, and if we find one, use that as the documentation. + if (!loc.isValid()) + return llvm::None; + int bufferId = sourceMgr.FindBufferContainingLoc(loc); + if (bufferId == 0) + return llvm::None; + const char *bufferStart = + sourceMgr.getMemoryBuffer(bufferId)->getBufferStart(); + StringRef buffer(bufferStart, loc.getPointer() - bufferStart); + + // Pop the last line from the buffer string. + auto popLastLine = [&]() -> Optional { + size_t newlineOffset = buffer.find_last_of("\n"); + if (newlineOffset == StringRef::npos) + return llvm::None; + StringRef lastLine = buffer.drop_front(newlineOffset).trim(); + buffer = buffer.take_front(newlineOffset); + return lastLine; + }; + + // Try to pop the current line. + if (!popLastLine()) + return llvm::None; + + // Try to parse a comment string from the source file. + SmallVector commentLines; + while (Optional line = popLastLine()) { + // Check for a comment at the beginning of the line. + if (!line->startswith("//")) + break; + + // Extract the document string from the comment. + commentLines.push_back(line->drop_while([](char c) { return c == '/'; })); + } + + if (commentLines.empty()) + return llvm::None; + return llvm::join(llvm::reverse(commentLines), "\n"); +} + //===----------------------------------------------------------------------===// // SourceMgrInclude //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp @@ -126,47 +126,8 @@ return doc->str(); // If the decl doesn't yet have documentation, try to extract it from the - // source file. This is a heuristic, and isn't intended to cover every case, - // but should cover the most common. We essentially look for a comment - // preceding the decl, and if we find one, use that as the documentation. - SMLoc startLoc = decl->getLoc().Start; - if (!startLoc.isValid()) - return llvm::None; - int bufferId = sourceMgr.FindBufferContainingLoc(startLoc); - if (bufferId == 0) - return llvm::None; - const char *bufferStart = - sourceMgr.getMemoryBuffer(bufferId)->getBufferStart(); - StringRef buffer(bufferStart, startLoc.getPointer() - bufferStart); - - // Pop the last line from the buffer string. - auto popLastLine = [&]() -> Optional { - size_t newlineOffset = buffer.find_last_of("\n"); - if (newlineOffset == StringRef::npos) - return llvm::None; - StringRef lastLine = buffer.drop_front(newlineOffset).trim(); - buffer = buffer.take_front(newlineOffset); - return lastLine; - }; - - // Try to pop the current line, which contains the decl. - if (!popLastLine()) - return llvm::None; - - // Try to parse a comment string from the source file. - SmallVector commentLines; - while (Optional line = popLastLine()) { - // Check for a comment at the beginning of the line. - if (!line->startswith("//")) - break; - - // Extract the document string from the comment. - commentLines.push_back(line->drop_while([](char c) { return c == '/'; })); - } - - if (commentLines.empty()) - return llvm::None; - return llvm::join(llvm::reverse(commentLines), "\n"); + // source file. + return lsp::extractSourceDocComment(sourceMgr, decl->getLoc().Start); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp --- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp +++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp @@ -12,6 +12,7 @@ #include "../lsp-server-support/Logging.h" #include "../lsp-server-support/Protocol.h" #include "../lsp-server-support/SourceMgrUtils.h" +#include "mlir/Support/IndentedOstream.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/PointerUnion.h" @@ -93,6 +94,39 @@ return lspDiag; } +/// Get the base definition of the given record value, or nullptr if one +/// couldn't be found. +static std::pair +getBaseValue(const llvm::Record *record, const llvm::RecordVal *value) { + if (value->isTemplateArg()) + return {nullptr, nullptr}; + + // Find a base value for the field in the super classes of the given record. + // On success, `record` is updated to the new parent record. + StringRef valueName = value->getName(); + auto findValueInSupers = + [&](const llvm::Record *&record) -> llvm::RecordVal * { + for (auto [parentRecord, loc] : record->getSuperClasses()) { + if (auto *newBase = parentRecord->getValue(valueName)) { + record = parentRecord; + return newBase; + } + } + return nullptr; + }; + + // Try to find the lowest definition of the record value. + std::pair baseValue = {}; + while (const llvm::RecordVal *newBase = findValueInSupers(record)) + baseValue = {record, newBase}; + + // Check that the base isn't the same as the current value (e.g. if the value + // wasn't overridden). + if (!baseValue.second || baseValue.second->getLoc() == value->getLoc()) + return {nullptr, nullptr}; + return baseValue; +} + //===----------------------------------------------------------------------===// // TableGenIndex //===----------------------------------------------------------------------===// @@ -109,7 +143,7 @@ : definition(value), defLoc(lsp::convertTokenLocToRange(value->getLoc())) {} - /// The main definition of the symbol. + // The main definition of the symbol. PointerUnion definition; /// The source location of the definition. @@ -118,6 +152,38 @@ /// The source location of the references of the definition. SmallVector references; }; +/// This class represents a single record symbol. +struct TableGenRecordSymbol : public TableGenIndexSymbol { + TableGenRecordSymbol(const llvm::Record *record) + : TableGenIndexSymbol(record) {} + + static bool classof(const TableGenIndexSymbol *symbol) { + return symbol->definition.is(); + } + + /// Return the value of this symbol. + const llvm::Record *getValue() const { + return definition.get(); + } +}; +/// This class represents a single record value symbol. +struct TableGenRecordValSymbol : public TableGenIndexSymbol { + TableGenRecordValSymbol(const llvm::Record *record, + const llvm::RecordVal *value) + : TableGenIndexSymbol(value), record(record) {} + + static bool classof(const TableGenIndexSymbol *symbol) { + return symbol->definition.is(); + } + + /// Return the value of this symbol. + const llvm::RecordVal *getValue() const { + return definition.get(); + } + + /// The parent record of this symbol. + const llvm::Record *record; +}; /// This class provides an index for definitions/uses within a TableGen /// document. It provides efficient lookup of a definition given an input source @@ -144,6 +210,24 @@ const TableGenIndexSymbol *>::LeafSize, llvm::IntervalMapHalfOpenInfo>; + /// Get or insert a symbol for the given record. + TableGenIndexSymbol *getOrInsertDef(const llvm::Record *record) { + auto it = defToSymbol.try_emplace(record, nullptr); + if (it.second) + it.first->second = std::make_unique(record); + return &*it.first->second; + } + /// Get or insert a symbol for the given record value. + TableGenIndexSymbol *getOrInsertDef(const llvm::Record *record, + const llvm::RecordVal *value) { + auto it = defToSymbol.try_emplace(value, nullptr); + if (it.second) { + it.first->second = + std::make_unique(record, value); + } + return &*it.first->second; + } + /// An allocator for the interval map. MapT::Allocator allocator; @@ -157,12 +241,6 @@ } // namespace void TableGenIndex::initialize(const llvm::RecordKeeper &records) { - auto getOrInsertDef = [&](const auto *def) -> TableGenIndexSymbol * { - auto it = defToSymbol.try_emplace(def, nullptr); - if (it.second) - it.first->second = std::make_unique(def); - return &*it.first->second; - }; auto insertRef = [&](TableGenIndexSymbol *sym, SMRange refLoc, bool isDef = false) { const char *startLoc = refLoc.Start.getPointer(); @@ -207,7 +285,7 @@ // Add definitions for any values. for (const llvm::RecordVal &value : def.getValues()) { - auto *sym = getOrInsertDef(&value); + auto *sym = getOrInsertDef(&def, &value); insertRef(sym, sym->defLoc, /*isDef=*/true); for (SMRange refLoc : value.getReferenceLocs()) insertRef(sym, refLoc); @@ -272,6 +350,14 @@ Optional findHover(const lsp::URIForFile &uri, const lsp::Position &hoverPos); + lsp::Hover buildHoverForRecord(const llvm::Record *record, + const SMRange &hoverRange); + lsp::Hover buildHoverForTemplateArg(const llvm::Record *record, + const llvm::RecordVal *value, + const SMRange &hoverRange); + lsp::Hover buildHoverForField(const llvm::Record *record, + const llvm::RecordVal *value, + const SMRange &hoverRange); private: /// Initialize the text file from the given file contents. @@ -422,7 +508,116 @@ for (const lsp::SourceMgrInclude &include : parsedIncludes) if (include.range.contains(hoverPos)) return include.buildHover(); - return llvm::None; + + // Find the symbol at the given location. + SMRange hoverRange; + SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr); + const TableGenIndexSymbol *symbol = index.lookup(posLoc, &hoverRange); + if (!symbol) + return llvm::None; + + // Build hover for a Record. + if (auto *record = dyn_cast(symbol)) + return buildHoverForRecord(record->getValue(), hoverRange); + + // Build hover for a RecordVal, which is either a template argument or a + // field. + auto *recordVal = cast(symbol); + const llvm::RecordVal *value = recordVal->getValue(); + if (value->isTemplateArg()) + return buildHoverForTemplateArg(recordVal->record, value, hoverRange); + return buildHoverForField(recordVal->record, value, hoverRange); +} + +lsp::Hover TableGenTextFile::buildHoverForRecord(const llvm::Record *record, + const SMRange &hoverRange) { + lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); + { + llvm::raw_string_ostream hoverOS(hover.contents.value); + + // Format the type of record this is. + if (record->isClass()) { + hoverOS << "**class** `" << record->getName() << "`"; + } else if (record->isAnonymous()) { + hoverOS << "**anonymous class**"; + } else { + hoverOS << "**def** `" << record->getName() << "`"; + } + hoverOS << "\n***\n"; + + // Check if this record has summary/description fields. These are often used + // to hold documentation for the record. + auto printAndFormatField = [&](StringRef fieldName) { + // Check that the record actually has the given field, and that it's a + // string. + const llvm::RecordVal *value = record->getValue(fieldName); + if (!value || !value->getValue()) + return; + auto *stringValue = dyn_cast(value->getValue()); + if (!stringValue) + return; + + raw_indented_ostream ros(hoverOS); + ros.printReindented(stringValue->getValue().rtrim(" \t")); + hoverOS << "\n***\n"; + }; + printAndFormatField("summary"); + printAndFormatField("description"); + + // Check for documentation in the source file. + if (Optional doc = + lsp::extractSourceDocComment(sourceMgr, record->getLoc().front())) { + hoverOS << "\n" << *doc << "\n"; + } + } + return hover; +} + +lsp::Hover +TableGenTextFile::buildHoverForTemplateArg(const llvm::Record *record, + const llvm::RecordVal *value, + const SMRange &hoverRange) { + lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); + { + llvm::raw_string_ostream hoverOS(hover.contents.value); + StringRef name = value->getName().rsplit(':').second; + + hoverOS << "**template arg** `" << name << "`\n***\nType: `"; + value->getType()->print(hoverOS); + hoverOS << "`\n"; + } + return hover; +} + +lsp::Hover TableGenTextFile::buildHoverForField(const llvm::Record *record, + const llvm::RecordVal *value, + const SMRange &hoverRange) { + lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); + { + llvm::raw_string_ostream hoverOS(hover.contents.value); + hoverOS << "**field** `" << value->getName() << "`\n***\nType: `"; + value->getType()->print(hoverOS); + hoverOS << "`\n***\n"; + + // Check for documentation in the source file. + if (Optional doc = + lsp::extractSourceDocComment(sourceMgr, value->getLoc())) { + hoverOS << "\n" << *doc << "\n"; + hoverOS << "\n***\n"; + } + + // Check to see if there is a base value that we can use for + // documentation. + auto [baseRecord, baseValue] = getBaseValue(record, value); + if (baseValue) { + if (Optional doc = + lsp::extractSourceDocComment(sourceMgr, baseValue->getLoc())) { + hoverOS << "\n *From `" << baseRecord->getName() << "`*:\n\n" + << *doc << "\n"; + } + } + } + return hover; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/tblgen-lsp-server/hover.test b/mlir/test/tblgen-lsp-server/hover.test --- a/mlir/test/tblgen-lsp-server/hover.test +++ b/mlir/test/tblgen-lsp-server/hover.test @@ -5,7 +5,7 @@ "uri":"test:///foo.td", "languageId":"tablegen", "version":1, - "text":"include \"include/included.td\"" + "text":"include \"include/included.td\"\n// This is a def.\ndef Def : IncludedClass;\nclass Class {\n int arg = c.arg;\n}\n" }}} // ----- // Hover on an include file. @@ -32,6 +32,54 @@ // CHECK-NEXT: } // CHECK-NEXT: } // ----- +// Hover on a record. +{"jsonrpc":"2.0","id":1,"method":"textDocument/hover","params":{ + "textDocument":{"uri":"test:///foo.td"}, + "position":{"line":2,"character":6} +}} +// CHECK: "id": 1, +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "contents": { +// CHECK-NEXT: "kind": "markdown", +// CHECK-NEXT: "value": "**def** `Def`\n***\n\n This is a def.\n" +// CHECK-NEXT: }, +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": 7, +// CHECK-NEXT: "line": 2 +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": 4, +// CHECK-NEXT: "line": 2 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- +// Hover on a record. +{"jsonrpc":"2.0","id":1,"method":"textDocument/hover","params":{ + "textDocument":{"uri":"test:///foo.td"}, + "position":{"line":4,"character":16} +}} +// CHECK: "id": 1, +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "contents": { +// CHECK-NEXT: "kind": "markdown", +// CHECK-NEXT: "value": "**field** `arg`\n***\nType: `int`\n***\n\n This argument was defined on an included class.\n\n***\n" +// CHECK-NEXT: }, +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": 17, +// CHECK-NEXT: "line": 4 +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": 14, +// CHECK-NEXT: "line": 4 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- {"jsonrpc":"2.0","id":7,"method":"shutdown"} // ----- {"jsonrpc":"2.0","method":"exit"} diff --git a/mlir/test/tblgen-lsp-server/include/included.td b/mlir/test/tblgen-lsp-server/include/included.td --- a/mlir/test/tblgen-lsp-server/include/included.td +++ b/mlir/test/tblgen-lsp-server/include/included.td @@ -1,3 +1,8 @@ // This file is merely to test the processing of includes, it has -// no other purpose or contents. +// no other purpose. + +class IncludedClass { + /// This argument was defined on an included class. + int arg = 10; +}