diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1734,22 +1734,19 @@ } /// Construct a float attribute bitwise equivalent to the integer literal. -static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type, - uint64_t value) { +static Optional buildHexadecimalFloatLiteral(Parser *p, FloatType type, + uint64_t value) { // FIXME: bfloat is currently stored as a double internally because it doesn't // have valid APFloat semantics. - if (type.isF64() || type.isBF16()) { - APFloat apFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value)); - return p->builder.getFloatAttr(type, apFloat); - } + if (type.isF64() || type.isBF16()) + return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value)); APInt apInt(type.getWidth(), value); if (apInt != value) { p->emitError("hexadecimal float constant out of range for type"); - return nullptr; + return llvm::None; } - APFloat apFloat(type.getFloatSemantics(), apInt); - return p->builder.getFloatAttr(type, apFloat); + return APFloat(type.getFloatSemantics(), apInt); } /// Parse a decimal or a hexadecimal literal, which can be either an integer @@ -1787,7 +1784,9 @@ } // Construct a float attribute bitwise equivalent to the integer literal. - return buildHexadecimalFloatLiteral(this, floatType, *val); + Optional apVal = + buildHexadecimalFloatLiteral(this, floatType, *val); + return apVal ? FloatAttr::get(floatType, *apVal) : Attribute(); } if (!type.isIntOrIndex()) @@ -1996,7 +1995,7 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, ShapedType type, FloatType eltTy) { - std::vector floatValues; + std::vector floatValues; floatValues.reserve(storage.size()); for (const auto &signAndToken : storage) { bool isNegative = signAndToken.first; @@ -2014,10 +2013,10 @@ p.emitError("hexadecimal float constant out of range for attribute"); return nullptr; } - FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val); - if (!attr) + Optional apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val); + if (!apVal) return nullptr; - floatValues.push_back(attr); + floatValues.push_back(*apVal); continue; } @@ -2033,7 +2032,14 @@ p.emitError("floating point value too large for attribute"); return nullptr; } - floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val)); + // Treat BF16 as double because it is not supported in LLVM's APFloat. + APFloat apVal(isNegative ? -*val : *val); + if (!eltTy.isBF16() && !eltTy.isF64()) { + bool unused; + apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, + &unused); + } + floatValues.push_back(apVal); } return DenseElementsAttr::get(type, floatValues);