diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -389,12 +389,10 @@ }, py::arg("value"), py::arg("context") = py::none(), "Gets an uniqued float point attribute associated to a f64 type"); - c.def_property_readonly( - "value", - [](PyFloatAttribute &self) { - return mlirFloatAttrGetValueDouble(self); - }, - "Returns the value of the float point attribute"); + c.def_property_readonly("value", mlirFloatAttrGetValueDouble, + "Returns the value of the float attribute"); + c.def("__float__", mlirFloatAttrGetValueDouble, + "Converts the value of the float attribute to a Python float"); } }; @@ -414,22 +412,25 @@ }, py::arg("type"), py::arg("value"), "Gets an uniqued integer attribute associated to a type"); - c.def_property_readonly( - "value", - [](PyIntegerAttribute &self) -> py::int_ { - MlirType type = mlirAttributeGetType(self); - if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) - return mlirIntegerAttrGetValueInt(self); - if (mlirIntegerTypeIsSigned(type)) - return mlirIntegerAttrGetValueSInt(self); - return mlirIntegerAttrGetValueUInt(self); - }, - "Returns the value of the integer attribute"); + c.def_property_readonly("value", toPyInt, + "Returns the value of the integer attribute"); + c.def("__int__", toPyInt, + "Converts the value of the integer attribute to a Python int"); c.def_property_readonly_static("static_typeid", [](py::object & /*class*/) -> MlirTypeID { return mlirIntegerAttrGetTypeID(); }); } + +private: + static py::int_ toPyInt(PyIntegerAttribute &self) { + MlirType type = mlirAttributeGetType(self); + if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) + return mlirIntegerAttrGetValueInt(self); + if (mlirIntegerTypeIsSigned(type)) + return mlirIntegerAttrGetValueSInt(self); + return mlirIntegerAttrGetValueUInt(self); + } }; /// Bool Attribute subclass - BoolAttr. @@ -448,10 +449,10 @@ }, py::arg("value"), py::arg("context") = py::none(), "Gets an uniqued bool attribute"); - c.def_property_readonly( - "value", - [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, - "Returns the value of the bool attribute"); + c.def_property_readonly("value", mlirBoolAttrGetValue, + "Returns the value of the bool attribute"); + c.def("__bool__", mlirBoolAttrGetValue, + "Converts the value of the bool attribute to a Python bool"); } }; @@ -595,13 +596,8 @@ }, py::arg("type"), py::arg("value"), "Gets a uniqued string attribute associated to a type"); - c.def_property_readonly( - "value", - [](PyStringAttribute &self) { - MlirStringRef stringRef = mlirStringAttrGetValue(self); - return py::str(stringRef.data, stringRef.length); - }, - "Returns the value of the string attribute"); + c.def_property_readonly("value", toPyStr, + "Returns the value of the string attribute"); c.def_property_readonly( "value_bytes", [](PyStringAttribute &self) { @@ -609,6 +605,14 @@ return py::bytes(stringRef.data, stringRef.length); }, "Returns the value of the string attribute as `bytes`"); + c.def("__str__", toPyStr, + "Converts the value of the string attribute to a Python str"); + } + +private: + static py::str toPyStr(PyStringAttribute &self) { + MlirStringRef stringRef = mlirStringAttrGetValue(self); + return py::str(stringRef.data, stringRef.length); } }; diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -21,7 +21,7 @@ assert t.context is ctx ctx = None gc.collect() - # CHECK: "hello" + # CHECK: hello print(str(t)) # CHECK: StringAttr("hello") print(repr(t)) @@ -169,6 +169,8 @@ fattr = FloatAttr(Attribute.parse("42.0 : f32")) # CHECK: fattr value: 42.0 print("fattr value:", fattr.value) + # CHECK: fattr float: 42.0 + print("fattr float:", float(fattr), type(float(fattr))) # Test factory methods. # CHECK: default_get: 4.200000e+01 : f32 @@ -196,15 +198,23 @@ print("i_attr value:", i_attr.value) # CHECK: i_attr type: i64 print("i_attr type:", i_attr.type) + # CHECK: i_attr int: 42 + print("i_attr int:", int(i_attr), type(int(i_attr))) si_attr = IntegerAttr(Attribute.parse("-1 : si8")) # CHECK: si_attr value: -1 print("si_attr value:", si_attr.value) ui_attr = IntegerAttr(Attribute.parse("255 : ui8")) + # CHECK: i_attr int: -1 + print("si_attr int:", int(si_attr), type(int(si_attr))) # CHECK: ui_attr value: 255 print("ui_attr value:", ui_attr.value) + # CHECK: i_attr int: 255 + print("ui_attr int:", int(ui_attr), type(int(ui_attr))) idx_attr = IntegerAttr(Attribute.parse("-1 : index")) # CHECK: idx_attr value: -1 print("idx_attr value:", idx_attr.value) + # CHECK: idx_attr int: -1 + print("idx_attr int:", int(idx_attr), type(int(idx_attr))) # Test factory methods. # CHECK: default_get: 42 : i32 @@ -218,6 +228,8 @@ battr = BoolAttr(Attribute.parse("true")) # CHECK: iattr value: True print("iattr value:", battr.value) + # CHECK: iattr bool: True + print("iattr bool:", bool(battr), type(bool(battr))) # Test factory methods. # CHECK: default_get: true @@ -278,14 +290,25 @@ sattr = StringAttr(Attribute.parse('"stringattr"')) # CHECK: sattr value: stringattr print("sattr value:", sattr.value) - # CHECK: sattr value: b'stringattr' - print("sattr value:", sattr.value_bytes) + # CHECK: sattr value_bytes: b'stringattr' + print("sattr value_bytes:", sattr.value_bytes) + # CHECK: sattr str: stringattr + print("sattr str:", str(sattr)) + + typed_sattr = StringAttr(Attribute.parse('"stringattr" : i32')) + # CHECK: typed_sattr value: stringattr + print("typed_sattr value:", typed_sattr.value) + # CHECK: typed_sattr str: stringattr + print("typed_sattr str:", str(typed_sattr)) # Test factory methods. - # CHECK: default_get: "foobar" - print("default_get:", StringAttr.get("foobar")) - # CHECK: typed_get: "12345" : i32 - print("typed_get:", StringAttr.get_typed(IntegerType.get_signless(32), "12345")) + # CHECK: default_get: StringAttr("foobar") + print("default_get:", repr(StringAttr.get("foobar"))) + # CHECK: typed_get: StringAttr("12345" : i32) + print( + "typed_get:", + repr(StringAttr.get_typed(IntegerType.get_signless(32), "12345")), + ) # CHECK-LABEL: TEST: testNamedAttr @@ -294,8 +317,8 @@ with Context(): a = Attribute.parse('"stringattr"') named = a.get_named("foobar") # Note: under the small object threshold - # CHECK: attr: "stringattr" - print("attr:", named.attr) + # CHECK: attr: StringAttr("stringattr") + print("attr:", repr(named.attr)) # CHECK: name: foobar print("name:", named.name) # CHECK: named: NamedAttribute(foobar="stringattr") @@ -367,6 +390,65 @@ print("myboolarray:", DenseBoolArrayAttr.get([MyBool()])) +# CHECK-LABEL: TEST: testDenseArrayAttrConstruction +@run +def testDenseArrayAttrConstruction(): + with Context(), Location.unknown(): + + def create_and_print(cls, x): + try: + darr = cls.get(x) + print(f"input: {x} ({type(x)}), result: {darr}") + except Exception as ex: + print(f"input: {x} ({type(x)}), error: {ex}") + + # CHECK: input: [4, 2] (), + # CHECK-SAME: result: array + create_and_print(DenseI8ArrayAttr, [4, 2]) + + # CHECK: input: [4, 2.0] (), + # CHECK-SAME: error: get(): incompatible function arguments + create_and_print(DenseI8ArrayAttr, [4, 2.0]) + + # CHECK: input: [40000, 2] (), + # CHECK-SAME: error: get(): incompatible function arguments + create_and_print(DenseI8ArrayAttr, [40000, 2]) + + # CHECK: input: range(0, 4) (), + # CHECK-SAME: result: array + create_and_print(DenseI8ArrayAttr, range(4)) + + # CHECK: input: [IntegerAttr(4 : i64), IntegerAttr(2 : i64)] (), + # CHECK-SAME: result: array + create_and_print(DenseI8ArrayAttr, [Attribute.parse(f"{x}") for x in [4, 2]]) + + # CHECK: input: [IntegerAttr(4000 : i64), IntegerAttr(2 : i64)] (), + # CHECK-SAME: error: get(): incompatible function arguments + create_and_print(DenseI8ArrayAttr, [Attribute.parse(f"{x}") for x in [4000, 2]]) + + # CHECK: input: [IntegerAttr(4 : i64), FloatAttr(2.000000e+00 : f64)] (), + # CHECK-SAME: error: get(): incompatible function arguments + create_and_print(DenseI8ArrayAttr, [Attribute.parse(f"{x}") for x in [4, 2.0]]) + + # CHECK: input: [IntegerAttr(4 : i8), IntegerAttr(2 : ui16)] (), + # CHECK-SAME: result: array + create_and_print( + DenseI8ArrayAttr, [Attribute.parse(s) for s in ["4 : i8", "2 : ui16"]] + ) + + # CHECK: input: [FloatAttr(4.000000e+00 : f64), FloatAttr(2.000000e+00 : f64)] () + # CHECK-SAME: result: array + create_and_print( + DenseF32ArrayAttr, [Attribute.parse(f"{x}") for x in [4.0, 2.0]] + ) + + # CHECK: [BoolAttr(true), BoolAttr(false)] (), + # CHECK-SAME: result: array + create_and_print( + DenseBoolArrayAttr, [Attribute.parse(f"{x}") for x in ["true", "false"]] + ) + + # CHECK-LABEL: TEST: testDenseIntAttrGetItem @run def testDenseIntAttrGetItem(): @@ -620,7 +702,6 @@ @run def testConcreteAttributesRoundTrip(): with Context(), Location.unknown(): - # CHECK: FloatAttr(4.200000e+01 : f32) print(repr(Attribute.parse("42.0 : f32")))