diff --git a/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h b/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h --- a/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h +++ b/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h @@ -66,6 +66,24 @@ /// Signal code completion for Pattern metadata. virtual void codeCompletePatternMetadata() {} + //===--------------------------------------------------------------------===// + // Signature Hooks + //===--------------------------------------------------------------------===// + + /// Signal code completion for the signature of a callable. + virtual void codeCompleteCallSignature(const ast::CallableDecl *callable, + unsigned currentNumArgs) {} + + /// Signal code completion for the signature of an operation's operands. + virtual void + codeCompleteOperationOperandsSignature(Optional opName, + unsigned currentNumOperands) {} + + /// Signal code completion for the signature of an operation's results. + virtual void + codeCompleteOperationResultsSignature(Optional opName, + unsigned currentNumResults) {} + protected: /// Create a new code completion context with the given code complete /// location. diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -425,6 +425,12 @@ LogicalResult codeCompleteOperationName(StringRef dialectName); LogicalResult codeCompletePatternMetadata(); + void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs); + void codeCompleteOperationOperandsSignature(Optional opName, + unsigned currentNumOperands); + void codeCompleteOperationResultsSignature(Optional opName, + unsigned currentNumResults); + //===--------------------------------------------------------------------===// // Lexer Utilities //===--------------------------------------------------------------------===// @@ -1762,6 +1768,12 @@ SmallVector arguments; if (curToken.isNot(Token::r_paren)) { do { + // Handle code completion for the call arguments. + if (curToken.is(Token::code_complete)) { + codeCompleteCallSignature(parentExpr, arguments.size()); + return failure(); + } + FailureOr argument = parseExpr(); if (failed(argument)) return failure(); @@ -1933,6 +1945,12 @@ ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); } } else if (!consumeIf(Token::r_paren)) { + // Check for operand signature code completion. + if (curToken.is(Token::code_complete)) { + codeCompleteOperationOperandsSignature(opName, operands.size()); + return failure(); + } + // If the operand list was specified and non-empty, parse the operands. do { FailureOr operand = parseExpr(); @@ -1972,6 +1990,12 @@ // Handle the case of an empty result list. if (!consumeIf(Token::r_paren)) { do { + // Check for result signature code completion. + if (curToken.is(Token::code_complete)) { + codeCompleteOperationResultsSignature(opName, resultTypes.size()); + return failure(); + } + FailureOr resultTypeExpr = parseExpr(); if (failed(resultTypeExpr)) return failure(); @@ -2899,6 +2923,27 @@ return failure(); } +void Parser::codeCompleteCallSignature(ast::Node *parent, + unsigned currentNumArgs) { + ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent); + if (!callableDecl) + return; + + codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs); +} + +void Parser::codeCompleteOperationOperandsSignature( + Optional opName, unsigned currentNumOperands) { + codeCompleteContext->codeCompleteOperationOperandsSignature( + opName, currentNumOperands); +} + +void Parser::codeCompleteOperationResultsSignature(Optional opName, + unsigned currentNumResults) { + codeCompleteContext->codeCompleteOperationResultsSignature(opName, + currentNumResults); +} + //===----------------------------------------------------------------------===// // Parser //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/lsp-server-support/Protocol.h b/mlir/lib/Tools/lsp-server-support/Protocol.h --- a/mlir/lib/Tools/lsp-server-support/Protocol.h +++ b/mlir/lib/Tools/lsp-server-support/Protocol.h @@ -871,6 +871,65 @@ bool fromJSON(const llvm::json::Value &value, CompletionParams &result, llvm::json::Path path); +//===----------------------------------------------------------------------===// +// ParameterInformation +//===----------------------------------------------------------------------===// + +/// A single parameter of a particular signature. +struct ParameterInformation { + /// The label of this parameter. Ignored when labelOffsets is set. + std::string labelString; + + /// Inclusive start and exclusive end offsets withing the containing signature + /// label. + Optional> labelOffsets; + + /// The documentation of this parameter. Optional. + std::string documentation; +}; + +/// Add support for JSON serialization. +llvm::json::Value toJSON(const ParameterInformation &value); + +//===----------------------------------------------------------------------===// +// SignatureInformation +//===----------------------------------------------------------------------===// + +/// Represents the signature of something callable. +struct SignatureInformation { + /// The label of this signature. Mandatory. + std::string label; + + /// The documentation of this signature. Optional. + std::string documentation; + + /// The parameters of this signature. + std::vector parameters; +}; + +/// Add support for JSON serialization. +llvm::json::Value toJSON(const SignatureInformation &value); +raw_ostream &operator<<(raw_ostream &os, const SignatureInformation &value); + +//===----------------------------------------------------------------------===// +// SignatureHelp +//===----------------------------------------------------------------------===// + +/// Represents the signature of a callable. +struct SignatureHelp { + /// The resulting signatures. + std::vector signatures; + + /// The active signature. + int activeSignature = 0; + + /// The active parameter of the active signature. + int activeParameter = 0; +}; + +/// Add support for JSON serialization. +llvm::json::Value toJSON(const SignatureHelp &value); + } // namespace lsp } // namespace mlir diff --git a/mlir/lib/Tools/lsp-server-support/Protocol.cpp b/mlir/lib/Tools/lsp-server-support/Protocol.cpp --- a/mlir/lib/Tools/lsp-server-support/Protocol.cpp +++ b/mlir/lib/Tools/lsp-server-support/Protocol.cpp @@ -742,3 +742,57 @@ return fromJSON(*context, result.context, path.field("context")); return true; } + +//===----------------------------------------------------------------------===// +// ParameterInformation +//===----------------------------------------------------------------------===// + +llvm::json::Value mlir::lsp::toJSON(const ParameterInformation &value) { + assert((value.labelOffsets.hasValue() || !value.labelString.empty()) && + "parameter information label is required"); + llvm::json::Object result; + if (value.labelOffsets) + result["label"] = llvm::json::Array( + {value.labelOffsets->first, value.labelOffsets->second}); + else + result["label"] = value.labelString; + if (!value.documentation.empty()) + result["documentation"] = value.documentation; + return std::move(result); +} + +//===----------------------------------------------------------------------===// +// SignatureInformation +//===----------------------------------------------------------------------===// + +llvm::json::Value mlir::lsp::toJSON(const SignatureInformation &value) { + assert(!value.label.empty() && "signature information label is required"); + llvm::json::Object result{ + {"label", value.label}, + {"parameters", llvm::json::Array(value.parameters)}, + }; + if (!value.documentation.empty()) + result["documentation"] = value.documentation; + return std::move(result); +} + +raw_ostream &mlir::lsp::operator<<(raw_ostream &os, + const SignatureInformation &value) { + return os << value.label << " - " << toJSON(value); +} + +//===----------------------------------------------------------------------===// +// SignatureHelp +//===----------------------------------------------------------------------===// + +llvm::json::Value mlir::lsp::toJSON(const SignatureHelp &value) { + assert(value.activeSignature >= 0 && + "Unexpected negative value for number of active signatures."); + assert(value.activeParameter >= 0 && + "Unexpected negative value for active parameter index"); + return llvm::json::Object{ + {"activeSignature", value.activeSignature}, + {"activeParameter", value.activeParameter}, + {"signatures", llvm::json::Array(value.signatures)}, + }; +} diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp --- a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp @@ -70,6 +70,12 @@ void onCompletion(const CompletionParams ¶ms, Callback reply); + //===--------------------------------------------------------------------===// + // Signature Help + + void onSignatureHelp(const TextDocumentPositionParams ¶ms, + Callback reply); + //===--------------------------------------------------------------------===// // Fields //===--------------------------------------------------------------------===// @@ -109,6 +115,10 @@ {"resolveProvider", false}, {"triggerCharacters", {".", ">", "(", "{", ",", "<", ":", "[", " "}}, }}, + {"signatureHelpProvider", + llvm::json::Object{ + {"triggerCharacters", {"(", ","}}, + }}, {"definitionProvider", true}, {"referencesProvider", true}, {"hoverProvider", true}, @@ -209,6 +219,14 @@ reply(server.getCodeCompletion(params.textDocument.uri, params.position)); } +//===----------------------------------------------------------------------===// +// Signature Help + +void LSPServer::onSignatureHelp(const TextDocumentPositionParams ¶ms, + Callback reply) { + reply(server.getSignatureHelp(params.textDocument.uri, params.position)); +} + //===----------------------------------------------------------------------===// // Entry Point //===----------------------------------------------------------------------===// @@ -249,6 +267,10 @@ messageHandler.method("textDocument/completion", &lspServer, &LSPServer::onCompletion); + // Signature Help + messageHandler.method("textDocument/signatureHelp", &lspServer, + &LSPServer::onSignatureHelp); + // Diagnostics lspServer.publishDiagnostics = messageHandler.outgoingNotification( diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h @@ -20,6 +20,7 @@ struct Hover; struct Location; struct Position; +struct SignatureHelp; class URIForFile; /// This class implements all of the PDLL related functionality necessary for a @@ -62,6 +63,10 @@ CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos); + /// Get the signature help for the position within the given file. + SignatureHelp getSignatureHelp(const URIForFile &uri, + const Position &helpPos); + private: struct Impl; diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp @@ -285,6 +285,13 @@ lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, const lsp::Position &completePos); + //===--------------------------------------------------------------------===// + // Signature Help + //===--------------------------------------------------------------------===// + + lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri, + const lsp::Position &helpPos); + //===--------------------------------------------------------------------===// // Fields //===--------------------------------------------------------------------===// @@ -827,6 +834,154 @@ return completionList; } +//===----------------------------------------------------------------------===// +// PDLDocument: Signature Help +//===----------------------------------------------------------------------===// + +namespace { +class LSPSignatureHelpContext : public CodeCompleteContext { +public: + LSPSignatureHelpContext(SMLoc completeLoc, lsp::SignatureHelp &signatureHelp, + ods::Context &odsContext) + : CodeCompleteContext(completeLoc), signatureHelp(signatureHelp), + odsContext(odsContext) {} + + void codeCompleteCallSignature(const ast::CallableDecl *callable, + unsigned currentNumArgs) final { + signatureHelp.activeParameter = currentNumArgs; + + lsp::SignatureInformation signatureInfo; + { + llvm::raw_string_ostream strOS(signatureInfo.label); + strOS << callable->getName()->getName() << "("; + auto formatParamFn = [&](const ast::VariableDecl *var) { + unsigned paramStart = strOS.str().size(); + strOS << var->getName().getName() << ": " << var->getType(); + unsigned paramEnd = strOS.str().size(); + signatureInfo.parameters.emplace_back(lsp::ParameterInformation{ + StringRef(strOS.str()).slice(paramStart, paramEnd).str(), + std::make_pair(paramStart, paramEnd), /*paramDoc*/ std::string()}); + }; + llvm::interleaveComma(callable->getInputs(), strOS, formatParamFn); + strOS << ") -> " << callable->getResultType(); + } + signatureHelp.signatures.emplace_back(std::move(signatureInfo)); + } + + void + codeCompleteOperationOperandsSignature(Optional opName, + unsigned currentNumOperands) final { + const ods::Operation *odsOp = + opName ? odsContext.lookupOperation(*opName) : nullptr; + codeCompleteOperationOperandOrResultSignature( + opName, odsOp, odsOp ? odsOp->getOperands() : llvm::None, + currentNumOperands, "operand", "Value"); + } + + void codeCompleteOperationResultsSignature(Optional opName, + unsigned currentNumResults) final { + const ods::Operation *odsOp = + opName ? odsContext.lookupOperation(*opName) : nullptr; + codeCompleteOperationOperandOrResultSignature( + opName, odsOp, odsOp ? odsOp->getResults() : llvm::None, + currentNumResults, "result", "Type"); + } + + void codeCompleteOperationOperandOrResultSignature( + Optional opName, const ods::Operation *odsOp, + ArrayRef values, unsigned currentValue, + StringRef label, StringRef dataType) { + signatureHelp.activeParameter = currentValue; + + // If we have ODS information for the operation, add in the ODS signature + // for the operation. We also verify that the current number of values is + // not more than what is defined in ODS, as this will result in an error + // anyways. + if (odsOp && currentValue < values.size()) { + lsp::SignatureInformation signatureInfo; + + // Build the signature label. + { + llvm::raw_string_ostream strOS(signatureInfo.label); + strOS << "("; + auto formatFn = [&](const ods::OperandOrResult &value) { + unsigned paramStart = strOS.str().size(); + + strOS << value.getName() << ": "; + + StringRef constraintDoc = value.getConstraint().getSummary(); + std::string paramDoc; + switch (value.getVariableLengthKind()) { + case ods::VariableLengthKind::Single: + strOS << dataType; + paramDoc = constraintDoc.str(); + break; + case ods::VariableLengthKind::Optional: + strOS << dataType << "?"; + paramDoc = ("optional: " + constraintDoc).str(); + break; + case ods::VariableLengthKind::Variadic: + strOS << dataType << "Range"; + paramDoc = ("variadic: " + constraintDoc).str(); + break; + } + + unsigned paramEnd = strOS.str().size(); + signatureInfo.parameters.emplace_back(lsp::ParameterInformation{ + StringRef(strOS.str()).slice(paramStart, paramEnd).str(), + std::make_pair(paramStart, paramEnd), paramDoc}); + }; + llvm::interleaveComma(values, strOS, formatFn); + strOS << ")"; + } + signatureInfo.documentation = + llvm::formatv("`op<{0}>` ODS {1} specification", *opName, label) + .str(); + signatureHelp.signatures.emplace_back(std::move(signatureInfo)); + } + + // If there aren't any arguments yet, we also add the generic signature. + if (currentValue == 0 && (!odsOp || !values.empty())) { + lsp::SignatureInformation signatureInfo; + signatureInfo.label = + llvm::formatv("(<{0}s>: {1}Range)", label, dataType).str(); + signatureInfo.documentation = + ("Generic operation " + label + " specification").str(); + signatureInfo.parameters.emplace_back(lsp::ParameterInformation{ + StringRef(signatureInfo.label).drop_front().drop_back().str(), + std::pair(1, signatureInfo.label.size() - 1), + ("All of the " + label + "s of the operation.").str()}); + signatureHelp.signatures.emplace_back(std::move(signatureInfo)); + } + } + +private: + lsp::SignatureHelp &signatureHelp; + ods::Context &odsContext; +}; +} // namespace + +lsp::SignatureHelp PDLDocument::getSignatureHelp(const lsp::URIForFile &uri, + const lsp::Position &helpPos) { + SMLoc posLoc = helpPos.getAsSMLoc(sourceMgr); + if (!posLoc.isValid()) + return lsp::SignatureHelp(); + + // Adjust the position one further to after the completion trigger token. + posLoc = SMLoc::getFromPointer(posLoc.getPointer() + 1); + + // To perform code completion, we run another parse of the module with the + // code completion context provided. + ods::Context tmpODSContext; + lsp::SignatureHelp signatureHelp; + LSPSignatureHelpContext completeContext(posLoc, signatureHelp, tmpODSContext); + + ast::Context tmpContext(tmpODSContext); + (void)parsePDLAST(tmpContext, sourceMgr, &completeContext); + + return signatureHelp; +} + //===----------------------------------------------------------------------===// // PDLTextFileChunk //===----------------------------------------------------------------------===// @@ -883,6 +1038,8 @@ void findDocumentSymbols(std::vector &symbols); lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, lsp::Position completePos); + lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri, + lsp::Position helpPos); private: /// Find the PDL document that contains the given position, and update the @@ -1036,6 +1193,11 @@ return completionList; } +lsp::SignatureHelp PDLTextFile::getSignatureHelp(const lsp::URIForFile &uri, + lsp::Position helpPos) { + return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos); +} + PDLTextFileChunk &PDLTextFile::getChunkFor(lsp::Position &pos) { if (chunks.size() == 1) return *chunks.front(); @@ -1123,3 +1285,11 @@ return fileIt->second->getCodeCompletion(uri, completePos); return CompletionList(); } + +lsp::SignatureHelp lsp::PDLLServer::getSignatureHelp(const URIForFile &uri, + const Position &helpPos) { + auto fileIt = impl->files.find(uri.file()); + if (fileIt != impl->files.end()) + return fileIt->second->getSignatureHelp(uri, helpPos); + return SignatureHelp(); +} diff --git a/mlir/test/mlir-pdll-lsp-server/initialize-params.test b/mlir/test/mlir-pdll-lsp-server/initialize-params.test --- a/mlir/test/mlir-pdll-lsp-server/initialize-params.test +++ b/mlir/test/mlir-pdll-lsp-server/initialize-params.test @@ -16,6 +16,12 @@ // CHECK-NEXT: "documentSymbolProvider": true, // CHECK-NEXT: "hoverProvider": true, // CHECK-NEXT: "referencesProvider": true, +// CHECK-NEXT: "signatureHelpProvider": { +// CHECK-NEXT: "triggerCharacters": [ +// CHECK-NEXT: "(", +// CHECK-NEXT: "," +// CHECK-NEXT: ] +// CHECK-NEXT: }, // CHECK-NEXT: "textDocumentSync": { // CHECK-NEXT: "change": 1, // CHECK-NEXT: "openClose": true, diff --git a/mlir/test/mlir-pdll-lsp-server/signature-help.test b/mlir/test/mlir-pdll-lsp-server/signature-help.test new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll-lsp-server/signature-help.test @@ -0,0 +1,89 @@ +// RUN: mlir-pdll-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"pdll","capabilities":{},"trace":"off"}} +// ----- +{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{ + "uri":"test:///foo.pdll", + "languageId":"pdll", + "version":1, + "text":"Constraint ValueCst(value: Value);\nPattern {\nlet root = op() -> ();\nValueCst(root);\nerase root;\n}" +}}} +// ----- +{"jsonrpc":"2.0","id":1,"method":"textDocument/signatureHelp","params":{ + "textDocument":{"uri":"test:///foo.pdll"}, + "position":{"line":2,"character":23} +}} +// CHECK: "id": 1 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "activeParameter": 0, +// CHECK-NEXT: "activeSignature": 0, +// CHECK-NEXT: "signatures": [ +// CHECK-NEXT: { +// CHECK-NEXT: "documentation": "Generic operation operand specification", +// CHECK-NEXT: "label": "(: ValueRange)", +// CHECK-NEXT: "parameters": [ +// CHECK-NEXT: { +// CHECK-NEXT: "documentation": "All of the operands of the operation.", +// CHECK-NEXT: "label": [ +// CHECK-NEXT: 1, +// CHECK-NEXT: 23 +// CHECK-NEXT: ] +// CHECK-NEXT: } +// CHECK-NEXT: ] +// CHECK-NEXT: } +// CHECK-NEXT: ] +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":1,"method":"textDocument/signatureHelp","params":{ + "textDocument":{"uri":"test:///foo.pdll"}, + "position":{"line":2,"character":29} +}} +// CHECK: "id": 1 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "activeParameter": 0, +// CHECK-NEXT: "activeSignature": 0, +// CHECK-NEXT: "signatures": [ +// CHECK-NEXT: { +// CHECK-NEXT: "documentation": "Generic operation result specification", +// CHECK-NEXT: "label": "(: TypeRange)", +// CHECK-NEXT: "parameters": [ +// CHECK-NEXT: { +// CHECK-NEXT: "documentation": "All of the results of the operation.", +// CHECK-NEXT: "label": [ +// CHECK-NEXT: 1, +// CHECK-NEXT: 21 +// CHECK-NEXT: ] +// CHECK-NEXT: } +// CHECK-NEXT: ] +// CHECK-NEXT: } +// CHECK-NEXT: ] +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":1,"method":"textDocument/signatureHelp","params":{ + "textDocument":{"uri":"test:///foo.pdll"}, + "position":{"line":3,"character":9} +}} +// CHECK: "id": 1 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "activeParameter": 0, +// CHECK-NEXT: "activeSignature": 0, +// CHECK-NEXT: "signatures": [ +// CHECK-NEXT: { +// CHECK-NEXT: "label": "ValueCst(value: Value) -> Tuple<>", +// CHECK-NEXT: "parameters": [ +// CHECK-NEXT: { +// CHECK-NEXT: "label": [ +// CHECK-NEXT: 9, +// CHECK-NEXT: 21 +// CHECK-NEXT: ] +// CHECK-NEXT: } +// CHECK-NEXT: ] +// CHECK-NEXT: } +// CHECK-NEXT: ] +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":3,"method":"shutdown"} +// ----- +{"jsonrpc":"2.0","method":"exit"}