diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1687,8 +1687,14 @@ /// Construct a float attribute bitwise equivalent to the integer literal. static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type, uint64_t value) { - int width = type.getIntOrFloatBitWidth(); - APInt apInt(width, value); + // FIXME: bfloat is currently stored as a double internally because it doesn't + // have valid APFloat semantics. + if (type.isF64() || type.isBF16()) { + APFloat apFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value)); + return p->builder.getFloatAttr(type, apFloat); + } + + APInt apInt(type.getWidth(), value); if (apInt != value) { p->emitError("hexadecimal float constant out of range for type"); return nullptr; @@ -1719,11 +1725,6 @@ } if (auto floatType = type.dyn_cast()) { - // TODO(zinenko): Update once hex format for bfloat16 is supported. - if (type.isBF16()) - return emitError(loc, - "hexadecimal float literal not supported for bfloat16"), - nullptr; if (isNegative) return emitError( loc, diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -1120,13 +1120,6 @@ // ----- -func @hexadecimal_bf16() { - // expected-error @+1 {{hexadecimal float literal not supported for bfloat16}} - "foo"() {value = 0xffff : bf16} : () -> () -} - -// ----- - func @hexadecimal_float_leading_minus() { // expected-error @+1 {{hexadecimal float literal should not have a leading minus}} "foo"() {value = -0x7fff : f16} : () -> () diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1030,6 +1030,32 @@ return } +// FIXME: bfloat16 currently uses f64 as a storage format. This test should be +// changed when that gets fixed. +// CHECK-LABEL: @bfloat16_special_values +func @bfloat16_special_values() { + // bfloat16 signaling NaNs. + // CHECK: constant 0x7FF0000000000001 : bf16 + %0 = constant 0x7FF0000000000001 : bf16 + // CHECK: constant 0x7FF8000000000000 : bf16 + %1 = constant 0x7FF8000000000000 : bf16 + + // bfloat16 quiet NaNs. + // CHECK: constant 0x7FF0000001000000 : bf16 + %2 = constant 0x7FF0000001000000 : bf16 + // CHECK: constant 0xFFF0000001000000 : bf16 + %3 = constant 0xFFF0000001000000 : bf16 + + // bfloat16 positive infinity. + // CHECK: constant 0x7FF0000000000000 : bf16 + %4 = constant 0x7FF0000000000000 : bf16 + // bfloat16 negative infinity. + // CHECK: constant 0xFFF0000000000000 : bf16 + %5 = constant 0xFFF0000000000000 : bf16 + + return +} + // We want to print floats in exponential notation with 6 significant digits, // but it may lead to precision loss when parsing back, in which case we print // the decimal form instead.