diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -355,6 +355,10 @@ MlirType shapedType, intptr_t numElements, const uint8_t *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt8Get( MlirType shapedType, intptr_t numElements, const int8_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt16Get( + MlirType shapedType, intptr_t numElements, const uint16_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt16Get( + MlirType shapedType, intptr_t numElements, const int16_t *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt32Get( MlirType shapedType, intptr_t numElements, const uint32_t *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt32Get( @@ -416,6 +420,10 @@ intptr_t pos); MLIR_CAPI_EXPORTED uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED int16_t +mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED uint16_t +mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED uint32_t diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -673,6 +673,12 @@ if (width == 1) { return mlirDenseElementsAttrGetBoolValue(*this, pos); } + if (width == 8) { + return mlirDenseElementsAttrGetUInt8Value(*this, pos); + } + if (width == 16) { + return mlirDenseElementsAttrGetUInt16Value(*this, pos); + } if (width == 32) { return mlirDenseElementsAttrGetUInt32Value(*this, pos); } @@ -683,6 +689,12 @@ if (width == 1) { return mlirDenseElementsAttrGetBoolValue(*this, pos); } + if (width == 8) { + return mlirDenseElementsAttrGetInt8Value(*this, pos); + } + if (width == 16) { + return mlirDenseElementsAttrGetInt16Value(*this, pos); + } if (width == 32) { return mlirDenseElementsAttrGetInt32Value(*this, pos); } diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -426,6 +426,16 @@ const int8_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } +MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType, + intptr_t numElements, + const uint16_t *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} +MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType, + intptr_t numElements, + const int16_t *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType, intptr_t numElements, const uint32_t *elements) { @@ -530,6 +540,12 @@ uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { return unwrap(attr).cast().getValues()[pos]; } +int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast().getValues()[pos]; +} +uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast().getValues()[pos]; +} int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { return unwrap(attr).cast().getValues()[pos]; } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -904,6 +904,8 @@ int bools[] = {0, 1}; uint8_t uints8[] = {0u, 1u}; int8_t ints8[] = {0, 1}; + uint16_t uints16[] = {0u, 1u}; + int16_t ints16[] = {0, 1}; uint32_t uints32[] = {0u, 1u}; int32_t ints32[] = {0, 1}; uint64_t uints64[] = {0u, 1u}; @@ -921,6 +923,13 @@ MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get( mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding), 2, ints8); + MlirAttribute uint16Elements = mlirDenseElementsAttrUInt16Get( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 16), + encoding), + 2, uints16); + MlirAttribute int16Elements = mlirDenseElementsAttrInt16Get( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 16), encoding), + 2, ints16); MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get( mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32), encoding), @@ -956,6 +965,8 @@ if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 || mlirDenseElementsAttrGetUInt8Value(uint8Elements, 1) != 1 || mlirDenseElementsAttrGetInt8Value(int8Elements, 1) != 1 || + mlirDenseElementsAttrGetUInt16Value(uint16Elements, 1) != 1 || + mlirDenseElementsAttrGetInt16Value(int16Elements, 1) != 1 || mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 || mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 || mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 || diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -292,6 +292,50 @@ print(ShapedType(a.type).element_type) +# CHECK-LABEL: TEST: testDenseIntAttrGetItem +@run +def testDenseIntAttrGetItem(): + def print_item(attr_asm): + attr = DenseIntElementsAttr(Attribute.parse(attr_asm)) + dtype = ShapedType(attr.type).element_type + try: + item = attr[0] + print(f"{dtype}:", item) + except TypeError as e: + print(f"{dtype}:", e) + + with Context(): + # CHECK: i1: 1 + print_item("dense : tensor") + # CHECK: i8: 123 + print_item("dense<123> : tensor") + # CHECK: i16: 123 + print_item("dense<123> : tensor") + # CHECK: i32: 123 + print_item("dense<123> : tensor") + # CHECK: i64: 123 + print_item("dense<123> : tensor") + # CHECK: ui8: 123 + print_item("dense<123> : tensor") + # CHECK: ui16: 123 + print_item("dense<123> : tensor") + # CHECK: ui32: 123 + print_item("dense<123> : tensor") + # CHECK: ui64: 123 + print_item("dense<123> : tensor") + # CHECK: si8: -123 + print_item("dense<-123> : tensor") + # CHECK: si16: -123 + print_item("dense<-123> : tensor") + # CHECK: si32: -123 + print_item("dense<-123> : tensor") + # CHECK: si64: -123 + print_item("dense<-123> : tensor") + + # CHECK: i7: Unsupported integer type + print_item("dense<123> : tensor") + + # CHECK-LABEL: TEST: testDenseFPAttr @run def testDenseFPAttr():