diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -307,20 +307,6 @@ return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); } -/// Construct a float attribute bitwise equivalent to the integer literal. -static Optional buildHexadecimalFloatLiteral(Parser *p, FloatType type, - uint64_t value) { - if (type.isF64()) - return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value)); - - APInt apInt(type.getWidth(), value); - if (apInt != value) { - p->emitError("hexadecimal float constant out of range for type"); - return llvm::None; - } - return APFloat(type.getFloatSemantics(), apInt); -} - /// Construct an APint from a parsed value, a known attribute type and /// sign. static Optional buildAttributeAPInt(Type type, bool isNegative, @@ -369,10 +355,9 @@ /// Parse a decimal or a hexadecimal literal, which can be either an integer /// or a float attribute. Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { - // Remember if the literal is hexadecimal. - StringRef spelling = getToken().getSpelling(); - auto loc = state.curToken.getLoc(); - bool isHex = spelling.size() > 1 && spelling[1] == 'x'; + Token tok = getToken(); + StringRef spelling = tok.getSpelling(); + llvm::SMLoc loc = tok.getLoc(); consumeToken(Token::integer); if (!type) { @@ -384,26 +369,12 @@ } if (auto floatType = type.dyn_cast()) { - if (isNegative) - return emitError( - loc, - "hexadecimal float literal should not have a leading minus"), - nullptr; - if (!isHex) { - emitError(loc, "unexpected decimal integer literal for a float attribute") - .attachNote() - << "add a trailing dot to make the literal a float"; - return nullptr; - } - - auto val = Token::getUInt64IntegerValue(spelling); - if (!val.hasValue()) - return emitError("integer constant out of range for attribute"), nullptr; - - // Construct a float attribute bitwise equivalent to the integer literal. - Optional apVal = - buildHexadecimalFloatLiteral(this, floatType, *val); - return apVal ? FloatAttr::get(floatType, *apVal) : Attribute(); + Optional result; + if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative, + floatType.getFloatSemantics(), + floatType.getWidth()))) + return Attribute(); + return FloatAttr::get(floatType, *result); } if (!type.isa()) @@ -638,19 +609,13 @@ // Handle hexadecimal float literals. if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { - if (isNegative) { - return p.emitError(token.getLoc()) - << "hexadecimal float literal should not have a leading minus"; - } - auto val = token.getUInt64IntegerValue(); - if (!val.hasValue()) { - return p.emitError( - "hexadecimal float constant out of range for attribute"); - } - Optional apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val); - if (!apVal) + Optional result; + if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative, + eltTy.getFloatSemantics(), + eltTy.getWidth()))) return failure(); - floatValues.push_back(*apVal); + + floatValues.push_back(*result); continue; } diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -63,21 +63,34 @@ /// Parse a floating point value from the stream. ParseResult parseFloat(double &result) override { - bool negative = parser.consumeIf(Token::minus); + bool isNegative = parser.consumeIf(Token::minus); Token curTok = parser.getToken(); + llvm::SMLoc loc = curTok.getLoc(); // Check for a floating point value. if (curTok.is(Token::floatliteral)) { auto val = curTok.getFloatingPointValue(); if (!val.hasValue()) - return emitError(curTok.getLoc(), "floating point value too large"); + return emitError(loc, "floating point value too large"); parser.consumeToken(Token::floatliteral); - result = negative ? -*val : *val; + result = isNegative ? -*val : *val; return success(); } - // TODO: support hex floating point values. - return emitError(getCurrentLocation(), "expected floating point literal"); + // Check for a hexadecimal float value. + if (curTok.is(Token::integer)) { + Optional apResult; + if (failed(parser.parseFloatFromIntegerLiteral( + apResult, curTok, isNegative, APFloat::IEEEdouble(), + /*typeSizeInBits=*/64))) + return failure(); + + parser.consumeToken(Token::integer); + result = apResult->convertToDouble(); + return success(); + } + + return emitError(loc, "expected floating point literal"); } /// Parse an optional integer value from the stream. diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -130,6 +130,12 @@ /// Parse an optional integer value from the stream. OptionalParseResult parseOptionalInteger(uint64_t &result); + /// Parse a floating point value from an integer literal token. + ParseResult parseFloatFromIntegerLiteral(Optional &result, + const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics, + size_t typeSizeInBits); + //===--------------------------------------------------------------------===// // Type Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -112,6 +112,41 @@ return success(); } +/// Parse a floating point value from an integer literal token. +ParseResult Parser::parseFloatFromIntegerLiteral( + Optional &result, const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics, size_t typeSizeInBits) { + llvm::SMLoc loc = tok.getLoc(); + StringRef spelling = tok.getSpelling(); + bool isHex = spelling.size() > 1 && spelling[1] == 'x'; + if (!isHex) { + return emitError(loc, "unexpected decimal integer literal for a " + "floating point value") + .attachNote() + << "add a trailing dot to make the literal a float"; + } + if (isNegative) { + return emitError(loc, "hexadecimal float literal should not have a " + "leading minus"); + } + + Optional value = tok.getUInt64IntegerValue(); + if (!value.hasValue()) + return emitError(loc, "hexadecimal float constant out of range for type"); + + if (&semantics == &APFloat::IEEEdouble()) { + result = APFloat(semantics, APInt(typeSizeInBits, *value)); + return success(); + } + + APInt apInt(typeSizeInBits, *value); + if (apInt != *value) + return emitError(loc, "hexadecimal float constant out of range for type"); + result = APFloat(semantics, apInt); + + return success(); +} + //===----------------------------------------------------------------------===// // OperationParser //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Quant/parse-uniform.mlir b/mlir/test/Dialect/Quant/parse-uniform.mlir --- a/mlir/test/Dialect/Quant/parse-uniform.mlir +++ b/mlir/test/Dialect/Quant/parse-uniform.mlir @@ -92,6 +92,15 @@ return %0 : !qalias } +// ----- +// Expressed type: f32 +// CHECK: !quant.uniform +!qalias = type !quant.uniform +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + // ----- // Expressed type: f16 // CHECK: !quant.uniform diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -1191,7 +1191,7 @@ // ----- func @decimal_float_literal() { - // expected-error @+2 {{unexpected decimal integer literal for a float attribute}} + // expected-error @+2 {{unexpected decimal integer literal for a floating point value}} // expected-note @+1 {{add a trailing dot to make the literal a float}} "foo"() {value = 42 : f32} : () -> () } @@ -1244,7 +1244,7 @@ // Check that we report an error when a value is too wide to be parsed. func @hexadecimal_float_too_wide_in_tensor() { - // expected-error @+1 {{hexadecimal float constant out of range for attribute}} + // expected-error @+1 {{hexadecimal float constant out of range for type}} "foo"() {bar = dense<0x7FFFFFF0000000000000> : tensor<2xf32>} : () -> () }