diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2118,8 +2118,12 @@ // Check that the size of the hex data correpsonds to the size of the type, or // a splat of the type. + // TODO: bf16 is currently stored as a double, this should be removed when + // APFloat properly supports it. + int64_t elementWidth = + elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth(); if (static_cast(data.size() * CHAR_BIT) != - (type.getNumElements() * elementType.getIntOrFloatBitWidth())) { + (type.getNumElements() * elementWidth)) { p.emitError(loc) << "elements hex data size is invalid for provided type: " << type; return nullptr; diff --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir --- a/mlir/test/IR/dense-elements-hex.mlir +++ b/mlir/test/IR/dense-elements-hex.mlir @@ -7,6 +7,9 @@ // CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xf64> "foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xf64>} : () -> () +// CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xbf16> +"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xbf16>} : () -> () + // ----- // expected-error@+1 {{elements hex string should start with '0x'}}