diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -125,9 +125,17 @@ int64_t value); /// Returns the value stored in the given integer attribute, assuming the value -/// fits into a 64-bit integer. +/// is of signless type and fits into a signed 64-bit integer. MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr); +/// Returns the value stored in the given integer attribute, assuming the value +/// is of signed type and fits into a signed 64-bit integer. +MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr); + +/// Returns the value stored in the given integer attribute, assuming the value +/// is of unsigned type and fits into an unsigned 64-bit integer. +MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr); + //===----------------------------------------------------------------------===// // Bool attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -258,8 +258,13 @@ "Gets an uniqued integer attribute associated to a type"); c.def_property_readonly( "value", - [](PyIntegerAttribute &self) { - return mlirIntegerAttrGetValueInt(self); + [](PyIntegerAttribute &self) -> py::int_ { + MlirType type = mlirAttributeGetType(self); + if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) + return mlirIntegerAttrGetValueInt(self); + if (mlirIntegerTypeIsSigned(type)) + return mlirIntegerAttrGetValueSInt(self); + return mlirIntegerAttrGetValueUInt(self); }, "Returns the value of the integer attribute"); } diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -129,6 +129,14 @@ return unwrap(attr).cast().getInt(); } +int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) { + return unwrap(attr).cast().getSInt(); +} + +uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) { + return unwrap(attr).cast().getUInt(); +} + //===----------------------------------------------------------------------===// // Bool attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -813,11 +813,21 @@ // CHECK: f64 MlirAttribute integer = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 42); + MlirAttribute signedInteger = + mlirIntegerAttrGet(mlirIntegerTypeSignedGet(ctx, 8), -1); + MlirAttribute unsignedInteger = + mlirIntegerAttrGet(mlirIntegerTypeUnsignedGet(ctx, 8), 255); if (!mlirAttributeIsAInteger(integer) || - mlirIntegerAttrGetValueInt(integer) != 42) + mlirIntegerAttrGetValueInt(integer) != 42 || + mlirIntegerAttrGetValueSInt(signedInteger) != -1 || + mlirIntegerAttrGetValueUInt(unsignedInteger) != 255) return 2; mlirAttributeDump(integer); + mlirAttributeDump(signedInteger); + mlirAttributeDump(unsignedInteger); // CHECK: 42 : i32 + // CHECK: -1 : si8 + // CHECK: 255 : ui8 MlirAttribute boolean = mlirBoolAttrGet(ctx, 1); if (!mlirAttributeIsABool(boolean) || !mlirBoolAttrGetValue(boolean)) diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -189,11 +189,20 @@ @run def testIntegerAttr(): with Context() as ctx: - iattr = IntegerAttr(Attribute.parse("42")) - # CHECK: iattr value: 42 - print("iattr value:", iattr.value) - # CHECK: iattr type: i64 - print("iattr type:", iattr.type) + i_attr = IntegerAttr(Attribute.parse("42")) + # CHECK: i_attr value: 42 + print("i_attr value:", i_attr.value) + # CHECK: i_attr type: i64 + print("i_attr type:", i_attr.type) + si_attr = IntegerAttr(Attribute.parse("-1 : si8")) + # CHECK: si_attr value: -1 + print("si_attr value:", si_attr.value) + ui_attr = IntegerAttr(Attribute.parse("255 : ui8")) + # CHECK: ui_attr value: 255 + print("ui_attr value:", ui_attr.value) + idx_attr = IntegerAttr(Attribute.parse("-1 : index")) + # CHECK: idx_attr value: -1 + print("idx_attr value:", idx_attr.value) # Test factory methods. # CHECK: default_get: 42 : i32