diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2364,6 +2364,7 @@ .def("__eq__", [](PyAttribute &self, PyAttribute &other) { return self == other; }) .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) + .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; }) .def( "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, kDumpDocstring) @@ -2457,6 +2458,7 @@ "Context that owns the Type") .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) .def("__eq__", [](PyType &self, py::object &other) { return false; }) + .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; }) .def( "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) .def( diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -57,6 +57,28 @@ print("a1 == None:", a1 == None) +# CHECK-LABEL: TEST: testAttrHash +@run +def testAttrHash(): + with Context(): + a1 = Attribute.parse('"attr1"') + a2 = Attribute.parse('"attr2"') + a3 = Attribute.parse('"attr1"') + # CHECK: hash(a1) == hash(a3): True + print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__()) + # In general, hashes don't have to be unique. In this case, however, the + # hash is just the underlying pointer so it will be. + # CHECK: hash(a1) == hash(a2): False + print("hash(a1) == hash(a2):", a1.__hash__() == a2.__hash__()) + + s = set() + s.add(a1) + s.add(a2) + s.add(a3) + # CHECK: len(s): 2 + print("len(s): ", len(s)) + + # CHECK-LABEL: TEST: testAttrCast @run def testAttrCast(): @@ -382,4 +404,3 @@ except RuntimeError as e: # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute print("Error: ", e) - diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -57,6 +57,28 @@ print("t1 == None:", t1 == None) +# CHECK-LABEL: TEST: testTypeHash +@run +def testTypeHash(): + ctx = Context() + t1 = Type.parse("i32", ctx) + t2 = Type.parse("f32", ctx) + t3 = Type.parse("i32", ctx) + + # CHECK: hash(t1) == hash(t3): True + print("hash(t1) == hash(t3):", t1.__hash__() == t3.__hash__()) + # In general, hashes don't have to be unique. In this case, however, the + # hash is just the underlying pointer so it will be. + # CHECK: hash(t1) == hash(t2): False + print("hash(t1) == hash(t2):", t1.__hash__() == t2.__hash__()) + + s = set() + s.add(t1) + s.add(t2) + s.add(t3) + # CHECK: len(s): 2 + print("len(s): ", len(s)) + # CHECK-LABEL: TEST: testTypeCast @run def testTypeCast():