diff --git a/mlir/include/mlir/Parser/AsmParserState.h b/mlir/include/mlir/Parser/AsmParserState.h --- a/mlir/include/mlir/Parser/AsmParserState.h +++ b/mlir/include/mlir/Parser/AsmParserState.h @@ -53,7 +53,8 @@ SMDefinition definition; }; - OperationDefinition(Operation *op, llvm::SMRange loc) : op(op), loc(loc) {} + OperationDefinition(Operation *op, llvm::SMRange loc, llvm::SMLoc endLoc) + : op(op), loc(loc), scopeLoc(loc.Start, endLoc) {} /// The operation representing this definition. Operation *op; @@ -61,6 +62,10 @@ /// The source location for the operation, i.e. the location of its name. llvm::SMRange loc; + /// The full source range of the operation definition, i.e. a range + /// encompassing the start and end of the full operation definition. + llvm::SMRange scopeLoc; + /// Source definitions for any result groups of this operation. SmallVector> resultGroups; @@ -110,6 +115,10 @@ /// state. iterator_range getOpDefs() const; + /// Return the definition for the given operation, or nullptr if the given + /// operation does not have a definition. + const OperationDefinition *getOpDef(Operation *op) const; + /// Returns (heuristically) the range of an identifier given a SMLoc /// corresponding to the start of an identifier location. static llvm::SMRange convertIdLocToRange(llvm::SMLoc loc); @@ -130,7 +139,7 @@ /// Finalize the most recently started operation definition. void finalizeOperationDefinition( - Operation *op, llvm::SMRange nameLoc, + Operation *op, llvm::SMRange nameLoc, llvm::SMLoc endLoc, ArrayRef> resultGroups = llvm::None); /// Start a definition for a region nested under the current operation. diff --git a/mlir/lib/Parser/AsmParserState.cpp b/mlir/lib/Parser/AsmParserState.cpp --- a/mlir/lib/Parser/AsmParserState.cpp +++ b/mlir/lib/Parser/AsmParserState.cpp @@ -109,8 +109,13 @@ return llvm::make_pointee_range(llvm::makeArrayRef(impl->operations)); } -/// Returns (heuristically) the range of an identifier given a SMLoc -/// corresponding to the start of an identifier location. +auto AsmParserState::getOpDef(Operation *op) const + -> const OperationDefinition * { + auto it = impl->operationToIdx.find(op); + return it == impl->operationToIdx.end() ? nullptr + : &*impl->operations[it->second]; +} + llvm::SMRange AsmParserState::convertIdLocToRange(llvm::SMLoc loc) { if (!loc.isValid()) return llvm::SMRange(); @@ -153,7 +158,7 @@ } void AsmParserState::finalizeOperationDefinition( - Operation *op, llvm::SMRange nameLoc, + Operation *op, llvm::SMRange nameLoc, llvm::SMLoc endLoc, ArrayRef> resultGroups) { assert(!impl->partialOperations.empty() && "expected valid partial operation definition"); @@ -161,7 +166,7 @@ // Build the full operation definition. std::unique_ptr def = - std::make_unique(op, nameLoc); + std::make_unique(op, nameLoc, endLoc); for (auto &resultGroup : resultGroups) def->resultGroups.emplace_back(resultGroup.first, convertIdLocToRange(resultGroup.second)); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -821,8 +821,9 @@ asmResultGroups.emplace_back(resultIt, std::get<2>(record)); resultIt += std::get<1>(record); } - state.asmState->finalizeOperationDefinition(op, nameTok.getLocRange(), - asmResultGroups); + state.asmState->finalizeOperationDefinition( + op, nameTok.getLocRange(), /*endLoc=*/getToken().getLoc(), + asmResultGroups); } // Add definitions for each of the result groups. @@ -837,7 +838,8 @@ // Add this operation to the assembly state if it was provided to populate. } else if (state.asmState) { - state.asmState->finalizeOperationDefinition(op, nameTok.getLocRange()); + state.asmState->finalizeOperationDefinition(op, nameTok.getLocRange(), + /*endLoc=*/getToken().getLoc()); } return success(); @@ -1009,7 +1011,8 @@ // If we are populating the parser asm state, finalize this operation // definition. if (state.asmState) - state.asmState->finalizeOperationDefinition(op, nameToken.getLocRange()); + state.asmState->finalizeOperationDefinition(op, nameToken.getLocRange(), + /*endLoc=*/getToken().getLoc()); return op; } diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp @@ -56,6 +56,16 @@ void onHover(const TextDocumentPositionParams ¶ms, Callback> reply); + //===--------------------------------------------------------------------===// + // Document Symbols + + void onDocumentSymbol(const DocumentSymbolParams ¶ms, + Callback> reply); + + //===--------------------------------------------------------------------===// + // Fields + //===--------------------------------------------------------------------===// + MLIRServer &server; JSONTransport &transport; @@ -73,6 +83,7 @@ void LSPServer::Impl::onInitialize(const InitializeParams ¶ms, Callback reply) { + // Send a response with the capabilities of this server. llvm::json::Object serverCaps{ {"textDocumentSync", llvm::json::Object{ @@ -83,6 +94,11 @@ {"definitionProvider", true}, {"referencesProvider", true}, {"hoverProvider", true}, + + // For now we only support documenting symbols when the client supports + // hierarchical symbols. + {"documentSymbolProvider", + params.capabilities.hierarchicalDocumentSymbol}, }; llvm::json::Object result{ @@ -165,6 +181,17 @@ reply(server.findHover(params.textDocument.uri, params.position)); } +//===----------------------------------------------------------------------===// +// Document Symbols + +void LSPServer::Impl::onDocumentSymbol( + const DocumentSymbolParams ¶ms, + Callback> reply) { + std::vector symbols; + server.findDocumentSymbols(params.textDocument.uri, symbols); + reply(std::move(symbols)); +} + //===----------------------------------------------------------------------===// // LSPServer //===----------------------------------------------------------------------===// @@ -198,6 +225,10 @@ // Hover messageHandler.method("textDocument/hover", impl.get(), &Impl::onHover); + // Document Symbols + messageHandler.method("textDocument/documentSymbol", impl.get(), + &Impl::onDocumentSymbol); + // Diagnostics impl->publishDiagnostics = messageHandler.outgoingNotification( diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h @@ -17,6 +17,7 @@ namespace lsp { struct Diagnostic; +struct DocumentSymbol; struct Hover; struct Location; struct Position; @@ -55,6 +56,10 @@ /// couldn't be found. Optional findHover(const URIForFile &uri, const Position &hoverPos); + /// Find all of the document symbols within the given file. + void findDocumentSymbols(const URIForFile &uri, + std::vector &symbols); + private: struct Impl; diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -297,6 +297,18 @@ buildHoverForBlockArgument(llvm::SMRange hoverRange, BlockArgument arg, const AsmParserState::BlockDefinition &block); + //===--------------------------------------------------------------------===// + // Document Symbols + //===--------------------------------------------------------------------===// + + void findDocumentSymbols(std::vector &symbols); + void findDocumentSymbols(Operation *op, + std::vector &symbols); + + //===--------------------------------------------------------------------===// + // Fields + //===--------------------------------------------------------------------===// + /// The context used to hold the state contained by the parsed document. MLIRContext context; @@ -593,6 +605,50 @@ return hover; } +//===----------------------------------------------------------------------===// +// MLIRDocument: Document Symbols +//===----------------------------------------------------------------------===// + +void MLIRDocument::findDocumentSymbols( + std::vector &symbols) { + for (Operation &op : parsedIR) + findDocumentSymbols(&op, symbols); +} + +void MLIRDocument::findDocumentSymbols( + Operation *op, std::vector &symbols) { + std::vector *childSymbols = &symbols; + + // Check for the source information of this operation. + if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) { + // If this operation defines a symbol, record it. + if (SymbolOpInterface symbol = dyn_cast(op)) { + symbols.emplace_back(symbol.getName(), + op->hasTrait() + ? lsp::SymbolKind::Function + : lsp::SymbolKind::Class, + getRangeFromLoc(sourceMgr, def->scopeLoc), + getRangeFromLoc(sourceMgr, def->loc)); + childSymbols = &symbols.back().children; + + // Otherwise, if this is a symbol table push an anonymous document symbol. + } else if (op->hasTrait()) { + symbols.emplace_back("<" + op->getName().getStringRef() + ">", + lsp::SymbolKind::Namespace, + getRangeFromLoc(sourceMgr, def->scopeLoc), + getRangeFromLoc(sourceMgr, def->loc)); + childSymbols = &symbols.back().children; + } + } + + // Recurse into the regions of this operation. + if (!op->getNumRegions()) + return; + for (Region ®ion : op->getRegions()) + for (Operation &childOp : region.getOps()) + findDocumentSymbols(&childOp, *childSymbols); +} + //===----------------------------------------------------------------------===// // MLIRTextFileChunk //===----------------------------------------------------------------------===// @@ -648,6 +704,7 @@ std::vector &references); Optional findHover(const lsp::URIForFile &uri, lsp::Position hoverPos); + void findDocumentSymbols(std::vector &symbols); private: /// Find the MLIR document that contains the given position, and update the @@ -661,6 +718,9 @@ /// The version of this file. int64_t version; + /// The number of lines in the file. + int64_t totalNumLines; + /// The chunks of this file. The order of these chunks is the order in which /// they appear in the text file. std::vector> chunks; @@ -670,7 +730,7 @@ MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, int64_t version, DialectRegistry ®istry, std::vector &diagnostics) - : contents(fileContents.str()), version(version) { + : contents(fileContents.str()), version(version), totalNumLines(0) { // Split the file into separate MLIR documents. // TODO: Find a way to share the split file marker with other tools. We don't // want to use `splitAndProcessBuffer` here, but we do want to make sure this @@ -701,6 +761,7 @@ } chunks.emplace_back(std::move(chunk)); } + totalNumLines = lineOffset; } void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri, @@ -742,6 +803,45 @@ return hoverInfo; } +void MLIRTextFile::findDocumentSymbols( + std::vector &symbols) { + if (chunks.size() == 1) + return chunks.front()->document.findDocumentSymbols(symbols); + + // If there are multiple chunks in this file, we create top-level symbols for + // each chunk. + for (unsigned i = 0, e = chunks.size(); i < e; ++i) { + MLIRTextFileChunk &chunk = *chunks[i]; + lsp::Position startPos(chunk.lineOffset); + lsp::Position endPos((i == e - 1) ? totalNumLines - 1 + : chunks[i + 1]->lineOffset); + lsp::DocumentSymbol symbol("", + lsp::SymbolKind::Namespace, + /*range=*/lsp::Range(startPos, endPos), + /*selectionRange=*/lsp::Range(startPos)); + chunk.document.findDocumentSymbols(symbol.children); + + // Fixup the locations of document symbols within this chunk. + if (i != 0) { + SmallVector symbolsToFix; + for (lsp::DocumentSymbol &childSymbol : symbol.children) + symbolsToFix.push_back(&childSymbol); + + while (!symbolsToFix.empty()) { + lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val(); + chunk.adjustLocForChunkOffset(symbol->range); + chunk.adjustLocForChunkOffset(symbol->selectionRange); + + for (lsp::DocumentSymbol &childSymbol : symbol->children) + symbolsToFix.push_back(&childSymbol); + } + } + + // Push the symbol for this chunk. + symbols.emplace_back(std::move(symbol)); + } +} + MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) { if (chunks.size() == 1) return *chunks.front(); @@ -820,3 +920,10 @@ return fileIt->second->findHover(uri, hoverPos); return llvm::None; } + +void lsp::MLIRServer::findDocumentSymbols( + const URIForFile &uri, std::vector &symbols) { + auto fileIt = impl->files.find(uri.file()); + if (fileIt != impl->files.end()) + fileIt->second->findDocumentSymbols(symbols); +} diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h --- a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h @@ -134,6 +134,20 @@ llvm::json::Path path); raw_ostream &operator<<(raw_ostream &os, const URIForFile &value); +//===----------------------------------------------------------------------===// +// ClientCapabilities +//===----------------------------------------------------------------------===// + +struct ClientCapabilities { + /// Client supports hierarchical document symbols. + /// textDocument.documentSymbol.hierarchicalDocumentSymbolSupport + bool hierarchicalDocumentSymbol = false; +}; + +/// Add support for JSON serialization. +bool fromJSON(const llvm::json::Value &value, ClientCapabilities &result, + llvm::json::Path path); + //===----------------------------------------------------------------------===// // InitializeParams //===----------------------------------------------------------------------===// @@ -149,6 +163,9 @@ llvm::json::Path path); struct InitializeParams { + /// The capabilities provided by the client (editor or tool). + ClientCapabilities capabilities; + /// The initial trace setting. If omitted trace is disabled ('off'). Optional trace; }; @@ -224,6 +241,9 @@ //===----------------------------------------------------------------------===// struct Position { + Position(int line = 0, int character = 0) + : line(line), character(character) {} + /// Line position in a document (zero-based). int line = 0; @@ -449,6 +469,94 @@ /// Add support for JSON serialization. llvm::json::Value toJSON(const Hover &hover); +//===----------------------------------------------------------------------===// +// SymbolKind +//===----------------------------------------------------------------------===// + +enum class SymbolKind { + File = 1, + Module = 2, + Namespace = 3, + Package = 4, + Class = 5, + Method = 6, + Property = 7, + Field = 8, + Constructor = 9, + Enum = 10, + Interface = 11, + Function = 12, + Variable = 13, + Constant = 14, + String = 15, + Number = 16, + Boolean = 17, + Array = 18, + Object = 19, + Key = 20, + Null = 21, + EnumMember = 22, + Struct = 23, + Event = 24, + Operator = 25, + TypeParameter = 26 +}; + +//===----------------------------------------------------------------------===// +// DocumentSymbol +//===----------------------------------------------------------------------===// + +/// Represents programming constructs like variables, classes, interfaces etc. +/// that appear in a document. Document symbols can be hierarchical and they +/// have two ranges: one that encloses its definition and one that points to its +/// most interesting range, e.g. the range of an identifier. +struct DocumentSymbol { + DocumentSymbol() = default; + DocumentSymbol(DocumentSymbol &&) = default; + DocumentSymbol(const Twine &name, SymbolKind kind, Range range, + Range selectionRange) + : name(name.str()), kind(kind), range(range), + selectionRange(selectionRange) {} + + /// The name of this symbol. + std::string name; + + /// More detail for this symbol, e.g the signature of a function. + std::string detail; + + /// The kind of this symbol. + SymbolKind kind; + + /// The range enclosing this symbol not including leading/trailing whitespace + /// but everything else like comments. This information is typically used to + /// determine if the clients cursor is inside the symbol to reveal in the + /// symbol in the UI. + Range range; + + /// The range that should be selected and revealed when this symbol is being + /// picked, e.g the name of a function. Must be contained by the `range`. + Range selectionRange; + + /// Children of this symbol, e.g. properties of a class. + std::vector children; +}; + +/// Add support for JSON serialization. +llvm::json::Value toJSON(const DocumentSymbol &symbol); + +//===----------------------------------------------------------------------===// +// DocumentSymbolParams +//===----------------------------------------------------------------------===// + +struct DocumentSymbolParams { + // The text document to find symbols in. + TextDocumentIdentifier textDocument; +}; + +/// Add support for JSON serialization. +bool fromJSON(const llvm::json::Value &value, DocumentSymbolParams &result, + llvm::json::Path path); + //===----------------------------------------------------------------------===// // DiagnosticRelatedInformation //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp --- a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp @@ -246,6 +246,28 @@ return os << value.uri(); } +//===----------------------------------------------------------------------===// +// ClientCapabilities +//===----------------------------------------------------------------------===// + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, + ClientCapabilities &result, llvm::json::Path path) { + const llvm::json::Object *o = value.getAsObject(); + if (!o) { + path.report("expected object"); + return false; + } + if (const llvm::json::Object *textDocument = o->getObject("textDocument")) { + if (const llvm::json::Object *documentSymbol = + textDocument->getObject("documentSymbol")) { + if (Optional hierarchicalSupport = + documentSymbol->getBoolean("hierarchicalDocumentSymbolSupport")) + result.hierarchicalDocumentSymbol = *hierarchicalSupport; + } + } + return true; +} + //===----------------------------------------------------------------------===// // InitializeParams //===----------------------------------------------------------------------===// @@ -275,6 +297,7 @@ if (!o) return false; // We deliberately don't fail if we can't parse individual fields. + o.map("capabilities", result.capabilities); o.map("trace", result.trace); return true; } @@ -494,6 +517,33 @@ return std::move(result); } +//===----------------------------------------------------------------------===// +// DocumentSymbol +//===----------------------------------------------------------------------===// + +llvm::json::Value mlir::lsp::toJSON(const DocumentSymbol &symbol) { + llvm::json::Object result{{"name", symbol.name}, + {"kind", static_cast(symbol.kind)}, + {"range", symbol.range}, + {"selectionRange", symbol.selectionRange}}; + + if (!symbol.detail.empty()) + result["detail"] = symbol.detail; + if (!symbol.children.empty()) + result["children"] = symbol.children; + return std::move(result); +} + +//===----------------------------------------------------------------------===// +// DocumentSymbolParams +//===----------------------------------------------------------------------===// + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, + DocumentSymbolParams &result, llvm::json::Path path) { + llvm::json::ObjectMapper o(value, path); + return o && o.map("textDocument", result.textDocument); +} + //===----------------------------------------------------------------------===// // DiagnosticRelatedInformation //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-lsp-server/document-symbols.test b/mlir/test/mlir-lsp-server/document-symbols.test new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-lsp-server/document-symbols.test @@ -0,0 +1,71 @@ +// RUN: mlir-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootUri":"test:///workspace","capabilities":{"textDocument":{"documentSymbol":{"hierarchicalDocumentSymbolSupport":true}}}},"trace":"off"}} +// ----- +{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{ + "uri":"test:///foo.mlir", + "languageId":"mlir", + "version":1, + "text":"module {\nfunc private @foo()\n}" +}}} +// ----- +{"jsonrpc":"2.0","id":1,"method":"textDocument/documentSymbol","params":{ + "textDocument":{"uri":"test:///foo.mlir"} +}} +// CHECK: "id": 1 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": [ +// CHECK-NEXT: { +// CHECK-NEXT: "children": [ +// CHECK-NEXT: { +// CHECK-NEXT: "kind": 12, +// CHECK-NEXT: "name": "foo", +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": {{.*}}, +// CHECK-NEXT: "line": {{.*}} +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": {{.*}}, +// CHECK-NEXT: "line": {{.*}} +// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: "selectionRange": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": 4, +// CHECK-NEXT: "line": {{.*}} +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": {{.*}}, +// CHECK-NEXT: "line": {{.*}} +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: ], +// CHECK-NEXT: "kind": 3, +// CHECK-NEXT: "name": "", +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": {{.*}}, +// CHECK-NEXT: "line": {{.*}} +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": {{.*}}, +// CHECK-NEXT: "line": {{.*}} +// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: "selectionRange": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": {{.*}}, +// CHECK-NEXT: "line": {{.*}} +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": {{.*}}, +// CHECK-NEXT: "line": {{.*}} +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: ] +// ----- +{"jsonrpc":"2.0","id":3,"method":"shutdown"} +// ----- +{"jsonrpc":"2.0","method":"exit"} diff --git a/mlir/test/mlir-lsp-server/initialize-params.test b/mlir/test/mlir-lsp-server/initialize-params.test --- a/mlir/test/mlir-lsp-server/initialize-params.test +++ b/mlir/test/mlir-lsp-server/initialize-params.test @@ -6,6 +6,7 @@ // CHECK-NEXT: "result": { // CHECK-NEXT: "capabilities": { // CHECK-NEXT: "definitionProvider": true, +// CHECK-NEXT: "documentSymbolProvider": false, // CHECK-NEXT: "hoverProvider": true, // CHECK-NEXT: "referencesProvider": true, // CHECK-NEXT: "textDocumentSync": {