diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h --- a/mlir/include/mlir/IR/DialectImplementation.h +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -139,7 +139,8 @@ virtual ParseResult parseFloat(double &result) = 0; /// Parse an integer value from the stream. - template ParseResult parseInteger(IntT &result) { + template + ParseResult parseInteger(IntT &result) { auto loc = getCurrentLocation(); OptionalParseResult parseResult = parseOptionalInteger(result); if (!parseResult.hasValue()) @@ -148,21 +149,23 @@ } /// Parse an optional integer value from the stream. - virtual OptionalParseResult parseOptionalInteger(uint64_t &result) = 0; + virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0; template OptionalParseResult parseOptionalInteger(IntT &result) { auto loc = getCurrentLocation(); // Parse the unsigned variant. - uint64_t uintResult; + APInt uintResult; OptionalParseResult parseResult = parseOptionalInteger(uintResult); if (!parseResult.hasValue() || failed(*parseResult)) return parseResult; - // Try to convert to the provided integer type. - result = IntT(uintResult); - if (uint64_t(result) != uintResult) + // Try to convert to the provided integer type. sextOrTrunc is correct even + // for unsigned types because parseOptionalInteger ensures the sign bit is + // zero for non-negated integers. + result = (IntT)uintResult.sextOrTrunc(sizeof(IntT) * 8).getLimitedValue(); + if (APInt(uintResult.getBitWidth(), result) != uintResult) return emitError(loc, "integer value too large"); return success(); } @@ -172,13 +175,14 @@ /// unlike `OpBuilder::getType`, this method does not implicitly insert a /// context parameter. template - T getChecked(llvm::SMLoc loc, ParamsT &&...params) { + T getChecked(llvm::SMLoc loc, ParamsT &&... params) { return T::getChecked([&] { return emitError(loc); }, std::forward(params)...); } /// A variant of `getChecked` that uses the result of `getNameLoc` to emit /// errors. - template T getChecked(ParamsT &&...params) { + template + T getChecked(ParamsT &&... params) { return T::getChecked([&] { return emitError(getNameLoc()); }, std::forward(params)...); } @@ -331,7 +335,8 @@ virtual ParseResult parseType(Type &result) = 0; /// Parse a type of a specific kind, e.g. a FunctionType. - template ParseResult parseType(TypeType &result) { + template + ParseResult parseType(TypeType &result) { llvm::SMLoc loc = getCurrentLocation(); // Parse any kind of type. diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -435,21 +435,23 @@ } /// Parse an optional integer value from the stream. - virtual OptionalParseResult parseOptionalInteger(uint64_t &result) = 0; + virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0; template OptionalParseResult parseOptionalInteger(IntT &result) { auto loc = getCurrentLocation(); // Parse the unsigned variant. - uint64_t uintResult; + APInt uintResult; OptionalParseResult parseResult = parseOptionalInteger(uintResult); if (!parseResult.hasValue() || failed(*parseResult)) return parseResult; - // Try to convert to the provided integer type. - result = IntT(uintResult); - if (uint64_t(result) != uintResult) + // Try to convert to the provided integer type. sextOrTrunc is correct even + // for unsigned types because parseOptionalInteger ensures the sign bit is + // zero for non-negated integers. + result = (IntT)uintResult.sextOrTrunc(sizeof(IntT) * 8).getLimitedValue(); + if (APInt(uintResult.getBitWidth(), result) != uintResult) return emitError(loc, "integer value too large"); return success(); } diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -94,7 +94,7 @@ } /// Parse an optional integer value from the stream. - OptionalParseResult parseOptionalInteger(uint64_t &result) override { + OptionalParseResult parseOptionalInteger(APInt &result) override { return parser.parseOptionalInteger(result); } diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -128,7 +128,7 @@ ParseResult parseToken(Token::Kind expectedToken, const Twine &message); /// Parse an optional integer value from the stream. - OptionalParseResult parseOptionalInteger(uint64_t &result); + OptionalParseResult parseOptionalInteger(APInt &result); /// Parse a floating point value from an integer literal token. ParseResult parseFloatFromIntegerLiteral(Optional &result, diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -96,7 +96,7 @@ } /// Parse an optional integer value from the stream. -OptionalParseResult Parser::parseOptionalInteger(uint64_t &result) { +OptionalParseResult Parser::parseOptionalInteger(APInt &result) { Token curToken = getToken(); if (curToken.isNot(Token::integer, Token::minus)) return llvm::None; @@ -106,10 +106,19 @@ if (parseToken(Token::integer, "expected integer value")) return failure(); - auto val = curTok.getUInt64IntegerValue(); - if (!val) + StringRef spelling = curTok.getSpelling(); + bool isHex = spelling.size() > 1 && spelling[1] == 'x'; + if (spelling.getAsInteger(isHex ? 0 : 10, result)) return emitError(curTok.getLoc(), "integer value too large"); - result = negative ? -*val : *val; + + // Make sure we have a zero at the top so we return the right signedness. + if (result.isNegative()) + result = result.zext(result.getBitWidth() + 1); + + // Process the negative sign if present. + if (negative) + result.negate(); + return success(); } @@ -1217,7 +1226,7 @@ } /// Parse an optional integer value from the stream. - OptionalParseResult parseOptionalInteger(uint64_t &result) override { + OptionalParseResult parseOptionalInteger(APInt &result) override { return parser.parseOptionalInteger(result); }