diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -596,8 +596,13 @@ }, py::arg("type"), py::arg("value"), "Gets a uniqued string attribute associated to a type"); - c.def_property_readonly("value", toPyStr, - "Returns the value of the string attribute"); + c.def_property_readonly( + "value", + [](PyStringAttribute &self) { + MlirStringRef stringRef = mlirStringAttrGetValue(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the value of the string attribute"); c.def_property_readonly( "value_bytes", [](PyStringAttribute &self) { @@ -605,14 +610,6 @@ return py::bytes(stringRef.data, stringRef.length); }, "Returns the value of the string attribute as `bytes`"); - c.def("__str__", toPyStr, - "Converts the value of the string attribute to a Python str"); - } - -private: - static py::str toPyStr(PyStringAttribute &self) { - MlirStringRef stringRef = mlirStringAttrGetValue(self); - return py::str(stringRef.data, stringRef.length); } }; diff --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py --- a/mlir/test/python/dialects/builtin.py +++ b/mlir/test/python/dialects/builtin.py @@ -134,13 +134,13 @@ ), visibility="nested", ) - # CHECK: Name is: some_func + # CHECK: Name is: "some_func" print("Name is: ", f.name) # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> print("Type is: ", f.type) - # CHECK: Visibility is: nested + # CHECK: Visibility is: "nested" print("Visibility is: ", f.visibility) try: diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -21,7 +21,7 @@ assert t.context is ctx ctx = None gc.collect() - # CHECK: hello + # CHECK: "hello" print(str(t)) # CHECK: StringAttr("hello") print(repr(t)) @@ -290,25 +290,14 @@ sattr = StringAttr(Attribute.parse('"stringattr"')) # CHECK: sattr value: stringattr print("sattr value:", sattr.value) - # CHECK: sattr value_bytes: b'stringattr' - print("sattr value_bytes:", sattr.value_bytes) - # CHECK: sattr str: stringattr - print("sattr str:", str(sattr)) - - typed_sattr = StringAttr(Attribute.parse('"stringattr" : i32')) - # CHECK: typed_sattr value: stringattr - print("typed_sattr value:", typed_sattr.value) - # CHECK: typed_sattr str: stringattr - print("typed_sattr str:", str(typed_sattr)) + # CHECK: sattr value: b'stringattr' + print("sattr value:", sattr.value_bytes) # Test factory methods. - # CHECK: default_get: StringAttr("foobar") - print("default_get:", repr(StringAttr.get("foobar"))) - # CHECK: typed_get: StringAttr("12345" : i32) - print( - "typed_get:", - repr(StringAttr.get_typed(IntegerType.get_signless(32), "12345")), - ) + # CHECK: default_get: "foobar" + print("default_get:", StringAttr.get("foobar")) + # CHECK: typed_get: "12345" : i32 + print("typed_get:", StringAttr.get_typed(IntegerType.get_signless(32), "12345")) # CHECK-LABEL: TEST: testNamedAttr @@ -317,8 +306,8 @@ with Context(): a = Attribute.parse('"stringattr"') named = a.get_named("foobar") # Note: under the small object threshold - # CHECK: attr: StringAttr("stringattr") - print("attr:", repr(named.attr)) + # CHECK: attr: "stringattr" + print("attr:", named.attr) # CHECK: name: foobar print("name:", named.name) # CHECK: named: NamedAttribute(foobar="stringattr")