diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1282,6 +1282,43 @@ PyOperationRef operation; }; +/// A list of operation attributes. Can be indexed by name, producing +/// attributes, or by index, producing named attributes. +class PyOpAttributeMap { +public: + PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} + + PyAttribute dunderGetItemNamed(const std::string &name) { + return PyAttribute( + operation->getContext(), + mlirOperationGetAttributeByName(operation->get(), name.c_str())); + } + + PyNamedAttribute dunderGetItemIndexed(intptr_t index) { + if (index < 0 || index >= dunderLen()) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds attribute"); + } + MlirNamedAttribute namedAttr = + mlirOperationGetAttribute(operation->get(), index); + return PyNamedAttribute(namedAttr.attribute, std::string(namedAttr.name)); + } + + intptr_t dunderLen() { + return mlirOperationGetNumAttributes(operation->get()); + } + + static void bind(py::module &m) { + py::class_(m, "OpAttributeMap") + .def("__len__", &PyOpAttributeMap::dunderLen) + .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) + .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed); + } + +private: + PyOperationRef operation; +}; + } // end namespace //------------------------------------------------------------------------------ @@ -2436,6 +2473,11 @@ }) .def("__eq__", [](PyOperationBase &self, py::object other) { return false; }) + .def_property_readonly("attributes", + [](PyOperationBase &self) { + return PyOpAttributeMap( + self.getOperation().getRef()); + }) .def_property_readonly("operands", [](PyOperationBase &self) { return PyOpOperandList( @@ -2810,6 +2852,7 @@ PyBlockList::bind(m); PyOperationIterator::bind(m); PyOperationList::bind(m); + PyOpAttributeMap::bind(m); PyOpOperandList::bind(m); PyOpResultList::bind(m); PyRegionIterator::bind(m); diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -277,6 +277,38 @@ run(testOperationResultList) +# CHECK-LABEL: TEST: testOperationAttributes +def testOperationAttributes(): + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse(r""" + "some.op"() { some.attribute = 1 : i8, + other.attribute = 3.0, + dependent = "text" } : () -> () + """, ctx) + op = module.body.operations[0] + assert len(op.attributes) == 3 + iattr = IntegerAttr(op.attributes["some.attribute"]) + fattr = FloatAttr(op.attributes["other.attribute"]) + sattr = StringAttr(op.attributes["dependent"]) + # CHECK: Attribute type i8, value 1 + print(f"Attribute type {iattr.type}, value {iattr.value}") + # CHECK: Attribute type f64, value 3.0 + print(f"Attribute type {fattr.type}, value {fattr.value}") + # CHECK: Attribute value text + print(f"Attribute value {sattr.value}") + + # We don't know in which order the attributes are stored. + # CHECK-DAG: NamedAttribute(dependent="text") + # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64) + # CHECK-DAG: NamedAttribute(some.attribute=1 : i8) + for attr in op.attributes: + print(str(attr)) + + +run(testOperationAttributes) + + # CHECK-LABEL: TEST: testOperationPrint def testOperationPrint(): ctx = Context()