diff --git a/mlir/lib/Tools/mlir-tblgen-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-tblgen-lsp-server/LSPServer.cpp --- a/mlir/lib/Tools/mlir-tblgen-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-tblgen-lsp-server/LSPServer.cpp @@ -42,6 +42,14 @@ void onDocumentDidClose(const DidCloseTextDocumentParams ¶ms); void onDocumentDidChange(const DidChangeTextDocumentParams ¶ms); + //===--------------------------------------------------------------------===// + // Definitions and References + + void onGoToDefinition(const TextDocumentPositionParams ¶ms, + Callback> reply); + void onReference(const ReferenceParams ¶ms, + Callback> reply); + //===----------------------------------------------------------------------===// // DocumentLink @@ -84,6 +92,8 @@ {"change", (int)TextDocumentSyncKind::Full}, {"save", true}, }}, + {"definitionProvider", true}, + {"referencesProvider", true}, {"documentLinkProvider", llvm::json::Object{ {"resolveProvider", false}, @@ -142,6 +152,23 @@ publishDiagnostics(diagParams); } +//===----------------------------------------------------------------------===// +// Definitions and References + +void LSPServer::onGoToDefinition(const TextDocumentPositionParams ¶ms, + Callback> reply) { + std::vector locations; + server.getLocationsOf(params.textDocument.uri, params.position, locations); + reply(std::move(locations)); +} + +void LSPServer::onReference(const ReferenceParams ¶ms, + Callback> reply) { + std::vector locations; + server.findReferencesOf(params.textDocument.uri, params.position, locations); + reply(std::move(locations)); +} + //===----------------------------------------------------------------------===// // DocumentLink @@ -183,6 +210,12 @@ messageHandler.notification("textDocument/didChange", &lspServer, &LSPServer::onDocumentDidChange); + // Definitions and References + messageHandler.method("textDocument/definition", &lspServer, + &LSPServer::onGoToDefinition); + messageHandler.method("textDocument/references", &lspServer, + &LSPServer::onReference); + // Document Link messageHandler.method("textDocument/documentLink", &lspServer, &LSPServer::onDocumentLink); diff --git a/mlir/lib/Tools/mlir-tblgen-lsp-server/TableGenServer.h b/mlir/lib/Tools/mlir-tblgen-lsp-server/TableGenServer.h --- a/mlir/lib/Tools/mlir-tblgen-lsp-server/TableGenServer.h +++ b/mlir/lib/Tools/mlir-tblgen-lsp-server/TableGenServer.h @@ -19,6 +19,7 @@ struct Diagnostic; struct DocumentLink; struct Hover; +struct Location; struct Position; class URIForFile; @@ -55,6 +56,14 @@ /// the server. Optional removeDocument(const URIForFile &uri); + /// Return the locations of the object pointed at by the given position. + void getLocationsOf(const URIForFile &uri, const Position &defPos, + std::vector &locations); + + /// Find all references of the object pointed at by the given position. + void findReferencesOf(const URIForFile &uri, const Position &pos, + std::vector &references); + /// Return the document links referenced by the given file. void getDocumentLinks(const URIForFile &uri, std::vector &documentLinks); diff --git a/mlir/lib/Tools/mlir-tblgen-lsp-server/TableGenServer.cpp b/mlir/lib/Tools/mlir-tblgen-lsp-server/TableGenServer.cpp --- a/mlir/lib/Tools/mlir-tblgen-lsp-server/TableGenServer.cpp +++ b/mlir/lib/Tools/mlir-tblgen-lsp-server/TableGenServer.cpp @@ -13,6 +13,7 @@ #include "../lsp-server-support/Protocol.h" #include "../lsp-server-support/SourceMgrUtils.h" #include "llvm/ADT/IntervalMap.h" +#include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" @@ -40,10 +41,14 @@ } /// Returns a language server location from the given source range. +static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange loc, + const lsp::URIForFile &uri) { + return lsp::Location(getURIFromLoc(mgr, loc.Start, uri), + lsp::Range(mgr, loc)); +} static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMLoc loc, const lsp::URIForFile &uri) { - return lsp::Location(getURIFromLoc(mgr, loc, uri), - lsp::Range(mgr, lsp::convertTokenLocToRange(loc))); + return getLocationFromLoc(mgr, lsp::convertTokenLocToRange(loc), uri); } /// Convert the given TableGen diagnostic to the LSP form. @@ -90,6 +95,133 @@ return lspDiag; } +//===----------------------------------------------------------------------===// +// TableGenIndex +//===----------------------------------------------------------------------===// + +namespace { +/// This class represents a single symbol definition within a TableGen index. It +/// contains the definition of the symbol, the location of the symbol, and any +/// recorded references. +struct TableGenIndexSymbol { + TableGenIndexSymbol(const llvm::Record *record) + : definition(record), + defLoc(lsp::convertTokenLocToRange(record->getLoc().front())) {} + TableGenIndexSymbol(const llvm::RecordVal *value) + : definition(value), + defLoc(lsp::convertTokenLocToRange(value->getLoc())) {} + + /// The main definition of the symbol. + PointerUnion definition; + + /// The source location of the definition. + SMRange defLoc; + + /// The source location of the references of the definition. + SmallVector references; +}; + +/// This class provides an index for definitions/uses within a TableGen +/// document. It provides efficient lookup of a definition given an input source +/// range. +class TableGenIndex { +public: + TableGenIndex() : intervalMap(allocator) {} + + /// Initialize the index with the given RecordKeeper. + void initialize(const llvm::RecordKeeper &records); + + /// Lookup a symbol for the given location. Returns nullptr if no symbol could + /// be found. If provided, `overlappedRange` is set to the range that the + /// provided `loc` overlapped with. + const TableGenIndexSymbol *lookup(SMLoc loc, + SMRange *overlappedRange = nullptr) const; + +private: + /// The type of interval map used to store source references. SMRange is + /// half-open, so we also need to use a half-open interval map. + using MapT = llvm::IntervalMap< + const char *, const TableGenIndexSymbol *, + llvm::IntervalMapImpl::NodeSizer::LeafSize, + llvm::IntervalMapHalfOpenInfo>; + + /// An allocator for the interval map. + MapT::Allocator allocator; + + /// An interval map containing a corresponding definition mapped to a source + /// interval. + MapT intervalMap; + + /// A mapping between definitions and their corresponding symbol. + DenseMap> defToSymbol; +}; +} // namespace + +void TableGenIndex::initialize(const llvm::RecordKeeper &records) { + auto getOrInsertDef = [&](const auto *def) -> TableGenIndexSymbol * { + auto it = defToSymbol.try_emplace(def, nullptr); + if (it.second) + it.first->second = std::make_unique(def); + return &*it.first->second; + }; + auto insertRef = [&](TableGenIndexSymbol *sym, SMRange refLoc, + bool isDef = false) { + const char *startLoc = refLoc.Start.getPointer(); + const char *endLoc = refLoc.End.getPointer(); + + if (startLoc == endLoc) { + refLoc = lsp::convertTokenLocToRange(SMLoc::getFromPointer(startLoc)); + startLoc = refLoc.Start.getPointer(); + endLoc = refLoc.End.getPointer(); + if (startLoc == endLoc) + return; + } + + if (!intervalMap.overlaps(startLoc, endLoc)) + intervalMap.insert(startLoc, endLoc, sym); + + if (!isDef) + sym->references.push_back(refLoc); + }; + auto classes = + llvm::make_pointee_range(llvm::make_second_range(records.getClasses())); + auto defs = + llvm::make_pointee_range(llvm::make_second_range(records.getDefs())); + for (const llvm::Record &def : llvm::concat(classes, defs)) { + auto *sym = getOrInsertDef(&def); + insertRef(sym, sym->defLoc, /*isDef=*/true); + + // Add references to the definition. + for (SMLoc loc : def.getLoc().drop_front()) + insertRef(sym, lsp::convertTokenLocToRange(loc)); + + // Add references to any super classes. + for (auto &it : def.getSuperClasses()) + insertRef(getOrInsertDef(it.first), + lsp::convertTokenLocToRange(it.second.Start)); + + // Add definitions for any values. + for (const llvm::RecordVal &value : def.getValues()) { + auto *sym = getOrInsertDef(&value); + insertRef(sym, sym->defLoc, /*isDef=*/true); + } + } +} + +const TableGenIndexSymbol * +TableGenIndex::lookup(SMLoc loc, SMRange *overlappedRange) const { + auto it = intervalMap.find(loc.getPointer()); + if (!it.valid() || loc.getPointer() < it.start()) + return nullptr; + + if (overlappedRange) { + *overlappedRange = SMRange(SMLoc::getFromPointer(it.start()), + SMLoc::getFromPointer(it.stop())); + } + return it.value(); +} + //===----------------------------------------------------------------------===// // TableGenTextFile //===----------------------------------------------------------------------===// @@ -106,6 +238,15 @@ /// Return the current version of this text file. int64_t getVersion() const { return version; } + //===--------------------------------------------------------------------===// + // Definitions and References + //===--------------------------------------------------------------------===// + + void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, + std::vector &locations); + void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, + std::vector &references); + //===--------------------------------------------------------------------===// // Document Links //===--------------------------------------------------------------------===// @@ -136,6 +277,9 @@ /// The record keeper containing the parsed tablegen constructs. llvm::RecordKeeper recordKeeper; + /// The index of the parsed file. + TableGenIndex index; + /// The set of includes of the parsed file. SmallVector parsedIncludes; }; @@ -183,6 +327,37 @@ lsp::gatherIncludeFiles(sourceMgr, parsedIncludes); if (failedToParse) return; + + // If we successfully parsed the file, we can now build the index. + index.initialize(recordKeeper); +} + +//===----------------------------------------------------------------------===// +// TableGenTextFile: Definitions and References +//===----------------------------------------------------------------------===// + +void TableGenTextFile::getLocationsOf(const lsp::URIForFile &uri, + const lsp::Position &defPos, + std::vector &locations) { + SMLoc posLoc = defPos.getAsSMLoc(sourceMgr); + const TableGenIndexSymbol *symbol = index.lookup(posLoc); + if (!symbol) + return; + + locations.push_back(getLocationFromLoc(sourceMgr, symbol->defLoc, uri)); +} + +void TableGenTextFile::findReferencesOf( + const lsp::URIForFile &uri, const lsp::Position &pos, + std::vector &references) { + SMLoc posLoc = pos.getAsSMLoc(sourceMgr); + const TableGenIndexSymbol *symbol = index.lookup(posLoc); + if (!symbol) + return; + + references.push_back(getLocationFromLoc(sourceMgr, symbol->defLoc, uri)); + for (SMRange refLoc : symbol->references) + references.push_back(getLocationFromLoc(sourceMgr, refLoc, uri)); } //===--------------------------------------------------------------------===// @@ -258,6 +433,22 @@ return version; } +void lsp::TableGenServer::getLocationsOf(const URIForFile &uri, + const Position &defPos, + std::vector &locations) { + auto fileIt = impl->files.find(uri.file()); + if (fileIt != impl->files.end()) + fileIt->second->getLocationsOf(uri, defPos, locations); +} + +void lsp::TableGenServer::findReferencesOf(const URIForFile &uri, + const Position &pos, + std::vector &references) { + auto fileIt = impl->files.find(uri.file()); + if (fileIt != impl->files.end()) + fileIt->second->findReferencesOf(uri, pos, references); +} + void lsp::TableGenServer::getDocumentLinks( const URIForFile &uri, std::vector &documentLinks) { auto fileIt = impl->files.find(uri.file()); diff --git a/mlir/test/mlir-tblgen-lsp-server/definition.test b/mlir/test/mlir-tblgen-lsp-server/definition.test new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen-lsp-server/definition.test @@ -0,0 +1,55 @@ +// RUN: mlir-tblgen-lsp-server -lit-test < %s | FileCheck %s +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"tblgen","capabilities":{},"trace":"off"}} +// ----- +{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{ + "uri":"test:///foo.td", + "languageId":"tblgen", + "version":1, + "text":"class Foo {\n int field1 = ?;\n}\ndef FooDerived : Foo {\n let field1 = 10;\n }\n" +}}} +// ----- +{"jsonrpc":"2.0","id":1,"method":"textDocument/definition","params":{ + "textDocument":{"uri":"test:///foo.td"}, + "position":{"line":3,"character":19} +}} +// CHECK: "id": 1 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": [ +// CHECK-NEXT: { +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": 9, +// CHECK-NEXT: "line": 0 +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": 6, +// CHECK-NEXT: "line": 0 +// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: "uri": "{{.*}}/foo.td" +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":2,"method":"textDocument/definition","params":{ + "textDocument":{"uri":"test:///foo.td"}, + "position":{"line":4,"character":9} +}} +// CHECK: "id": 2 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": [ +// CHECK-NEXT: { +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": 12, +// CHECK-NEXT: "line": 4 +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": 6, +// CHECK-NEXT: "line": 4 +// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: "uri": "{{.*}}/foo.td" +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":3,"method":"shutdown"} +// ----- +{"jsonrpc":"2.0","method":"exit"} diff --git a/mlir/test/mlir-tblgen-lsp-server/initialize-params.test b/mlir/test/mlir-tblgen-lsp-server/initialize-params.test --- a/mlir/test/mlir-tblgen-lsp-server/initialize-params.test +++ b/mlir/test/mlir-tblgen-lsp-server/initialize-params.test @@ -5,11 +5,13 @@ // CHECK-NEXT: "jsonrpc": "2.0", // CHECK-NEXT: "result": { // CHECK-NEXT: "capabilities": { +// CHECK-NEXT: "definitionProvider": true, // CHECK-NEXT: "documentLinkProvider": { // CHECK-NEXT: "resolveProvider": false // CHECK-NEXT: }, // CHECK-NEXT: "hoverProvider": true, -// CHECK-NEXT: "textDocumentSync": { +// CHECK-NEXT: "referencesProvider": true, +// CHECK-NEXT: "textDocumentSync": { // CHECK-NEXT: "change": 1, // CHECK-NEXT: "openClose": true, // CHECK-NEXT: "save": true diff --git a/mlir/test/mlir-tblgen-lsp-server/references.test b/mlir/test/mlir-tblgen-lsp-server/references.test new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen-lsp-server/references.test @@ -0,0 +1,49 @@ +// RUN: mlir-tblgen-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"tblgen","capabilities":{},"trace":"off"}} +// ----- +{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{ + "uri":"test:///foo.td", + "languageId":"tblgen", + "version":1, + "text":"class Foo;\ndef FooDerived : Foo;\n" +}}} +// ----- +{"jsonrpc":"2.0","id":1,"method":"textDocument/references","params":{ + "textDocument":{"uri":"test:///foo.td"}, + "position":{"line":0,"character":7}, + "context":{"includeDeclaration": false} +}} +// CHECK: "id": 1 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": [ +// CHECK-NEXT: { +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": 9, +// CHECK-NEXT: "line": 0 +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": 6, +// CHECK-NEXT: "line": 0 +// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: "uri": "{{.*}}/foo.td" +// CHECK-NEXT: }, +// CHECK-NEXT: { +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": 20, +// CHECK-NEXT: "line": 1 +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": 17, +// CHECK-NEXT: "line": 1 +// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: "uri": "{{.*}}/foo.td" +// CHECK-NEXT: } +// CHECK-NEXT: ] +// ----- +{"jsonrpc":"2.0","id":3,"method":"shutdown"} +// ----- +{"jsonrpc":"2.0","method":"exit"}