diff --git a/mlir/include/mlir/Parser/AsmParserState.h b/mlir/include/mlir/Parser/AsmParserState.h --- a/mlir/include/mlir/Parser/AsmParserState.h +++ b/mlir/include/mlir/Parser/AsmParserState.h @@ -53,7 +53,8 @@ SMDefinition definition; }; - OperationDefinition(Operation *op, llvm::SMRange loc) : op(op), loc(loc) {} + OperationDefinition(Operation *op, llvm::SMRange loc, llvm::SMLoc endLoc) + : op(op), loc(loc), scopeLoc(loc.Start, endLoc) {} /// The operation representing this definition. Operation *op; @@ -61,6 +62,10 @@ /// The source location for the operation, i.e. the location of its name. llvm::SMRange loc; + /// The full source range of the operation definition, i.e. a range + /// encompassing the start and end of the full operation definition. + llvm::SMRange scopeLoc; + /// Source definitions for any result groups of this operation. SmallVector> resultGroups; @@ -110,6 +115,10 @@ /// state. iterator_range getOpDefs() const; + /// Return the definition for the given operation, or nullptr if the given + /// operation does not have a definition. + const OperationDefinition *getOpDef(Operation *op) const; + /// Returns (heuristically) the range of an identifier given a SMLoc /// corresponding to the start of an identifier location. static llvm::SMRange convertIdLocToRange(llvm::SMLoc loc); @@ -130,7 +139,7 @@ /// Finalize the most recently started operation definition. void finalizeOperationDefinition( - Operation *op, llvm::SMRange nameLoc, + Operation *op, llvm::SMRange nameLoc, llvm::SMLoc endLoc, ArrayRef> resultGroups = llvm::None); /// Start a definition for a region nested under the current operation. diff --git a/mlir/lib/Parser/AsmParserState.cpp b/mlir/lib/Parser/AsmParserState.cpp --- a/mlir/lib/Parser/AsmParserState.cpp +++ b/mlir/lib/Parser/AsmParserState.cpp @@ -36,8 +36,8 @@ std::unique_ptr symbolTable; }; - /// Resolve any symbol table uses under the given partial operation. - void resolveSymbolUses(Operation *op, PartialOpDef &opDef); + /// Resolve any symbol table uses in the IR. + void resolveSymbolUses(); /// A mapping from operations in the input source file to their parser state. SmallVector> operations; @@ -51,6 +51,10 @@ /// This map should be empty if the parser finishes successfully. DenseMap> placeholderValueUses; + /// The symbol table operations within the IR. + SmallVector>> + symbolTableOperations; + /// A stack of partial operation definitions that have been started but not /// yet finalized. SmallVector partialOperations; @@ -63,22 +67,21 @@ SymbolTableCollection symbolTable; }; -void AsmParserState::Impl::resolveSymbolUses(Operation *op, - PartialOpDef &opDef) { - assert(opDef.isSymbolTable() && "expected op to be a symbol table"); - +void AsmParserState::Impl::resolveSymbolUses() { SmallVector symbolOps; - for (auto &it : *opDef.symbolTable) { - symbolOps.clear(); - if (failed(symbolTable.lookupSymbolIn(op, it.first.cast(), - symbolOps))) - continue; - - for (ArrayRef useRange : it.second) { - for (const auto &symIt : llvm::zip(symbolOps, useRange)) { - auto opIt = operationToIdx.find(std::get<0>(symIt)); - if (opIt != operationToIdx.end()) - operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt)); + for (auto &opAndUseMapIt : symbolTableOperations) { + for (auto &it : *opAndUseMapIt.second) { + symbolOps.clear(); + if (failed(symbolTable.lookupSymbolIn( + opAndUseMapIt.first, it.first.cast(), symbolOps))) + continue; + + for (ArrayRef useRange : it.second) { + for (const auto &symIt : llvm::zip(symbolOps, useRange)) { + auto opIt = operationToIdx.find(std::get<0>(symIt)); + if (opIt != operationToIdx.end()) + operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt)); + } } } } @@ -112,8 +115,13 @@ return llvm::make_pointee_range(llvm::makeArrayRef(impl->operations)); } -/// Returns (heuristically) the range of an identifier given a SMLoc -/// corresponding to the start of an identifier location. +auto AsmParserState::getOpDef(Operation *op) const + -> const OperationDefinition * { + auto it = impl->operationToIdx.find(op); + return it == impl->operationToIdx.end() ? nullptr + : &*impl->operations[it->second]; +} + llvm::SMRange AsmParserState::convertIdLocToRange(llvm::SMLoc loc) { if (!loc.isValid()) return llvm::SMRange(); @@ -124,7 +132,7 @@ }; const char *curPtr = loc.getPointer(); - while (isIdentifierChar(*(++curPtr))) + while (*curPtr && isIdentifierChar(*(++curPtr))) continue; return llvm::SMRange(loc, llvm::SMLoc::getFromPointer(curPtr)); } @@ -147,8 +155,11 @@ Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val(); // If this operation is a symbol table, resolve any symbol uses. - if (partialOpDef.isSymbolTable()) - impl->resolveSymbolUses(topLevelOp, partialOpDef); + if (partialOpDef.isSymbolTable()) { + impl->symbolTableOperations.emplace_back( + topLevelOp, std::move(partialOpDef.symbolTable)); + } + impl->resolveSymbolUses(); } void AsmParserState::startOperationDefinition(const OperationName &opName) { @@ -156,7 +167,7 @@ } void AsmParserState::finalizeOperationDefinition( - Operation *op, llvm::SMRange nameLoc, + Operation *op, llvm::SMRange nameLoc, llvm::SMLoc endLoc, ArrayRef> resultGroups) { assert(!impl->partialOperations.empty() && "expected valid partial operation definition"); @@ -164,7 +175,7 @@ // Build the full operation definition. std::unique_ptr def = - std::make_unique(op, nameLoc); + std::make_unique(op, nameLoc, endLoc); for (auto &resultGroup : resultGroups) def->resultGroups.emplace_back(resultGroup.first, convertIdLocToRange(resultGroup.second)); @@ -172,8 +183,10 @@ impl->operations.emplace_back(std::move(def)); // If this operation is a symbol table, resolve any symbol uses. - if (partialOpDef.isSymbolTable()) - impl->resolveSymbolUses(op, partialOpDef); + if (partialOpDef.isSymbolTable()) { + impl->symbolTableOperations.emplace_back( + op, std::move(partialOpDef.symbolTable)); + } } void AsmParserState::startRegionDefinition() { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -166,12 +166,12 @@ /// operations. class OperationParser : public Parser { public: - OperationParser(ParserState &state, Operation *topLevelOp); + OperationParser(ParserState &state, ModuleOp topLevelOp); ~OperationParser(); /// After parsing is finished, this function must be called to see if there /// are any remaining issues. - ParseResult finalize(Operation *topLevelOp); + ParseResult finalize(); //===--------------------------------------------------------------------===// // SSA Value Handling @@ -399,9 +399,8 @@ }; } // end anonymous namespace -OperationParser::OperationParser(ParserState &state, Operation *topLevelOp) - : Parser(state), opBuilder(topLevelOp->getRegion(0)), - topLevelOp(topLevelOp) { +OperationParser::OperationParser(ParserState &state, ModuleOp topLevelOp) + : Parser(state), opBuilder(topLevelOp.getRegion()), topLevelOp(topLevelOp) { // The top level operation starts a new name scope. pushSSANameScope(/*isIsolated=*/true); @@ -429,7 +428,7 @@ /// After parsing is finished, this function must be called to see if there are /// any remaining issues. -ParseResult OperationParser::finalize(Operation *topLevelOp) { +ParseResult OperationParser::finalize() { // Check for any forward references that are left. If we find any, error // out. if (!forwardRefPlaceholders.empty()) { @@ -466,12 +465,18 @@ opOrArgument.get().setLoc(locAttr); } + // Pop the top level name scope. + if (failed(popSSANameScope())) + return failure(); + + // Verify that the parsed operations are valid. + if (failed(verify(topLevelOp))) + return failure(); + // If we are populating the parser state, finalize the top-level operation. if (state.asmState) state.asmState->finalize(topLevelOp); - - // Pop the top level name scope. - return popSSANameScope(); + return success(); } //===----------------------------------------------------------------------===// @@ -821,8 +826,9 @@ asmResultGroups.emplace_back(resultIt, std::get<2>(record)); resultIt += std::get<1>(record); } - state.asmState->finalizeOperationDefinition(op, nameTok.getLocRange(), - asmResultGroups); + state.asmState->finalizeOperationDefinition( + op, nameTok.getLocRange(), /*endLoc=*/getToken().getLoc(), + asmResultGroups); } // Add definitions for each of the result groups. @@ -837,7 +843,8 @@ // Add this operation to the assembly state if it was provided to populate. } else if (state.asmState) { - state.asmState->finalizeOperationDefinition(op, nameTok.getLocRange()); + state.asmState->finalizeOperationDefinition(op, nameTok.getLocRange(), + /*endLoc=*/getToken().getLoc()); } return success(); @@ -1009,7 +1016,8 @@ // If we are populating the parser asm state, finalize this operation // definition. if (state.asmState) - state.asmState->finalizeOperationDefinition(op, nameToken.getLocRange()); + state.asmState->finalizeOperationDefinition(op, nameToken.getLocRange(), + /*endLoc=*/getToken().getLoc()); return op; } @@ -2019,6 +2027,10 @@ // Add arguments to the entry block. if (!entryArguments.empty()) { + // If we had named arguments, then don't allow a block name. + if (getToken().is(Token::caret_identifier)) + return emitError("invalid block name in region with named arguments"); + for (auto &placeholderArgPair : entryArguments) { auto &argInfo = placeholderArgPair.first; @@ -2040,10 +2052,6 @@ if (addDefinition(argInfo, arg)) return failure(); } - - // If we had named arguments, then don't allow a block name. - if (getToken().is(Token::caret_identifier)) - return emitError("invalid block name in region with named arguments"); } if (parseBlock(block)) @@ -2310,7 +2318,7 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock, Location parserLoc) { // Create a top-level operation to contain the parsed state. - OwningOpRef topLevelOp(ModuleOp::create(parserLoc)); + OwningOpRef topLevelOp(ModuleOp::create(parserLoc)); OperationParser opParser(state, topLevelOp.get()); while (true) { switch (getToken().getKind()) { @@ -2322,16 +2330,12 @@ // If we got to the end of the file, then we're done. case Token::eof: { - if (opParser.finalize(topLevelOp.get())) - return failure(); - - // Verify that the parsed operations are valid. - if (failed(verify(topLevelOp.get()))) + if (opParser.finalize()) return failure(); // Splice the blocks of the parsed operation over to the provided // top-level block. - auto &parsedOps = (*topLevelOp)->getRegion(0).front().getOperations(); + auto &parsedOps = topLevelOp->getBody()->getOperations(); auto &destOps = topLevelBlock->getOperations(); destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()), parsedOps, parsedOps.begin(), parsedOps.end()); diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp @@ -56,6 +56,16 @@ void onHover(const TextDocumentPositionParams ¶ms, Callback> reply); + //===--------------------------------------------------------------------===// + // Document Symbols + + void onDocumentSymbol(const DocumentSymbolParams ¶ms, + Callback> reply); + + //===--------------------------------------------------------------------===// + // Fields + //===--------------------------------------------------------------------===// + MLIRServer &server; JSONTransport &transport; @@ -73,6 +83,7 @@ void LSPServer::Impl::onInitialize(const InitializeParams ¶ms, Callback reply) { + // Send a response with the capabilities of this server. llvm::json::Object serverCaps{ {"textDocumentSync", llvm::json::Object{ @@ -83,6 +94,11 @@ {"definitionProvider", true}, {"referencesProvider", true}, {"hoverProvider", true}, + + // For now we only support documenting symbols when the client supports + // hierarchical symbols. + {"documentSymbolProvider", + params.capabilities.hierarchicalDocumentSymbol}, }; llvm::json::Object result{ @@ -165,6 +181,17 @@ reply(server.findHover(params.textDocument.uri, params.position)); } +//===----------------------------------------------------------------------===// +// Document Symbols + +void LSPServer::Impl::onDocumentSymbol( + const DocumentSymbolParams ¶ms, + Callback> reply) { + std::vector symbols; + server.findDocumentSymbols(params.textDocument.uri, symbols); + reply(std::move(symbols)); +} + //===----------------------------------------------------------------------===// // LSPServer //===----------------------------------------------------------------------===// @@ -198,6 +225,10 @@ // Hover messageHandler.method("textDocument/hover", impl.get(), &Impl::onHover); + // Document Symbols + messageHandler.method("textDocument/documentSymbol", impl.get(), + &Impl::onDocumentSymbol); + // Diagnostics impl->publishDiagnostics = messageHandler.outgoingNotification( diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h @@ -17,6 +17,7 @@ namespace lsp { struct Diagnostic; +struct DocumentSymbol; struct Hover; struct Location; struct Position; @@ -55,6 +56,10 @@ /// couldn't be found. Optional findHover(const URIForFile &uri, const Position &hoverPos); + /// Find all of the document symbols within the given file. + void findDocumentSymbols(const URIForFile &uri, + std::vector &symbols); + private: struct Impl; diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -298,6 +298,18 @@ buildHoverForBlockArgument(llvm::SMRange hoverRange, BlockArgument arg, const AsmParserState::BlockDefinition &block); + //===--------------------------------------------------------------------===// + // Document Symbols + //===--------------------------------------------------------------------===// + + void findDocumentSymbols(std::vector &symbols); + void findDocumentSymbols(Operation *op, + std::vector &symbols); + + //===--------------------------------------------------------------------===// + // Fields + //===--------------------------------------------------------------------===// + /// The context used to hold the state contained by the parsed document. MLIRContext context; @@ -594,6 +606,50 @@ return hover; } +//===----------------------------------------------------------------------===// +// MLIRDocument: Document Symbols +//===----------------------------------------------------------------------===// + +void MLIRDocument::findDocumentSymbols( + std::vector &symbols) { + for (Operation &op : parsedIR) + findDocumentSymbols(&op, symbols); +} + +void MLIRDocument::findDocumentSymbols( + Operation *op, std::vector &symbols) { + std::vector *childSymbols = &symbols; + + // Check for the source information of this operation. + if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) { + // If this operation defines a symbol, record it. + if (SymbolOpInterface symbol = dyn_cast(op)) { + symbols.emplace_back(symbol.getName(), + op->hasTrait() + ? lsp::SymbolKind::Function + : lsp::SymbolKind::Class, + getRangeFromLoc(sourceMgr, def->scopeLoc), + getRangeFromLoc(sourceMgr, def->loc)); + childSymbols = &symbols.back().children; + + } else if (op->hasTrait()) { + // Otherwise, if this is a symbol table push an anonymous document symbol. + symbols.emplace_back("<" + op->getName().getStringRef() + ">", + lsp::SymbolKind::Namespace, + getRangeFromLoc(sourceMgr, def->scopeLoc), + getRangeFromLoc(sourceMgr, def->loc)); + childSymbols = &symbols.back().children; + } + } + + // Recurse into the regions of this operation. + if (!op->getNumRegions()) + return; + for (Region ®ion : op->getRegions()) + for (Operation &childOp : region.getOps()) + findDocumentSymbols(&childOp, *childSymbols); +} + //===----------------------------------------------------------------------===// // MLIRTextFileChunk //===----------------------------------------------------------------------===// @@ -649,6 +705,7 @@ std::vector &references); Optional findHover(const lsp::URIForFile &uri, lsp::Position hoverPos); + void findDocumentSymbols(std::vector &symbols); private: /// Find the MLIR document that contains the given position, and update the @@ -662,6 +719,9 @@ /// The version of this file. int64_t version; + /// The number of lines in the file. + int64_t totalNumLines; + /// The chunks of this file. The order of these chunks is the order in which /// they appear in the text file. std::vector> chunks; @@ -671,7 +731,7 @@ MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, int64_t version, DialectRegistry ®istry, std::vector &diagnostics) - : contents(fileContents.str()), version(version) { + : contents(fileContents.str()), version(version), totalNumLines(0) { // Split the file into separate MLIR documents. // TODO: Find a way to share the split file marker with other tools. We don't // want to use `splitAndProcessBuffer` here, but we do want to make sure this @@ -702,6 +762,7 @@ } chunks.emplace_back(std::move(chunk)); } + totalNumLines = lineOffset; } void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri, @@ -743,6 +804,45 @@ return hoverInfo; } +void MLIRTextFile::findDocumentSymbols( + std::vector &symbols) { + if (chunks.size() == 1) + return chunks.front()->document.findDocumentSymbols(symbols); + + // If there are multiple chunks in this file, we create top-level symbols for + // each chunk. + for (unsigned i = 0, e = chunks.size(); i < e; ++i) { + MLIRTextFileChunk &chunk = *chunks[i]; + lsp::Position startPos(chunk.lineOffset); + lsp::Position endPos((i == e - 1) ? totalNumLines - 1 + : chunks[i + 1]->lineOffset); + lsp::DocumentSymbol symbol("", + lsp::SymbolKind::Namespace, + /*range=*/lsp::Range(startPos, endPos), + /*selectionRange=*/lsp::Range(startPos)); + chunk.document.findDocumentSymbols(symbol.children); + + // Fixup the locations of document symbols within this chunk. + if (i != 0) { + SmallVector symbolsToFix; + for (lsp::DocumentSymbol &childSymbol : symbol.children) + symbolsToFix.push_back(&childSymbol); + + while (!symbolsToFix.empty()) { + lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val(); + chunk.adjustLocForChunkOffset(symbol->range); + chunk.adjustLocForChunkOffset(symbol->selectionRange); + + for (lsp::DocumentSymbol &childSymbol : symbol->children) + symbolsToFix.push_back(&childSymbol); + } + } + + // Push the symbol for this chunk. + symbols.emplace_back(std::move(symbol)); + } +} + MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) { if (chunks.size() == 1) return *chunks.front(); @@ -821,3 +921,10 @@ return fileIt->second->findHover(uri, hoverPos); return llvm::None; } + +void lsp::MLIRServer::findDocumentSymbols( + const URIForFile &uri, std::vector &symbols) { + auto fileIt = impl->files.find(uri.file()); + if (fileIt != impl->files.end()) + fileIt->second->findDocumentSymbols(symbols); +} diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h --- a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h @@ -134,6 +134,20 @@ llvm::json::Path path); raw_ostream &operator<<(raw_ostream &os, const URIForFile &value); +//===----------------------------------------------------------------------===// +// ClientCapabilities +//===----------------------------------------------------------------------===// + +struct ClientCapabilities { + /// Client supports hierarchical document symbols. + /// textDocument.documentSymbol.hierarchicalDocumentSymbolSupport + bool hierarchicalDocumentSymbol = false; +}; + +/// Add support for JSON serialization. +bool fromJSON(const llvm::json::Value &value, ClientCapabilities &result, + llvm::json::Path path); + //===----------------------------------------------------------------------===// // InitializeParams //===----------------------------------------------------------------------===// @@ -149,6 +163,9 @@ llvm::json::Path path); struct InitializeParams { + /// The capabilities provided by the client (editor or tool). + ClientCapabilities capabilities; + /// The initial trace setting. If omitted trace is disabled ('off'). Optional trace; }; @@ -224,6 +241,9 @@ //===----------------------------------------------------------------------===// struct Position { + Position(int line = 0, int character = 0) + : line(line), character(character) {} + /// Line position in a document (zero-based). int line = 0; @@ -449,6 +469,94 @@ /// Add support for JSON serialization. llvm::json::Value toJSON(const Hover &hover); +//===----------------------------------------------------------------------===// +// SymbolKind +//===----------------------------------------------------------------------===// + +enum class SymbolKind { + File = 1, + Module = 2, + Namespace = 3, + Package = 4, + Class = 5, + Method = 6, + Property = 7, + Field = 8, + Constructor = 9, + Enum = 10, + Interface = 11, + Function = 12, + Variable = 13, + Constant = 14, + String = 15, + Number = 16, + Boolean = 17, + Array = 18, + Object = 19, + Key = 20, + Null = 21, + EnumMember = 22, + Struct = 23, + Event = 24, + Operator = 25, + TypeParameter = 26 +}; + +//===----------------------------------------------------------------------===// +// DocumentSymbol +//===----------------------------------------------------------------------===// + +/// Represents programming constructs like variables, classes, interfaces etc. +/// that appear in a document. Document symbols can be hierarchical and they +/// have two ranges: one that encloses its definition and one that points to its +/// most interesting range, e.g. the range of an identifier. +struct DocumentSymbol { + DocumentSymbol() = default; + DocumentSymbol(DocumentSymbol &&) = default; + DocumentSymbol(const Twine &name, SymbolKind kind, Range range, + Range selectionRange) + : name(name.str()), kind(kind), range(range), + selectionRange(selectionRange) {} + + /// The name of this symbol. + std::string name; + + /// More detail for this symbol, e.g the signature of a function. + std::string detail; + + /// The kind of this symbol. + SymbolKind kind; + + /// The range enclosing this symbol not including leading/trailing whitespace + /// but everything else like comments. This information is typically used to + /// determine if the clients cursor is inside the symbol to reveal in the + /// symbol in the UI. + Range range; + + /// The range that should be selected and revealed when this symbol is being + /// picked, e.g the name of a function. Must be contained by the `range`. + Range selectionRange; + + /// Children of this symbol, e.g. properties of a class. + std::vector children; +}; + +/// Add support for JSON serialization. +llvm::json::Value toJSON(const DocumentSymbol &symbol); + +//===----------------------------------------------------------------------===// +// DocumentSymbolParams +//===----------------------------------------------------------------------===// + +struct DocumentSymbolParams { + // The text document to find symbols in. + TextDocumentIdentifier textDocument; +}; + +/// Add support for JSON serialization. +bool fromJSON(const llvm::json::Value &value, DocumentSymbolParams &result, + llvm::json::Path path); + //===----------------------------------------------------------------------===// // DiagnosticRelatedInformation //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp --- a/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp @@ -246,6 +246,28 @@ return os << value.uri(); } +//===----------------------------------------------------------------------===// +// ClientCapabilities +//===----------------------------------------------------------------------===// + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, + ClientCapabilities &result, llvm::json::Path path) { + const llvm::json::Object *o = value.getAsObject(); + if (!o) { + path.report("expected object"); + return false; + } + if (const llvm::json::Object *textDocument = o->getObject("textDocument")) { + if (const llvm::json::Object *documentSymbol = + textDocument->getObject("documentSymbol")) { + if (Optional hierarchicalSupport = + documentSymbol->getBoolean("hierarchicalDocumentSymbolSupport")) + result.hierarchicalDocumentSymbol = *hierarchicalSupport; + } + } + return true; +} + //===----------------------------------------------------------------------===// // InitializeParams //===----------------------------------------------------------------------===// @@ -275,6 +297,7 @@ if (!o) return false; // We deliberately don't fail if we can't parse individual fields. + o.map("capabilities", result.capabilities); o.map("trace", result.trace); return true; } @@ -494,6 +517,33 @@ return std::move(result); } +//===----------------------------------------------------------------------===// +// DocumentSymbol +//===----------------------------------------------------------------------===// + +llvm::json::Value mlir::lsp::toJSON(const DocumentSymbol &symbol) { + llvm::json::Object result{{"name", symbol.name}, + {"kind", static_cast(symbol.kind)}, + {"range", symbol.range}, + {"selectionRange", symbol.selectionRange}}; + + if (!symbol.detail.empty()) + result["detail"] = symbol.detail; + if (!symbol.children.empty()) + result["children"] = symbol.children; + return std::move(result); +} + +//===----------------------------------------------------------------------===// +// DocumentSymbolParams +//===----------------------------------------------------------------------===// + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, + DocumentSymbolParams &result, llvm::json::Path path) { + llvm::json::ObjectMapper o(value, path); + return o && o.map("textDocument", result.textDocument); +} + //===----------------------------------------------------------------------===// // DiagnosticRelatedInformation //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.cpp b/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.cpp --- a/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/lsp/Transport.cpp @@ -204,6 +204,8 @@ if (llvm::Expected doc = llvm::json::parse(json)) { if (!handleMessage(std::move(*doc), handler)) return llvm::Error::success(); + } else { + Logger::error("JSON parse error: {0}", llvm::toString(doc.takeError())); } } } diff --git a/mlir/test/mlir-lsp-server/document-symbols.test b/mlir/test/mlir-lsp-server/document-symbols.test new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-lsp-server/document-symbols.test @@ -0,0 +1,71 @@ +// RUN: mlir-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootUri":"test:///workspace","capabilities":{"textDocument":{"documentSymbol":{"hierarchicalDocumentSymbolSupport":true}}},"trace":"off"}} +// ----- +{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{ + "uri":"test:///foo.mlir", + "languageId":"mlir", + "version":1, + "text":"module {\nfunc private @foo()\n}" +}}} +// ----- +{"jsonrpc":"2.0","id":1,"method":"textDocument/documentSymbol","params":{ + "textDocument":{"uri":"test:///foo.mlir"} +}} +// CHECK: "id": 1 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": [ +// CHECK-NEXT: { +// CHECK-NEXT: "children": [ +// CHECK-NEXT: { +// CHECK-NEXT: "kind": 12, +// CHECK-NEXT: "name": "foo", +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": {{.*}}, +// CHECK-NEXT: "line": {{.*}} +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": {{.*}}, +// CHECK-NEXT: "line": {{.*}} +// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: "selectionRange": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": 4, +// CHECK-NEXT: "line": {{.*}} +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": {{.*}}, +// CHECK-NEXT: "line": {{.*}} +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: ], +// CHECK-NEXT: "kind": 3, +// CHECK-NEXT: "name": "", +// CHECK-NEXT: "range": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": {{.*}}, +// CHECK-NEXT: "line": {{.*}} +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": {{.*}}, +// CHECK-NEXT: "line": {{.*}} +// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: "selectionRange": { +// CHECK-NEXT: "end": { +// CHECK-NEXT: "character": {{.*}}, +// CHECK-NEXT: "line": {{.*}} +// CHECK-NEXT: }, +// CHECK-NEXT: "start": { +// CHECK-NEXT: "character": {{.*}}, +// CHECK-NEXT: "line": {{.*}} +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: ] +// ----- +{"jsonrpc":"2.0","id":3,"method":"shutdown"} +// ----- +{"jsonrpc":"2.0","method":"exit"} diff --git a/mlir/test/mlir-lsp-server/initialize-params.test b/mlir/test/mlir-lsp-server/initialize-params.test --- a/mlir/test/mlir-lsp-server/initialize-params.test +++ b/mlir/test/mlir-lsp-server/initialize-params.test @@ -6,6 +6,7 @@ // CHECK-NEXT: "result": { // CHECK-NEXT: "capabilities": { // CHECK-NEXT: "definitionProvider": true, +// CHECK-NEXT: "documentSymbolProvider": false, // CHECK-NEXT: "hoverProvider": true, // CHECK-NEXT: "referencesProvider": true, // CHECK-NEXT: "textDocumentSync": {