diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -132,9 +132,7 @@ /// Construct a key with a type and double. static KeyTy getKey(Type type, double value) { - // Treat BF16 as double because it is not supported in LLVM's APFloat. - // TODO(b/121118307): add BF16 support to APFloat? - if (type.isBF16() || type.isF64()) + if (type.isF64()) return KeyTy(type, APFloat(value)); // This handles, e.g., F16 because there is no APFloat constructor for it. @@ -355,10 +353,6 @@ // Align the width for complex to 8 to make storage and interpretation easier. if (ComplexType comp = eltType.dyn_cast()) return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2; - // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored - // with double semantics. - if (eltType.isBF16()) - return 64; if (eltType.isIndex()) return IndexType::kInternalStorageBitWidth; return eltType.getIntOrFloatBitWidth(); diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -157,12 +157,7 @@ /// Returns the floating semantics for the given type. const llvm::fltSemantics &FloatType::getFloatSemantics() { if (isBF16()) - // Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is - // not defined in LLVM. - // TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc - // else one could add it. - // static const fltSemantics semBF16 = {127, -126, 8, 16}; - return APFloat::IEEEdouble(); + return APFloat::BFloat(); if (isF16()) return APFloat::IEEEhalf(); if (isF32()) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1774,9 +1774,7 @@ /// Construct a float attribute bitwise equivalent to the integer literal. static Optional buildHexadecimalFloatLiteral(Parser *p, FloatType type, uint64_t value) { - // FIXME: bfloat is currently stored as a double internally because it doesn't - // have valid APFloat semantics. - if (type.isF64() || type.isBF16()) + if (type.isF64()) return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value)); APInt apInt(type.getWidth(), value); @@ -2153,9 +2151,8 @@ if (!val.hasValue()) return p.emitError("floating point value too large for attribute"); - // Treat BF16 as double because it is not supported in LLVM's APFloat. APFloat apVal(isNegative ? -*val : *val); - if (!eltTy.isBF16() && !eltTy.isF64()) { + if (!eltTy.isF64()) { bool unused; apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, &unused); diff --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir --- a/mlir/test/IR/dense-elements-hex.mlir +++ b/mlir/test/IR/dense-elements-hex.mlir @@ -11,7 +11,7 @@ "foo.op"() {dense.attr = dense<"0x0000000000002440000000000000144000000000000024400000000000001440"> : tensor<2xcomplex>} : () -> () // CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xbf16> -"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xbf16>} : () -> () +"foo.op"() {dense.attr = dense<"0x2041A040"> : tensor<2xbf16>} : () -> () // ----- diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1073,28 +1073,26 @@ return } -// FIXME: bfloat16 currently uses f64 as a storage format. This test should be -// changed when that gets fixed. // CHECK-LABEL: @bfloat16_special_values func @bfloat16_special_values() { // bfloat16 signaling NaNs. - // CHECK: constant 0x7FF0000000000001 : bf16 - %0 = constant 0x7FF0000000000001 : bf16 - // CHECK: constant 0x7FF8000000000000 : bf16 - %1 = constant 0x7FF8000000000000 : bf16 + // CHECK: constant 0x7F81 : bf16 + %0 = constant 0x7F81 : bf16 + // CHECK: constant 0xFF81 : bf16 + %1 = constant 0xFF81 : bf16 // bfloat16 quiet NaNs. - // CHECK: constant 0x7FF0000001000000 : bf16 - %2 = constant 0x7FF0000001000000 : bf16 - // CHECK: constant 0xFFF0000001000000 : bf16 - %3 = constant 0xFFF0000001000000 : bf16 + // CHECK: constant 0x7FC0 : bf16 + %2 = constant 0x7FC0 : bf16 + // CHECK: constant 0xFFC0 : bf16 + %3 = constant 0xFFC0 : bf16 // bfloat16 positive infinity. - // CHECK: constant 0x7FF0000000000000 : bf16 - %4 = constant 0x7FF0000000000000 : bf16 + // CHECK: constant 0x7F80 : bf16 + %4 = constant 0x7F80 : bf16 // bfloat16 negative infinity. - // CHECK: constant 0xFFF0000000000000 : bf16 - %5 = constant 0xFFF0000000000000 : bf16 + // CHECK: constant 0xFF80 : bf16 + %5 = constant 0xFF80 : bf16 return } @@ -1215,12 +1213,12 @@ %x = test.string_attr_pretty_name // CHECK: %x = test.string_attr_pretty_name // CHECK-NOT: attributes - + // This specifies an explicit name, which should override the result. %YY = test.string_attr_pretty_name attributes { names = ["y"] } // CHECK: %y = test.string_attr_pretty_name // CHECK-NOT: attributes - + // Conflicts with the 'y' name, so need an explicit attribute. %0 = "test.string_attr_pretty_name"() { names = ["y"]} : () -> i32 // CHECK: %y_0 = test.string_attr_pretty_name attributes {names = ["y"]} diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s // CHECK: @i32_global = internal global i32 42 llvm.mlir.global internal @i32_global(42: i32) : !llvm.i32 @@ -1214,3 +1214,14 @@ // CHECK-DAG: alignstack=4 // CHECK-DAG: null_pointer_is_valid // CHECK-DAG: "foo"="bar" + +// ----- + +// CHECK-LABEL: @constant_bf16 +llvm.func @constant_bf16() -> !llvm<"bfloat"> { + %0 = llvm.mlir.constant(1.000000e+01 : bf16) : !llvm<"bfloat"> + llvm.return %0 : !llvm<"bfloat"> +} + +// CHECK: ret bfloat 0xR4120 + diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -134,7 +134,7 @@ TEST(DenseSplatTest, FloatAttrSplat) { MLIRContext context; - FloatType floatTy = FloatType::getBF16(&context); + FloatType floatTy = FloatType::getF32(&context); Attribute value = FloatAttr::get(floatTy, 10.0); testSplat(floatTy, value); @@ -143,8 +143,7 @@ TEST(DenseSplatTest, BF16Splat) { MLIRContext context; FloatType floatTy = FloatType::getBF16(&context); - // Note: We currently use double to represent bfloat16. - double value = 10.0; + Attribute value = FloatAttr::get(floatTy, 10.0); testSplat(floatTy, value); }