diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -381,6 +381,8 @@ MlirType shapedType, intptr_t numElements, const double *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBFloat16Get( MlirType shapedType, intptr_t numElements, const uint16_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrFloat16Get( + MlirType shapedType, intptr_t numElements, const uint16_t *elements); /// Creates a dense elements attribute with the given shaped type from string /// elements. diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -479,6 +479,13 @@ const void *buffer = static_cast(elements); return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer); } +MlirAttribute mlirDenseElementsAttrFloat16Get(MlirType shapedType, + intptr_t numElements, + const uint16_t *elements) { + size_t bufferSize = numElements * 2; + const void *buffer = static_cast(elements); + return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer); +} MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, intptr_t numElements, diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -959,6 +959,7 @@ float floats[] = {0.0f, 1.0f}; double doubles[] = {0.0, 1.0}; uint16_t bf16s[] = {0x0, 0x3f80}; + uint16_t f16s[] = {0x0, 0x3c00}; MlirAttribute encoding = mlirAttributeGetNull(); MlirAttribute boolElements = mlirDenseElementsAttrBoolGet( mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding), @@ -1000,6 +1001,9 @@ MlirAttribute bf16Elements = mlirDenseElementsAttrBFloat16Get( mlirRankedTensorTypeGet(2, shape, mlirBF16TypeGet(ctx), encoding), 2, bf16s); + MlirAttribute f16Elements = mlirDenseElementsAttrFloat16Get( + mlirRankedTensorTypeGet(2, shape, mlirF16TypeGet(ctx), encoding), 2, + f16s); if (!mlirAttributeIsADenseElements(boolElements) || !mlirAttributeIsADenseElements(uint8Elements) || @@ -1010,7 +1014,8 @@ !mlirAttributeIsADenseElements(int64Elements) || !mlirAttributeIsADenseElements(floatElements) || !mlirAttributeIsADenseElements(doubleElements) || - !mlirAttributeIsADenseElements(bf16Elements)) + !mlirAttributeIsADenseElements(bf16Elements) || + !mlirAttributeIsADenseElements(f16Elements)) return 14; if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 || @@ -1037,6 +1042,7 @@ mlirAttributeDump(floatElements); mlirAttributeDump(doubleElements); mlirAttributeDump(bf16Elements); + mlirAttributeDump(f16Elements); // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1> // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8> // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8> @@ -1047,6 +1053,7 @@ // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32> // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64> // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xbf16> + // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf16> MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet( mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding), @@ -1125,13 +1132,16 @@ (double *)mlirDenseElementsAttrGetRawData(doubleElements); uint16_t *bf16RawData = (uint16_t *)mlirDenseElementsAttrGetRawData(bf16Elements); + uint16_t *f16RawData = + (uint16_t *)mlirDenseElementsAttrGetRawData(f16Elements); if (uint8RawData[0] != 0u || uint8RawData[1] != 1u || int8RawData[0] != 0 || int8RawData[1] != 1 || uint32RawData[0] != 0u || uint32RawData[1] != 1u || int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u || uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 || floatRawData[0] != 0.0f || floatRawData[1] != 1.0f || doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0 || - bf16RawData[0] != 0 || bf16RawData[1] != 0x3f80) + bf16RawData[0] != 0 || bf16RawData[1] != 0x3f80 || f16RawData[0] != 0 || + f16RawData[1] != 0x3c00) return 18; mlirAttributeDump(splatBool); @@ -1156,9 +1166,11 @@ mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64)); mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64)); mlirAttributeDump(mlirElementsAttrGetValue(bf16Elements, 2, uints64)); + mlirAttributeDump(mlirElementsAttrGetValue(f16Elements, 2, uints64)); // CHECK: 1.000000e+00 : f32 // CHECK: 1.000000e+00 : f64 // CHECK: 1.000000e+00 : bf16 + // CHECK: 1.000000e+00 : f16 int64_t indices[] = {0, 1}; int64_t one = 1;