diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -313,6 +313,10 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, bool element); MLIR_CAPI_EXPORTED MlirAttribute +mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType, uint8_t element); +MLIR_CAPI_EXPORTED MlirAttribute +mlirDenseElementsAttrInt8SplatGet(MlirType shapedType, int8_t element); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, uint32_t element); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, int32_t element); @@ -330,6 +334,10 @@ /// data element type. MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBoolGet( MlirType shapedType, intptr_t numElements, const int *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt8Get( + MlirType shapedType, intptr_t numElements, const uint8_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt8Get( + MlirType shapedType, intptr_t numElements, const int8_t *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt32Get( MlirType shapedType, intptr_t numElements, const uint32_t *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt32Get( @@ -364,6 +372,10 @@ mlirDenseElementsAttrGetSplatValue(MlirAttribute attr); MLIR_CAPI_EXPORTED int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr); +MLIR_CAPI_EXPORTED int8_t +mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr); +MLIR_CAPI_EXPORTED uint8_t +mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr); MLIR_CAPI_EXPORTED int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr); MLIR_CAPI_EXPORTED uint32_t @@ -383,6 +395,10 @@ /// contained by the given dense elements attribute. MLIR_CAPI_EXPORTED bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED uint8_t +mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED uint32_t diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -341,6 +341,16 @@ return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), element)); } +MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType, + uint8_t element) { + return wrap( + DenseElementsAttr::get(unwrap(shapedType).cast(), element)); +} +MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType, + int8_t element) { + return wrap( + DenseElementsAttr::get(unwrap(shapedType).cast(), element)); +} MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, uint32_t element) { return wrap( @@ -390,6 +400,16 @@ llvm::makeArrayRef(elements, numElements))); } +MlirAttribute mlirDenseElementsAttrUInt8Get(MlirType shapedType, + intptr_t numElements, + const uint8_t *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} +MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType, + intptr_t numElements, + const int8_t *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType, intptr_t numElements, const uint32_t *elements) { @@ -452,6 +472,12 @@ int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } +int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) { + return unwrap(attr).cast().getSplatValue(); +} +uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) { + return unwrap(attr).cast().getSplatValue(); +} int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } @@ -482,6 +508,14 @@ return *(unwrap(attr).cast().getValues().begin() + pos); } +int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { + return *(unwrap(attr).cast().getValues().begin() + + pos); +} +uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { + return *(unwrap(attr).cast().getValues().begin() + + pos); +} int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { return *(unwrap(attr).cast().getValues().begin() + pos); diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -886,6 +886,8 @@ int64_t shape[] = {1, 2}; int bools[] = {0, 1}; + uint8_t uints8[] = {0u, 1u}; + int8_t ints8[] = {0, 1}; uint32_t uints32[] = {0u, 1u}; int32_t ints32[] = {0, 1}; uint64_t uints64[] = {0u, 1u}; @@ -896,6 +898,13 @@ MlirAttribute boolElements = mlirDenseElementsAttrBoolGet( mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding), 2, bools); + MlirAttribute uint8Elements = mlirDenseElementsAttrUInt8Get( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8), + encoding), + 2, uints8); + MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding), + 2, ints8); MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get( mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32), encoding), @@ -918,6 +927,8 @@ 2, doubles); if (!mlirAttributeIsADenseElements(boolElements) || + !mlirAttributeIsADenseElements(uint8Elements) || + !mlirAttributeIsADenseElements(int8Elements) || !mlirAttributeIsADenseElements(uint32Elements) || !mlirAttributeIsADenseElements(int32Elements) || !mlirAttributeIsADenseElements(uint64Elements) || @@ -927,6 +938,8 @@ return 14; if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 || + mlirDenseElementsAttrGetUInt8Value(uint8Elements, 1) != 1 || + mlirDenseElementsAttrGetInt8Value(int8Elements, 1) != 1 || mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 || mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 || mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 || @@ -937,6 +950,8 @@ return 15; mlirAttributeDump(boolElements); + mlirAttributeDump(uint8Elements); + mlirAttributeDump(int8Elements); mlirAttributeDump(uint32Elements); mlirAttributeDump(int32Elements); mlirAttributeDump(uint64Elements); @@ -944,6 +959,8 @@ mlirAttributeDump(floatElements); mlirAttributeDump(doubleElements); // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1> + // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8> + // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8> // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32> // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32> // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64> @@ -952,20 +969,29 @@ // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64> MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet( - mlirRankedTensorTypeGet(2, shape, - mlirIntegerTypeGet(ctx, 1), encoding), 1); + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding), + 1); + MlirAttribute splatUInt8 = mlirDenseElementsAttrUInt8SplatGet( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8), + encoding), + 1); + MlirAttribute splatInt8 = mlirDenseElementsAttrInt8SplatGet( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding), + 1); MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet( - mlirRankedTensorTypeGet(2, shape, - mlirIntegerTypeGet(ctx, 32), encoding), 1); + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32), + encoding), + 1); MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet( - mlirRankedTensorTypeGet(2, shape, - mlirIntegerTypeGet(ctx, 32), encoding), 1); + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding), + 1); MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet( - mlirRankedTensorTypeGet(2, shape, - mlirIntegerTypeGet(ctx, 64), encoding), 1); + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64), + encoding), + 1); MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet( - mlirRankedTensorTypeGet(2, shape, - mlirIntegerTypeGet(ctx, 64), encoding), 1); + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding), + 1); MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet( mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 1.0f); MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet( @@ -973,6 +999,10 @@ if (!mlirAttributeIsADenseElements(splatBool) || !mlirDenseElementsAttrIsSplat(splatBool) || + !mlirAttributeIsADenseElements(splatUInt8) || + !mlirDenseElementsAttrIsSplat(splatUInt8) || + !mlirAttributeIsADenseElements(splatInt8) || + !mlirDenseElementsAttrIsSplat(splatInt8) || !mlirAttributeIsADenseElements(splatUInt32) || !mlirDenseElementsAttrIsSplat(splatUInt32) || !mlirAttributeIsADenseElements(splatInt32) || @@ -988,6 +1018,8 @@ return 16; if (mlirDenseElementsAttrGetBoolSplatValue(splatBool) != 1 || + mlirDenseElementsAttrGetUInt8SplatValue(splatUInt8) != 1 || + mlirDenseElementsAttrGetInt8SplatValue(splatInt8) != 1 || mlirDenseElementsAttrGetUInt32SplatValue(splatUInt32) != 1 || mlirDenseElementsAttrGetInt32SplatValue(splatInt32) != 1 || mlirDenseElementsAttrGetUInt64SplatValue(splatUInt64) != 1 || @@ -997,6 +1029,9 @@ fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6) return 17; + uint8_t *uint8RawData = + (uint8_t *)mlirDenseElementsAttrGetRawData(uint8Elements); + int8_t *int8RawData = (int8_t *)mlirDenseElementsAttrGetRawData(int8Elements); uint32_t *uint32RawData = (uint32_t *)mlirDenseElementsAttrGetRawData(uint32Elements); int32_t *int32RawData = @@ -1008,7 +1043,8 @@ float *floatRawData = (float *)mlirDenseElementsAttrGetRawData(floatElements); double *doubleRawData = (double *)mlirDenseElementsAttrGetRawData(doubleElements); - if (uint32RawData[0] != 0u || uint32RawData[1] != 1u || + if (uint8RawData[0] != 0u || uint8RawData[1] != 1u || int8RawData[0] != 0 || + int8RawData[1] != 1 || uint32RawData[0] != 0u || uint32RawData[1] != 1u || int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u || uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 || floatRawData[0] != 0.0f || floatRawData[1] != 1.0f || @@ -1016,6 +1052,8 @@ return 18; mlirAttributeDump(splatBool); + mlirAttributeDump(splatUInt8); + mlirAttributeDump(splatInt8); mlirAttributeDump(splatUInt32); mlirAttributeDump(splatInt32); mlirAttributeDump(splatUInt64); @@ -1023,9 +1061,11 @@ mlirAttributeDump(splatFloat); mlirAttributeDump(splatDouble); // CHECK: dense : tensor<1x2xi1> + // CHECK: dense<1> : tensor<1x2xui8> + // CHECK: dense<1> : tensor<1x2xi8> + // CHECK: dense<1> : tensor<1x2xui32> // CHECK: dense<1> : tensor<1x2xi32> - // CHECK: dense<1> : tensor<1x2xi32> - // CHECK: dense<1> : tensor<1x2xi64> + // CHECK: dense<1> : tensor<1x2xui64> // CHECK: dense<1> : tensor<1x2xi64> // CHECK: dense<1.000000e+00> : tensor<1x2xf32> // CHECK: dense<1.000000e+00> : tensor<1x2xf64>