diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1282,6 +1282,47 @@ PyOperationRef operation; }; +/// A list of operation attributes. Can be indexed by name, producing +/// attributes, or by index, producing named attributes. +class PyOpAttributeMap { +public: + PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} + + PyAttribute dunderGetItemNamed(const std::string &name) { + MlirAttribute attr = + mlirOperationGetAttributeByName(operation->get(), name.c_str()); + if (mlirAttributeIsNull(attr)) { + throw SetPyError(PyExc_KeyError, + "attempt to access a non-existent attribute"); + } + return PyAttribute(operation->getContext(), attr); + } + + PyNamedAttribute dunderGetItemIndexed(intptr_t index) { + if (index < 0 || index >= dunderLen()) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds attribute"); + } + MlirNamedAttribute namedAttr = + mlirOperationGetAttribute(operation->get(), index); + return PyNamedAttribute(namedAttr.attribute, std::string(namedAttr.name)); + } + + intptr_t dunderLen() { + return mlirOperationGetNumAttributes(operation->get()); + } + + static void bind(py::module &m) { + py::class_(m, "OpAttributeMap") + .def("__len__", &PyOpAttributeMap::dunderLen) + .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) + .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed); + } + +private: + PyOperationRef operation; +}; + } // end namespace //------------------------------------------------------------------------------ @@ -2436,6 +2477,11 @@ }) .def("__eq__", [](PyOperationBase &self, py::object other) { return false; }) + .def_property_readonly("attributes", + [](PyOperationBase &self) { + return PyOpAttributeMap( + self.getOperation().getRef()); + }) .def_property_readonly("operands", [](PyOperationBase &self) { return PyOpOperandList( @@ -2810,6 +2856,7 @@ PyBlockList::bind(m); PyOperationIterator::bind(m); PyOperationList::bind(m); + PyOpAttributeMap::bind(m); PyOpOperandList::bind(m); PyOpResultList::bind(m); PyRegionIterator::bind(m); diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -277,6 +277,53 @@ run(testOperationResultList) +# CHECK-LABEL: TEST: testOperationAttributes +def testOperationAttributes(): + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse(r""" + "some.op"() { some.attribute = 1 : i8, + other.attribute = 3.0, + dependent = "text" } : () -> () + """, ctx) + op = module.body.operations[0] + assert len(op.attributes) == 3 + iattr = IntegerAttr(op.attributes["some.attribute"]) + fattr = FloatAttr(op.attributes["other.attribute"]) + sattr = StringAttr(op.attributes["dependent"]) + # CHECK: Attribute type i8, value 1 + print(f"Attribute type {iattr.type}, value {iattr.value}") + # CHECK: Attribute type f64, value 3.0 + print(f"Attribute type {fattr.type}, value {fattr.value}") + # CHECK: Attribute value text + print(f"Attribute value {sattr.value}") + + # We don't know in which order the attributes are stored. + # CHECK-DAG: NamedAttribute(dependent="text") + # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64) + # CHECK-DAG: NamedAttribute(some.attribute=1 : i8) + for attr in op.attributes: + print(str(attr)) + + # Check that exceptions are raised as expected. + try: + op.attributes["does_not_exist"] + except KeyError: + pass + else: + assert False, "expected KeyError on accessing a non-existent attribute" + + try: + op.attributes[42] + except IndexError: + pass + else: + assert False, "expected IndexError on accessing an out-of-bounds attribute" + + +run(testOperationAttributes) + + # CHECK-LABEL: TEST: testOperationPrint def testOperationPrint(): ctx = Context()