diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -116,10 +116,8 @@ mlir::Type fir::parseFirType(FIROpsDialect *dialect, mlir::DialectAsmParser &parser) { mlir::StringRef typeTag; - if (parser.parseKeyword(&typeTag)) - return {}; mlir::Type genType; - auto parseResult = generatedTypeParser(parser, typeTag, genType); + auto parseResult = generatedTypeParser(parser, &typeTag, genType); if (parseResult.hasValue()) return genType; parser.emitError(parser.getNameLoc(), "unknown fir type: ") << typeTag; diff --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md --- a/mlir/docs/AttributesAndTypes.md +++ b/mlir/docs/AttributesAndTypes.md @@ -473,10 +473,10 @@ and have the following function signatures: ```c++ -static ParseResult generatedAttributeParser(DialectAsmParser& parser, StringRef mnemonic, Type attrType, Attribute &result); +static ParseResult generatedAttributeParser(DialectAsmParser& parser, StringRef *mnemonic, Type attrType, Attribute &result); static LogicalResult generatedAttributePrinter(Attribute attr, DialectAsmPrinter& printer); -static ParseResult generatedTypeParser(DialectAsmParser& parser, StringRef mnemonic, Type &result); +static ParseResult generatedTypeParser(DialectAsmParser& parser, StringRef *mnemonic, Type &result); static LogicalResult generatedTypePrinter(Type type, DialectAsmPrinter& printer); ``` diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -571,43 +571,6 @@ /// Parse a quoted string token if present. virtual ParseResult parseOptionalString(std::string *string) = 0; - /// Parse a given keyword. - ParseResult parseKeyword(StringRef keyword) { - return parseKeyword(keyword, ""); - } - virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0; - - /// Parse a keyword into 'keyword'. - ParseResult parseKeyword(StringRef *keyword) { - auto loc = getCurrentLocation(); - if (parseOptionalKeyword(keyword)) - return emitError(loc, "expected valid keyword"); - return success(); - } - - /// Parse the given keyword if present. - virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; - - /// Parse a keyword, if present, into 'keyword'. - virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; - - /// Parse a keyword, if present, and if one of the 'allowedValues', - /// into 'keyword' - virtual ParseResult - parseOptionalKeyword(StringRef *keyword, - ArrayRef allowedValues) = 0; - - /// Parse a keyword or a quoted string. - ParseResult parseKeywordOrString(std::string *result) { - if (failed(parseOptionalKeywordOrString(result))) - return emitError(getCurrentLocation()) - << "expected valid keyword or string"; - return success(); - } - - /// Parse an optional keyword or string. - virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0; - /// Parse a `(` token. virtual ParseResult parseLParen() = 0; @@ -712,6 +675,115 @@ return parseCommaSeparatedList(Delimiter::None, parseElementFn); } + //===--------------------------------------------------------------------===// + // Keyword Parsing + //===--------------------------------------------------------------------===// + + /// This class represents a StringSwitch like class that is useful for parsing + /// expected keywords. On construction, it invokes `parseKeyword` and + /// processes each of the provided cases statements until a match is hit. The + /// provided `ResultT` must be assignable from `failure()`. + template + class KeywordSwitch { + public: + KeywordSwitch(AsmParser &parser) + : parser(parser), loc(parser.getCurrentLocation()) { + if (failed(parser.parseKeywordOrCompletion(&keyword))) + result = failure(); + } + + /// Case that uses the provided value when true. + KeywordSwitch &Case(StringLiteral str, ResultT value) { + return Case(str, [&](StringRef, SMLoc) { return std::move(value); }); + } + KeywordSwitch &Default(ResultT value) { + return Default([&](StringRef, SMLoc) { return std::move(value); }); + } + /// Case that invokes the provided functor when true. The parameters passed + /// to the functor are the keyword, and the location of the keyword (in case + /// any errors need to be emitted). + template + std::enable_if_t::value, KeywordSwitch &> + Case(StringLiteral str, FnT &&fn) { + if (result) + return *this; + + // If the word was empty, record this as a completion. + if (keyword.empty()) + parser.codeCompleteExpectedTokens(str); + else if (keyword == str) + result.emplace(std::move(fn(keyword, loc))); + return *this; + } + template + std::enable_if_t::value, KeywordSwitch &> + Default(FnT &&fn) { + if (!result) + result.emplace(fn(keyword, loc)); + return *this; + } + + /// Returns true if this switch has a value yet. + bool hasValue() const { return result.hasValue(); } + + /// Return the result of the switch. + LLVM_NODISCARD operator ResultT() { + if (!result) + return parser.emitError(loc, "unexpected keyword: ") << keyword; + return std::move(*result); + } + + private: + /// The parser used to construct this switch. + AsmParser &parser; + + /// The location of the keyword, used to emit errors as necessary. + SMLoc loc; + + /// The parsed keyword itself. + StringRef keyword; + + /// The result of the switch statement or none if currently unknown. + Optional result; + }; + + /// Parse a given keyword. + ParseResult parseKeyword(StringRef keyword) { + return parseKeyword(keyword, ""); + } + virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0; + + /// Parse a keyword into 'keyword'. + ParseResult parseKeyword(StringRef *keyword) { + auto loc = getCurrentLocation(); + if (parseOptionalKeyword(keyword)) + return emitError(loc, "expected valid keyword"); + return success(); + } + + /// Parse the given keyword if present. + virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; + + /// Parse a keyword, if present, into 'keyword'. + virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; + + /// Parse a keyword, if present, and if one of the 'allowedValues', + /// into 'keyword' + virtual ParseResult + parseOptionalKeyword(StringRef *keyword, + ArrayRef allowedValues) = 0; + + /// Parse a keyword or a quoted string. + ParseResult parseKeywordOrString(std::string *result) { + if (failed(parseOptionalKeywordOrString(result))) + return emitError(getCurrentLocation()) + << "expected valid keyword or string"; + return success(); + } + + /// Parse an optional keyword or string. + virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0; + //===--------------------------------------------------------------------===// // Attribute/Type Parsing //===--------------------------------------------------------------------===// @@ -1124,6 +1196,17 @@ virtual FailureOr parseResourceHandle(Dialect *dialect) = 0; + //===--------------------------------------------------------------------===// + // Code Completion + //===--------------------------------------------------------------------===// + + /// Parse a keyword, or an empty string if the current location signals a code + /// completion. + virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0; + + /// Signal the code completion of a set of expected tokens. + virtual void codeCompleteExpectedTokens(ArrayRef tokens) = 0; + private: AsmParser(const AsmParser &) = delete; void operator=(const AsmParser &) = delete; diff --git a/mlir/include/mlir/Parser/CodeComplete.h b/mlir/include/mlir/Parser/CodeComplete.h --- a/mlir/include/mlir/Parser/CodeComplete.h +++ b/mlir/include/mlir/Parser/CodeComplete.h @@ -10,9 +10,13 @@ #define MLIR_PARSER_CODECOMPLETE_H #include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" #include "llvm/Support/SourceMgr.h" namespace mlir { +class Attribute; +class Type; + /// This class provides an abstract interface into the parser for hooking in /// code completion events. This class is only really useful for providing /// language tooling for MLIR, general clients should not need to use this @@ -28,8 +32,9 @@ // Completion Hooks //===--------------------------------------------------------------------===// - /// Signal code completion for a dialect name. - virtual void completeDialectName() = 0; + /// Signal code completion for a dialect name, with an optional prefix. + virtual void completeDialectName(StringRef prefix) = 0; + void completeDialectName() { completeDialectName(""); } /// Signal code completion for an operation name within the given dialect. virtual void completeOperationName(StringRef dialectName) = 0; @@ -48,6 +53,16 @@ virtual void completeExpectedTokens(ArrayRef tokens, bool optional) = 0; + /// Signal a completion for an attribute. + virtual void completeAttribute(const llvm::StringMap &aliases) = 0; + virtual void completeDialectAttributeOrAlias( + const llvm::StringMap &aliases) = 0; + + /// Signal a completion for a type. + virtual void completeType(const llvm::StringMap &aliases) = 0; + virtual void + completeDialectTypeOrAlias(const llvm::StringMap &aliases) = 0; + protected: /// Create a new code completion context with the given code complete /// location. diff --git a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp --- a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp @@ -35,11 +35,9 @@ static Type parsePDLType(AsmParser &parser) { StringRef typeTag; - if (parser.parseKeyword(&typeTag)) - return Type(); { Type genType; - auto parseResult = generatedTypeParser(parser, typeTag, genType); + auto parseResult = generatedTypeParser(parser, &typeTag, genType); if (parseResult.hasValue()) return genType; } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp @@ -577,17 +577,11 @@ // Parse the kind keyword first. StringRef attrKind; - if (parser.parseKeyword(&attrKind)) - return {}; - Attribute attr; OptionalParseResult result = - generatedAttributeParser(parser, attrKind, type, attr); - if (result.hasValue()) { - if (failed(result.getValue())) - return {}; + generatedAttributeParser(parser, &attrKind, type, attr); + if (result.hasValue()) return attr; - } if (attrKind == spirv::TargetEnvAttr::getKindName()) return parseTargetEnvAttr(parser); diff --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h --- a/mlir/lib/Parser/AsmParserImpl.h +++ b/mlir/lib/Parser/AsmParserImpl.h @@ -242,6 +242,56 @@ return success(); } + /// Parse a floating point value from the stream. + ParseResult parseFloat(double &result) override { + bool isNegative = parser.consumeIf(Token::minus); + Token curTok = parser.getToken(); + SMLoc loc = curTok.getLoc(); + + // Check for a floating point value. + if (curTok.is(Token::floatliteral)) { + auto val = curTok.getFloatingPointValue(); + if (!val) + return emitError(loc, "floating point value too large"); + parser.consumeToken(Token::floatliteral); + result = isNegative ? -*val : *val; + return success(); + } + + // Check for a hexadecimal float value. + if (curTok.is(Token::integer)) { + Optional apResult; + if (failed(parser.parseFloatFromIntegerLiteral( + apResult, curTok, isNegative, APFloat::IEEEdouble(), + /*typeSizeInBits=*/64))) + return failure(); + + parser.consumeToken(Token::integer); + result = apResult->convertToDouble(); + return success(); + } + + return emitError(loc, "expected floating point literal"); + } + + /// Parse an optional integer value from the stream. + OptionalParseResult parseOptionalInteger(APInt &result) override { + return parser.parseOptionalInteger(result); + } + + /// Parse a list of comma-separated items with an optional delimiter. If a + /// delimiter is provided, then an empty list is allowed. If not, then at + /// least one element will be parsed. + ParseResult parseCommaSeparatedList(Delimiter delimiter, + function_ref parseElt, + StringRef contextMessage) override { + return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage); + } + + //===--------------------------------------------------------------------===// + // Keyword Parsing + //===--------------------------------------------------------------------===// + ParseResult parseKeyword(StringRef keyword, const Twine &msg) override { if (parser.getToken().isCodeCompletion()) return parser.codeCompleteExpectedTokens(keyword); @@ -251,6 +301,7 @@ return emitError(loc, "expected '") << keyword << "'" << msg; return success(); } + using AsmParser::parseKeyword; /// Parse the given keyword if present. ParseResult parseOptionalKeyword(StringRef keyword) override { @@ -308,52 +359,6 @@ return parseOptionalString(result); } - /// Parse a floating point value from the stream. - ParseResult parseFloat(double &result) override { - bool isNegative = parser.consumeIf(Token::minus); - Token curTok = parser.getToken(); - SMLoc loc = curTok.getLoc(); - - // Check for a floating point value. - if (curTok.is(Token::floatliteral)) { - auto val = curTok.getFloatingPointValue(); - if (!val) - return emitError(loc, "floating point value too large"); - parser.consumeToken(Token::floatliteral); - result = isNegative ? -*val : *val; - return success(); - } - - // Check for a hexadecimal float value. - if (curTok.is(Token::integer)) { - Optional apResult; - if (failed(parser.parseFloatFromIntegerLiteral( - apResult, curTok, isNegative, APFloat::IEEEdouble(), - /*typeSizeInBits=*/64))) - return failure(); - - parser.consumeToken(Token::integer); - result = apResult->convertToDouble(); - return success(); - } - - return emitError(loc, "expected floating point literal"); - } - - /// Parse an optional integer value from the stream. - OptionalParseResult parseOptionalInteger(APInt &result) override { - return parser.parseOptionalInteger(result); - } - - /// Parse a list of comma-separated items with an optional delimiter. If a - /// delimiter is provided, then an empty list is allowed. If not, then at - /// least one element will be parsed. - ParseResult parseCommaSeparatedList(Delimiter delimiter, - function_ref parseElt, - StringRef contextMessage) override { - return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage); - } - //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// @@ -528,6 +533,28 @@ return parser.parseXInDimensionList(); } + //===--------------------------------------------------------------------===// + // Code Completion + //===--------------------------------------------------------------------===// + + /// Parse a keyword, or an empty string if the current location signals a code + /// completion. + ParseResult parseKeywordOrCompletion(StringRef *keyword) override { + Token tok = parser.getToken(); + if (tok.isCodeCompletion() && tok.getSpelling().empty()) { + *keyword = ""; + return success(); + } + return parseKeyword(keyword); + } + + /// Signal the code completion of a set of expected tokens. + void codeCompleteExpectedTokens(ArrayRef tokens) override { + Token tok = parser.getToken(); + if (tok.isCodeCompletion() && tok.getSpelling().empty()) + (void)parser.codeCompleteExpectedTokens(tokens); + } + protected: /// The source location of the dialect symbol. SMLoc nameLoc; diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -213,6 +213,12 @@ consumeToken(Token::kw_unit); return builder.getUnitAttr(); + // Handle completion of an attribute. + case Token::code_complete: + if (getToken().isCodeCompletionFor(Token::hash_identifier)) + return parseExtendedAttr(type); + return codeCompleteAttribute(); + default: // Parse a type attribute. We parse `Optional` here to allow for providing a // better error message. diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -43,9 +43,6 @@ }; } // namespace -/// Parse the body of a dialect symbol, which starts and ends with <>'s, and may -/// be recursive. Return with the 'body' StringRef encompassing the entire -/// body. /// /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body @@ -54,7 +51,8 @@ /// | '{' pretty-dialect-sym-contents+ '}' /// | '[^[<({>\])}\0]+' /// -ParseResult Parser::parseDialectSymbolBody(StringRef &body) { +ParseResult Parser::parseDialectSymbolBody(StringRef &body, + bool &isCodeCompletion) { // Symbol bodies are a relatively unstructured format that contains a series // of properly nested punctuation, with anything else in the middle. Scan // ahead to find it and consume it if successful, otherwise emit an error. @@ -65,7 +63,16 @@ // go until we find the matching '>' character. assert(*curPtr == '<'); SmallVector nestedPunctuation; + const char *codeCompleteLoc = state.lex.getCodeCompleteLoc(); do { + // Handle code completions, which may appear in the middle of the symbol + // body. + if (curPtr == codeCompleteLoc) { + isCodeCompletion = true; + nestedPunctuation.clear(); + break; + } + char c = *curPtr++; switch (c) { case '\0': @@ -107,9 +114,19 @@ case '"': { // Dispatch to the lexer to lex past strings. resetToken(curPtr - 1); + curPtr = state.curToken.getEndLoc().getPointer(); + + // Handle code completions, which may appear in the middle of the symbol + // body. + if (state.curToken.isCodeCompletion()) { + isCodeCompletion = true; + nestedPunctuation.clear(); + break; + } + + // Otherwise, ensure this token was actually a string. if (state.curToken.isNot(Token::string)) return failure(); - curPtr = state.curToken.getEndLoc().getPointer(); break; } @@ -129,19 +146,24 @@ /// Parse an extended dialect symbol. template -static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, - SymbolAliasMap &aliases, +static Symbol parseExtendedSymbol(Parser &p, SymbolAliasMap &aliases, CreateFn &&createSymbol) { + Token tok = p.getToken(); + + // Handle code completion of the extended symbol. + StringRef identifier = tok.getSpelling().drop_front(); + if (tok.isCodeCompletion() && identifier.empty()) + return p.codeCompleteDialectSymbol(aliases); + // Parse the dialect namespace. - StringRef identifier = p.getTokenSpelling().drop_front(); SMLoc loc = p.getToken().getLoc(); - p.consumeToken(identifierTok); + p.consumeToken(); // Check to see if this is a pretty name. StringRef dialectName; StringRef symbolData; std::tie(dialectName, symbolData) = identifier.split('.'); - bool isPrettyName = !symbolData.empty(); + bool isPrettyName = !symbolData.empty() || identifier.back() == '.'; // Check to see if the symbol has trailing data, i.e. has an immediately // following '<'. @@ -167,9 +189,17 @@ if (!isPrettyName) { // Point the symbol data to the end of the dialect name to start. symbolData = StringRef(dialectName.end(), 0); - if (p.parseDialectSymbolBody(symbolData)) + + // Parse the body of the symbol. + bool isCodeCompletion = false; + if (p.parseDialectSymbolBody(symbolData, isCodeCompletion)) return nullptr; - symbolData = symbolData.drop_front().drop_back(); + symbolData = symbolData.drop_front(); + + // If the body contained a code completion it won't have the trailing `>` + // token, so don't drop it. + if (!isCodeCompletion) + symbolData = symbolData.drop_back(); } else { loc = SMLoc::getFromPointer(symbolData.data()); @@ -192,7 +222,7 @@ Attribute Parser::parseExtendedAttr(Type type) { MLIRContext *ctx = getContext(); Attribute attr = parseExtendedSymbol( - *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions, + *this, state.symbols.attributeAliasDefinitions, [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute { // Parse an optional trailing colon type. Type attrType = type; @@ -238,7 +268,7 @@ Type Parser::parseExtendedType() { MLIRContext *ctx = getContext(); return parseExtendedSymbol( - *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions, + *this, state.symbols.typeAliasDefinitions, [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type { // If we found a registered dialect, then ask it to parse the type. if (auto *dialect = ctx->getOrLoadDialect(dialectName)) { diff --git a/mlir/lib/Parser/Lexer.h b/mlir/lib/Parser/Lexer.h --- a/mlir/lib/Parser/Lexer.h +++ b/mlir/lib/Parser/Lexer.h @@ -40,6 +40,10 @@ /// Returns the start of the buffer. const char *getBufferBegin() { return curBuffer.data(); } + /// Return the code completion location of the lexer, or nullptr if there is + /// none. + const char *getCodeCompleteLoc() const { return codeCompleteLoc; } + private: // Helpers. Token formToken(Token::Kind kind, const char *tokStart) { diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -57,7 +57,16 @@ return parseCommaSeparatedList(Delimiter::None, parseElementFn); } - ParseResult parseDialectSymbolBody(StringRef &body); + /// Parse the body of a dialect symbol, which starts and ends with <>'s, and + /// may be recursive. Return with the 'body' StringRef encompassing the entire + /// body. `isCodeCompletion` is set to true if the body contained a code + /// completion location, in which case the body is only populated up to the + /// completion. + ParseResult parseDialectSymbolBody(StringRef &body, bool &isCodeCompletion); + ParseResult parseDialectSymbolBody(StringRef &body) { + bool isCodeCompletion = false; + return parseDialectSymbolBody(body, isCodeCompletion); + } // We have two forms of parsing methods - those that return a non-null // pointer on success, and those that return a ParseResult to indicate whether @@ -322,6 +331,12 @@ ParseResult codeCompleteExpectedTokens(ArrayRef tokens); ParseResult codeCompleteOptionalTokens(ArrayRef tokens); + Attribute codeCompleteAttribute(); + Type codeCompleteType(); + Attribute + codeCompleteDialectSymbol(const llvm::StringMap &aliases); + Type codeCompleteDialectSymbol(const llvm::StringMap &aliases); + protected: /// The Parser is subclassed and reinstantiated. Do not add additional /// non-trivial state here, add it to the ParserState class. diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -404,6 +404,26 @@ return failure(); } +Attribute Parser::codeCompleteAttribute() { + state.codeCompleteContext->completeAttribute( + state.symbols.attributeAliasDefinitions); + return {}; +} +Type Parser::codeCompleteType() { + state.codeCompleteContext->completeType(state.symbols.typeAliasDefinitions); + return {}; +} + +Attribute +Parser::codeCompleteDialectSymbol(const llvm::StringMap &aliases) { + state.codeCompleteContext->completeDialectAttributeOrAlias(aliases); + return {}; +} +Type Parser::codeCompleteDialectSymbol(const llvm::StringMap &aliases) { + state.codeCompleteContext->completeDialectTypeOrAlias(aliases); + return {}; +} + //===----------------------------------------------------------------------===// // OperationParser //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -358,6 +358,12 @@ // extended type case Token::exclamation_identifier: return parseExtendedType(); + + // Handle completion of a dialect type. + case Token::code_complete: + if (getToken().isCodeCompletionFor(Token::exclamation_identifier)) + return parseExtendedType(); + return codeCompleteType(); } } diff --git a/mlir/lib/Tools/lsp-server-support/Protocol.h b/mlir/lib/Tools/lsp-server-support/Protocol.h --- a/mlir/lib/Tools/lsp-server-support/Protocol.h +++ b/mlir/lib/Tools/lsp-server-support/Protocol.h @@ -781,8 +781,9 @@ struct CompletionItem { CompletionItem() = default; - CompletionItem(StringRef label, CompletionItemKind kind) - : label(label.str()), kind(kind), + CompletionItem(const Twine &label, CompletionItemKind kind, + StringRef sortText = "") + : label(label.str()), kind(kind), sortText(sortText.str()), insertTextFormat(InsertTextFormat::PlainText) {} /// The label of this completion item. By default also the text that is diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -636,15 +636,17 @@ : AsmParserCodeCompleteContext(completeLoc), completionList(completionList), ctx(ctx) {} - /// Signal code completion for a dialect name. - void completeDialectName() final { + /// Signal code completion for a dialect name, with an optional prefix. + void completeDialectName(StringRef prefix) final { for (StringRef dialect : ctx->getAvailableDialects()) { - lsp::CompletionItem item(dialect, lsp::CompletionItemKind::Module); - item.sortText = "2"; + lsp::CompletionItem item(prefix + dialect, + lsp::CompletionItemKind::Module, + /*sortText=*/"3"); item.detail = "dialect"; completionList.items.emplace_back(item); } } + using AsmParserCodeCompleteContext::completeDialectName; /// Signal code completion for an operation name within the given dialect. void completeOperationName(StringRef dialectName) final { @@ -658,8 +660,8 @@ lsp::CompletionItem item( op.getStringRef().drop_front(dialectName.size() + 1), - lsp::CompletionItemKind::Field); - item.sortText = "1"; + lsp::CompletionItemKind::Field, + /*sortText=*/"1"); item.detail = "operation"; completionList.items.emplace_back(item); } @@ -693,13 +695,71 @@ /// Signal a completion for the given expected token. void completeExpectedTokens(ArrayRef tokens, bool optional) final { for (StringRef token : tokens) { - lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword); - item.sortText = "0"; + lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword, + /*sortText=*/"0"); item.detail = optional ? "optional" : ""; completionList.items.emplace_back(item); } } + /// Signal a completion for an attribute. + void completeAttribute(const llvm::StringMap &aliases) override { + appendSimpleCompletions({"affine_set", "affine_map", "dense", "false", + "loc", "opaque", "sparse", "true", "unit"}, + lsp::CompletionItemKind::Field, + /*sortText=*/"1"); + + completeDialectName("#"); + completeAliases(aliases, "#"); + } + void completeDialectAttributeOrAlias( + const llvm::StringMap &aliases) override { + completeDialectName(); + completeAliases(aliases); + } + + /// Signal a completion for a type. + void completeType(const llvm::StringMap &aliases) override { + appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector", + "bf16", "f16", "f32", "f64", "f80", "f128", + "index", "none"}, + lsp::CompletionItemKind::Field, + /*sortText=*/"1"); + lsp::CompletionItem item("i", lsp::CompletionItemKind::Field, + /*sortText=*/"1"); + item.insertText = "i"; + completionList.items.emplace_back(item); + + completeDialectName("!"); + completeAliases(aliases, "!"); + } + void + completeDialectTypeOrAlias(const llvm::StringMap &aliases) override { + completeDialectName(); + completeAliases(aliases); + } + + /// Add completion results for the given set of aliases. + template + void completeAliases(const llvm::StringMap &aliases, + StringRef prefix = "") { + for (const auto &alias : aliases) { + lsp::CompletionItem item(prefix + alias.getKey(), + lsp::CompletionItemKind::Field, + /*sortText=*/"2"); + llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue(); + completionList.items.emplace_back(item); + } + } + + /// Add a set of simple completions that all have the same kind. + void appendSimpleCompletions(ArrayRef completions, + lsp::CompletionItemKind kind, + StringRef sortText = "") { + for (StringRef completion : completions) + completionList.items.emplace_back(completion, kind, sortText); + } + private: lsp::CompletionList &completionList; MLIRContext *ctx; diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -408,12 +408,9 @@ Type TestDialect::parseTestType(AsmParser &parser, SetVector &stack) const { StringRef typeTag; - if (failed(parser.parseKeyword(&typeTag))) - return Type(); - { Type genType; - auto parseResult = generatedTypeParser(parser, typeTag, genType); + auto parseResult = generatedTypeParser(parser, &typeTag, genType); if (parseResult.hasValue()) return genType; } diff --git a/mlir/test/mlir-lsp-server/completion.test b/mlir/test/mlir-lsp-server/completion.test --- a/mlir/test/mlir-lsp-server/completion.test +++ b/mlir/test/mlir-lsp-server/completion.test @@ -5,14 +5,14 @@ "uri":"test:///foo.mlir", "languageId":"mlir", "version":1, - "text":"func.func private @foo(%arg: i32) -> i32 {\n%cast = \"builtin.unrealized_conversion_cast\"() : () -> (i32)\nreturn %" + "text":"#attr = i32\n!alias = i32\nfunc.func private @foo(%arg: i32) -> i32 {\n%cast = \"builtin.unrealized_conversion_cast\"() : () -> (!pdl.value)\nreturn %" }}} // ----- {"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{ "textDocument":{"uri":"test:///foo.mlir"}, - "position":{"line":0,"character":0} + "position":{"line":2,"character":0} }} -// CHECK: "id": 1 +// CHECK-LABEL: "id": 1 // CHECK-NEXT: "jsonrpc": "2.0", // CHECK-NEXT: "result": { // CHECK-NEXT: "isIncomplete": false, @@ -22,7 +22,7 @@ // CHECK: "insertTextFormat": 1, // CHECK: "kind": 9, // CHECK: "label": "builtin", -// CHECK: "sortText": "2" +// CHECK: "sortText": "3" // CHECK: }, // CHECK: { // CHECK: "detail": "operation", @@ -34,11 +34,11 @@ // CHECK: ] // CHECK-NEXT: } // ----- -{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{ +{"jsonrpc":"2.0","id":2,"method":"textDocument/completion","params":{ "textDocument":{"uri":"test:///foo.mlir"}, - "position":{"line":1,"character":9} + "position":{"line":3,"character":9} }} -// CHECK: "id": 1 +// CHECK-LABEL: "id": 2 // CHECK-NEXT: "jsonrpc": "2.0", // CHECK-NEXT: "result": { // CHECK-NEXT: "isIncomplete": false, @@ -48,17 +48,17 @@ // CHECK: "insertTextFormat": 1, // CHECK: "kind": 9, // CHECK: "label": "builtin", -// CHECK: "sortText": "2" +// CHECK: "sortText": "3" // CHECK: }, // CHECK-NOT: "detail": "operation", // CHECK: ] // CHECK-NEXT: } // ----- -{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{ +{"jsonrpc":"2.0","id":3,"method":"textDocument/completion","params":{ "textDocument":{"uri":"test:///foo.mlir"}, - "position":{"line":1,"character":17} + "position":{"line":3,"character":17} }} -// CHECK: "id": 1 +// CHECK-LABEL: "id": 3 // CHECK-NEXT: "jsonrpc": "2.0", // CHECK-NEXT: "result": { // CHECK-NEXT: "isIncomplete": false, @@ -74,17 +74,17 @@ // CHECK: ] // CHECK-NEXT: } // ----- -{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{ +{"jsonrpc":"2.0","id":4,"method":"textDocument/completion","params":{ "textDocument":{"uri":"test:///foo.mlir"}, - "position":{"line":2,"character":8} + "position":{"line":4,"character":8} }} -// CHECK: "id": 1 +// CHECK-LABEL: "id": 4 // CHECK-NEXT: "jsonrpc": "2.0", // CHECK-NEXT: "result": { // CHECK-NEXT: "isIncomplete": false, // CHECK-NEXT: "items": [ // CHECK-NEXT: { -// CHECK-NEXT: "detail": "builtin.unrealized_conversion_cast: i32", +// CHECK-NEXT: "detail": "builtin.unrealized_conversion_cast: !pdl.value", // CHECK-NEXT: "insertText": "cast", // CHECK-NEXT: "insertTextFormat": 1, // CHECK-NEXT: "kind": 6, @@ -100,11 +100,11 @@ // CHECK: ] // CHECK-NEXT: } // ----- -{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{ +{"jsonrpc":"2.0","id":5,"method":"textDocument/completion","params":{ "textDocument":{"uri":"test:///foo.mlir"}, - "position":{"line":0,"character":10} + "position":{"line":2,"character":10} }} -// CHECK: "id": 1 +// CHECK-LABEL: "id": 5 // CHECK-NEXT: "jsonrpc": "2.0", // CHECK-NEXT: "result": { // CHECK-NEXT: "isIncomplete": false, @@ -133,6 +133,134 @@ // CHECK-NEXT: ] // CHECK-NEXT: } // ----- -{"jsonrpc":"2.0","id":3,"method":"shutdown"} +{"jsonrpc":"2.0","id":6,"method":"textDocument/completion","params":{ + "textDocument":{"uri":"test:///foo.mlir"}, + "position":{"line":0,"character":8} +}} +// CHECK-LABEL: "id": 6 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "isIncomplete": false, +// CHECK-NEXT: "items": [ +// CHECK: { +// CHECK: "insertTextFormat": 1, +// CHECK: "kind": 5, +// CHECK: "label": "false" +// CHECK: }, +// CHECK: { +// CHECK: "insertTextFormat": 1, +// CHECK: "kind": 5, +// CHECK: "label": "loc" +// CHECK: }, +// CHECK: { +// CHECK: "insertTextFormat": 1, +// CHECK: "kind": 5, +// CHECK: "label": "true" +// CHECK: }, +// CHECK: { +// CHECK: "insertTextFormat": 1, +// CHECK: "kind": 5, +// CHECK: "label": "unit" +// CHECK: } +// CHECK: ] +// CHECK: } +// ----- +{"jsonrpc":"2.0","id":7,"method":"textDocument/completion","params":{ + "textDocument":{"uri":"test:///foo.mlir"}, + "position":{"line":3,"character":56} +}} +// CHECK-LABEL: "id": 7 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "isIncomplete": false, +// CHECK-NEXT: "items": [ +// CHECK: { +// CHECK: "insertTextFormat": 1, +// CHECK: "kind": 5, +// CHECK: "label": "index" +// CHECK: }, +// CHECK: { +// CHECK: "insertTextFormat": 1, +// CHECK: "kind": 5, +// CHECK: "label": "none" +// CHECK: }, +// CHECK: { +// CHECK: "insertText": "i", +// CHECK: "insertTextFormat": 1, +// CHECK: "kind": 5, +// CHECK: "label": "i" +// CHECK: } +// CHECK: ] +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":8,"method":"textDocument/completion","params":{ + "textDocument":{"uri":"test:///foo.mlir"}, + "position":{"line":3,"character":57} +}} +// CHECK-LABEL: "id": 8 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "isIncomplete": false, +// CHECK-NEXT: "items": [ +// CHECK: { +// CHECK: "detail": "dialect", +// CHECK: "insertTextFormat": 1, +// CHECK: "kind": 9, +// CHECK: "label": "builtin", +// CHECK: "sortText": "3" +// CHECK: }, +// CHECK: { +// CHECK: "detail": "alias: i32", +// CHECK: "insertTextFormat": 1, +// CHECK: "kind": 5, +// CHECK: "label": "alias", +// CHECK: "sortText": "2" +// CHECK: } +// CHECK: ] +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":9,"method":"textDocument/completion","params":{ + "textDocument":{"uri":"test:///foo.mlir"}, + "position":{"line":3,"character":61} +}} +// CHECK-LABEL: "id": 9 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "isIncomplete": false, +// CHECK-NEXT: "items": [ +// CHECK-NEXT: { +// CHECK-NEXT: "insertTextFormat": 1, +// CHECK-NEXT: "kind": 14, +// CHECK-NEXT: "label": "attribute", +// CHECK-NEXT: "sortText": "0" +// CHECK-NEXT: }, +// CHECK-NEXT: { +// CHECK-NEXT: "insertTextFormat": 1, +// CHECK-NEXT: "kind": 14, +// CHECK-NEXT: "label": "operation", +// CHECK-NEXT: "sortText": "0" +// CHECK-NEXT: }, +// CHECK-NEXT: { +// CHECK-NEXT: "insertTextFormat": 1, +// CHECK-NEXT: "kind": 14, +// CHECK-NEXT: "label": "range", +// CHECK-NEXT: "sortText": "0" +// CHECK-NEXT: }, +// CHECK-NEXT: { +// CHECK-NEXT: "insertTextFormat": 1, +// CHECK-NEXT: "kind": 14, +// CHECK-NEXT: "label": "type", +// CHECK-NEXT: "sortText": "0" +// CHECK-NEXT: }, +// CHECK-NEXT: { +// CHECK-NEXT: "insertTextFormat": 1, +// CHECK-NEXT: "kind": 14, +// CHECK-NEXT: "label": "value", +// CHECK-NEXT: "sortText": "0" +// CHECK-NEXT: } +// CHECK-NEXT: ] +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":10,"method":"shutdown"} // ----- {"jsonrpc":"2.0","method":"exit"} diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -21,16 +21,19 @@ // DEF-LABEL: ::mlir::OptionalParseResult generatedAttributeParser( // DEF-SAME: ::mlir::AsmParser &parser, -// DEF-SAME: ::llvm::StringRef mnemonic, ::mlir::Type type, +// DEF-SAME: ::llvm::StringRef *mnemonic, ::mlir::Type type, // DEF-SAME: ::mlir::Attribute &value) { -// DEF: if (mnemonic == ::test::CompoundAAttr::getMnemonic()) { +// DEF: return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser) +// DEF: .Case(::test::CompoundAAttr::getMnemonic() // DEF-NEXT: value = ::test::CompoundAAttr::parse(parser, type); // DEF-NEXT: return ::mlir::success(!!value); -// DEF-NEXT: } -// DEF-NEXT: if (mnemonic == ::test::IndexAttr::getMnemonic()) { +// DEF-NEXT: }) +// DEF-NEXT: .Case(::test::IndexAttr::getMnemonic() // DEF-NEXT: value = ::test::IndexAttr::parse(parser, type); // DEF-NEXT: return ::mlir::success(!!value); -// DEF: return {}; +// DEF: .Default([&](llvm::StringRef keyword, +// DEF-NEXT: *mnemonic = keyword; +// DEF-NEXT: return llvm::None; def Test_Dialect: Dialect { // DECL-NOT: TestDialect diff --git a/mlir/test/mlir-tblgen/default-type-attr-print-parser.td b/mlir/test/mlir-tblgen/default-type-attr-print-parser.td --- a/mlir/test/mlir-tblgen/default-type-attr-print-parser.td +++ b/mlir/test/mlir-tblgen/default-type-attr-print-parser.td @@ -27,11 +27,9 @@ // ATTR: ::mlir::Type type) const { // ATTR: ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); // ATTR: ::llvm::StringRef attrTag; -// ATTR: if (::mlir::failed(parser.parseKeyword(&attrTag))) -// ATTR: return {}; // ATTR: { // ATTR: ::mlir::Attribute attr; -// ATTR: auto parseResult = generatedAttributeParser(parser, attrTag, type, attr); +// ATTR: auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr); // ATTR: if (parseResult.hasValue()) // ATTR: return attr; // ATTR: } @@ -57,10 +55,8 @@ // TYPE: ::mlir::Type TestDialect::parseType(::mlir::DialectAsmParser &parser) const { // TYPE: ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); // TYPE: ::llvm::StringRef mnemonic; -// TYPE: if (parser.parseKeyword(&mnemonic)) -// TYPE: return ::mlir::Type(); // TYPE: ::mlir::Type genType; -// TYPE: auto parseResult = generatedTypeParser(parser, mnemonic, genType); +// TYPE: auto parseResult = generatedTypeParser(parser, &mnemonic, genType); // TYPE: if (parseResult.hasValue()) // TYPE: return genType; // TYPE: parser.emitError(typeLoc) << "unknown type `" diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td --- a/mlir/test/mlir-tblgen/typedefs.td +++ b/mlir/test/mlir-tblgen/typedefs.td @@ -22,16 +22,18 @@ // DEF-LABEL: ::mlir::OptionalParseResult generatedTypeParser( // DEF-SAME: ::mlir::AsmParser &parser, -// DEF-SAME: ::llvm::StringRef mnemonic, +// DEF-SAME: ::llvm::StringRef *mnemonic, // DEF-SAME: ::mlir::Type &value) { -// DEF: if (mnemonic == ::test::CompoundAType::getMnemonic()) { +// DEF: .Case(::test::CompoundAType::getMnemonic() // DEF-NEXT: value = ::test::CompoundAType::parse(parser); // DEF-NEXT: return ::mlir::success(!!value); -// DEF-NEXT: } -// DEF-NEXT: if (mnemonic == ::test::IndexType::getMnemonic()) { +// DEF-NEXT: }) +// DEF-NEXT: .Case(::test::IndexType::getMnemonic() // DEF-NEXT: value = ::test::IndexType::parse(parser); // DEF-NEXT: return ::mlir::success(!!value); -// DEF: return {}; +// DEF: .Default([&](llvm::StringRef keyword, +// DEF-NEXT: *mnemonic = keyword; +// DEF-NEXT: return llvm::None; def Test_Dialect: Dialect { // DECL-NOT: TestDialect diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -673,11 +673,9 @@ ::mlir::Type type) const {{ ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); ::llvm::StringRef attrTag; - if (::mlir::failed(parser.parseKeyword(&attrTag))) - return {{}; {{ ::mlir::Attribute attr; - auto parseResult = generatedAttributeParser(parser, attrTag, type, attr); + auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr); if (parseResult.hasValue()) return attr; } @@ -723,10 +721,8 @@ ::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{ ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); ::llvm::StringRef mnemonic; - if (parser.parseKeyword(&mnemonic)) - return ::mlir::Type(); ::mlir::Type genType; - auto parseResult = generatedTypeParser(parser, mnemonic, genType); + auto parseResult = generatedTypeParser(parser, &mnemonic, genType); if (parseResult.hasValue()) return genType; {1} @@ -771,7 +767,7 @@ } // Declare the parser. SmallVector params = {{"::mlir::AsmParser &", "parser"}, - {"::llvm::StringRef", "mnemonic"}}; + {"::llvm::StringRef *", "mnemonic"}}; if (isAttrGenerator) params.emplace_back("::mlir::Type", "type"); params.emplace_back(strfmt("::mlir::{0} &", valueType), "value"); @@ -784,14 +780,18 @@ {{strfmt("::mlir::{0}", valueType), "def"}, {"::mlir::AsmPrinter &", "printer"}}); - // The parser dispatch is just a list of if-elses, matching on the mnemonic - // and calling the def's parse function. + // The parser dispatch uses a KeywordSwitch, matching on the mnemonic and + // calling the def's parse function. + parse.body() << " return " + "::mlir::AsmParser::KeywordSwitch<::mlir::" + "OptionalParseResult>(parser)\n"; const char *const getValueForMnemonic = - R"( if (mnemonic == {0}::getMnemonic()) {{ - value = {0}::{1}; - return ::mlir::success(!!value); - } + R"( .Case({0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {{ + value = {0}::{1}; + return ::mlir::success(!!value); + }) )"; + // The printer dispatch uses llvm::TypeSwitch to find and call the correct // printer. printer.body() << " return ::llvm::TypeSwitch<::mlir::" << valueType @@ -822,7 +822,10 @@ printDef = "\nt.print(printer);"; printer.body() << llvm::formatv(printValue, defClass, printDef); } - parse.body() << " return {};"; + parse.body() << " .Default([&](llvm::StringRef keyword, llvm::SMLoc) {\n" + " *mnemonic = keyword;\n" + " return llvm::None;\n" + " });"; printer.body() << " .Default([](auto) { return ::mlir::failure(); });"; raw_indented_ostream indentedOs(os);