diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -676,7 +676,8 @@ static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt) { // Make sure that the data element size is the same as the type element width. - if ((dataEltSize * CHAR_BIT) != type.getElementTypeBitWidth()) + if (getDenseElementBitwidth(type.getElementType()) != + static_cast(dataEltSize * CHAR_BIT)) return false; // Check that the element type is valid. diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -135,4 +135,14 @@ testSplat(floatTy, value); } + +TEST(DenseSplatTest, BF16Splat) { + MLIRContext context; + FloatType floatTy = FloatType::getBF16(&context); + // Note: We currently use double to represent bfloat16. + double value = 10.0; + + testSplat(floatTy, value); +} + } // end namespace