diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -413,6 +413,35 @@ /// Parse a `...` token if present; virtual ParseResult parseOptionalEllipsis() = 0; + /// Parse an integer value from the stream. + template ParseResult parseInteger(IntT &result) { + auto loc = getCurrentLocation(); + OptionalParseResult parseResult = parseOptionalInteger(result); + if (!parseResult.hasValue()) + return emitError(loc, "expected integer value"); + return *parseResult; + } + + /// Parse an optional integer value from the stream. + virtual OptionalParseResult parseOptionalInteger(uint64_t &result) = 0; + + template + OptionalParseResult parseOptionalInteger(IntT &result) { + auto loc = getCurrentLocation(); + + // Parse the unsigned variant. + uint64_t uintResult; + OptionalParseResult parseResult = parseOptionalInteger(uintResult); + if (!parseResult.hasValue() || failed(*parseResult)) + return parseResult; + + // Try to convert to the provided integer type. + result = IntT(uintResult); + if (uint64_t(result) != uintResult) + return emitError(loc, "integer value too large"); + return success(); + } + //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -82,20 +82,7 @@ /// Parse an optional integer value from the stream. OptionalParseResult parseOptionalInteger(uint64_t &result) override { - Token curToken = parser.getToken(); - if (curToken.isNot(Token::integer, Token::minus)) - return llvm::None; - - bool negative = parser.consumeIf(Token::minus); - Token curTok = parser.getToken(); - if (parser.parseToken(Token::integer, "expected integer value")) - return failure(); - - auto val = curTok.getUInt64IntegerValue(); - if (!val) - return emitError(curTok.getLoc(), "integer value too large"); - result = negative ? -*val : *val; - return success(); + return parser.parseOptionalInteger(result); } //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -127,6 +127,9 @@ /// output a diagnostic and return failure. ParseResult parseToken(Token::Kind expectedToken, const Twine &message); + /// Parse an optional integer value from the stream. + OptionalParseResult parseOptionalInteger(uint64_t &result); + //===--------------------------------------------------------------------===// // Type Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -94,6 +94,24 @@ return emitError(message); } +/// Parse an optional integer value from the stream. +OptionalParseResult Parser::parseOptionalInteger(uint64_t &result) { + Token curToken = getToken(); + if (curToken.isNot(Token::integer, Token::minus)) + return llvm::None; + + bool negative = consumeIf(Token::minus); + Token curTok = getToken(); + if (parseToken(Token::integer, "expected integer value")) + return failure(); + + auto val = curTok.getUInt64IntegerValue(); + if (!val) + return emitError(curTok.getLoc(), "integer value too large"); + result = negative ? -*val : *val; + return success(); +} + //===----------------------------------------------------------------------===// // OperationParser //===----------------------------------------------------------------------===// @@ -1109,6 +1127,11 @@ return success(parser.consumeIf(Token::star)); } + /// Parse an optional integer value from the stream. + OptionalParseResult parseOptionalInteger(uint64_t &result) override { + return parser.parseOptionalInteger(result); + } + //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1174,10 +1174,17 @@ // CHECK-LABEL: func private @escaped_string_char(i1 {foo.value = "\0A"}) func private @escaped_string_char(i1 {foo.value = "\n"}) -// CHECK-LABEL: func @wrapped_keyword_test -func @wrapped_keyword_test() { - // CHECK: test.wrapped_keyword foo.keyword - test.wrapped_keyword foo.keyword +// CHECK-LABEL: func @parse_integer_literal_test +func @parse_integer_literal_test() { + // CHECK: test.parse_integer_literal : 5 + test.parse_integer_literal : 5 + return +} + +// CHECK-LABEL: func @parse_wrapped_keyword_test +func @parse_wrapped_keyword_test() { + // CHECK: test.parse_wrapped_keyword foo.keyword + test.parse_wrapped_keyword foo.keyword return } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -564,8 +564,28 @@ // Test parser. //===----------------------------------------------------------------------===// -static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, - OperationState &result) { +static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, + OperationState &result) { + if (parser.parseOptionalColon()) + return success(); + uint64_t numResults; + if (parser.parseInteger(numResults)) + return failure(); + + IndexType type = parser.getBuilder().getIndexType(); + for (unsigned i = 0; i < numResults; ++i) + result.addTypes(type); + return success(); +} + +static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { + p << ParseIntegerLiteralOp::getOperationName(); + if (unsigned numResults = op->getNumResults()) + p << " : " << numResults; +} + +static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, + OperationState &result) { StringRef keyword; if (parser.parseKeyword(&keyword)) return failure(); @@ -573,8 +593,8 @@ return success(); } -static void print(OpAsmPrinter &p, WrappedKeywordOp op) { - p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); +static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { + p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1293,7 +1293,13 @@ // Test parser. //===----------------------------------------------------------------------===// -def WrappedKeywordOp : TEST_Op<"wrapped_keyword"> { +def ParseIntegerLiteralOp : TEST_Op<"parse_integer_literal"> { + let results = (outs Variadic:$results); + let parser = [{ return ::parse$cppClass(parser, result); }]; + let printer = [{ return ::print(p, *this); }]; +} + +def ParseWrappedKeywordOp : TEST_Op<"parse_wrapped_keyword"> { let arguments = (ins StrAttr:$keyword); let parser = [{ return ::parse$cppClass(parser, result); }]; let printer = [{ return ::print(p, *this); }];