diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -379,6 +379,8 @@ MlirType shapedType, intptr_t numElements, const float *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrDoubleGet( MlirType shapedType, intptr_t numElements, const double *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBFloat16Get( + MlirType shapedType, intptr_t numElements, const uint16_t *elements); /// Creates a dense elements attribute with the given shaped type from string /// elements. diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -474,6 +474,13 @@ const double *elements) { return getDenseAttribute(shapedType, numElements, elements); } +MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType, + intptr_t numElements, + const uint16_t *elements) { + size_t bufferSize = numElements * 2; + const void *buffer = static_cast(elements); + return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer); +} MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, intptr_t numElements, diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -936,6 +936,7 @@ int64_t ints64[] = {0, 1}; float floats[] = {0.0f, 1.0f}; double doubles[] = {0.0, 1.0}; + uint16_t bf16s[] = {0x0, 0x3f80}; MlirAttribute encoding = mlirAttributeGetNull(); MlirAttribute boolElements = mlirDenseElementsAttrBoolGet( mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding), @@ -974,6 +975,9 @@ MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet( mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 2, doubles); + MlirAttribute bf16Elements = mlirDenseElementsAttrBFloat16Get( + mlirRankedTensorTypeGet(2, shape, mlirBF16TypeGet(ctx), encoding), 2, + bf16s); if (!mlirAttributeIsADenseElements(boolElements) || !mlirAttributeIsADenseElements(uint8Elements) || @@ -983,7 +987,8 @@ !mlirAttributeIsADenseElements(uint64Elements) || !mlirAttributeIsADenseElements(int64Elements) || !mlirAttributeIsADenseElements(floatElements) || - !mlirAttributeIsADenseElements(doubleElements)) + !mlirAttributeIsADenseElements(doubleElements) || + !mlirAttributeIsADenseElements(bf16Elements)) return 14; if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 || @@ -1009,6 +1014,7 @@ mlirAttributeDump(int64Elements); mlirAttributeDump(floatElements); mlirAttributeDump(doubleElements); + mlirAttributeDump(bf16Elements); // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1> // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8> // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8> @@ -1018,6 +1024,7 @@ // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64> // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32> // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64> + // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xbf16> MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet( mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding), @@ -1094,12 +1101,15 @@ float *floatRawData = (float *)mlirDenseElementsAttrGetRawData(floatElements); double *doubleRawData = (double *)mlirDenseElementsAttrGetRawData(doubleElements); + uint16_t *bf16RawData = + (uint16_t *)mlirDenseElementsAttrGetRawData(bf16Elements); if (uint8RawData[0] != 0u || uint8RawData[1] != 1u || int8RawData[0] != 0 || int8RawData[1] != 1 || uint32RawData[0] != 0u || uint32RawData[1] != 1u || int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u || uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 || floatRawData[0] != 0.0f || floatRawData[1] != 1.0f || - doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0) + doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0 || + bf16RawData[0] != 0 || bf16RawData[1] != 0x3f80) return 18; mlirAttributeDump(splatBool); @@ -1123,8 +1133,10 @@ mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64)); mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64)); + mlirAttributeDump(mlirElementsAttrGetValue(bf16Elements, 2, uints64)); // CHECK: 1.000000e+00 : f32 // CHECK: 1.000000e+00 : f64 + // CHECK: 1.000000e+00 : bf16 int64_t indices[] = {0, 1}; int64_t one = 1;