diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1968,6 +1968,58 @@ } }; +class PyDictAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; + static constexpr const char *pyClassName = "DictAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } + + static void bindDerived(ClassTy &c) { + c.def("__len__", &PyDictAttribute::dunderLen); + c.def_static( + "get", + [](py::dict attributes, DefaultingPyMlirContext context) { + SmallVector mlirNamedAttributes; + mlirNamedAttributes.reserve(attributes.size()); + for (auto &it : attributes) { + auto &mlir_attr = it.second.cast(); + auto name = it.first.cast(); + mlirNamedAttributes.push_back(mlirNamedAttributeGet( + mlirIdentifierGet(mlirAttributeGetContext(mlir_attr), + toMlirStringRef(name)), + mlir_attr)); + } + MlirAttribute attr = + mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), + mlirNamedAttributes.data()); + return PyDictAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets an uniqued dict attribute"); + c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { + MlirAttribute attr = + mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); + if (mlirAttributeIsNull(attr)) { + throw SetPyError(PyExc_KeyError, + "attempt to access a non-existent attribute"); + } + return PyAttribute(self.getContext(), attr); + }); + c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { + if (index < 0 || index >= self.dunderLen()) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds attribute"); + } + MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); + return PyNamedAttribute( + namedAttr.attribute, + std::string(mlirIdentifierStr(namedAttr.name).data)); + }); + } +}; + /// Refinement of PyDenseElementsAttribute for attributes containing /// floating-point values. Supports element access. class PyDenseFPElementsAttribute @@ -3181,6 +3233,7 @@ PyDenseElementsAttribute::bind(m); PyDenseIntElementsAttribute::bind(m); PyDenseFPElementsAttribute::bind(m); + PyDictAttribute::bind(m); PyTypeAttribute::bind(m); PyUnitAttribute::bind(m); diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py --- a/mlir/test/Bindings/Python/ir_attributes.py +++ b/mlir/test/Bindings/Python/ir_attributes.py @@ -257,6 +257,47 @@ run(testDenseFPAttr) +# CHECK-LABEL: TEST: testDictAttr +def testDictAttr(): + with Context(): + dict_attr = { + 'stringattr': StringAttr.get('string'), + 'integerattr' : IntegerAttr.get( + IntegerType.get_signless(32), 42) + } + + a = DictAttr.get(dict_attr) + + # CHECK attr: {integerattr = 42 : i32, stringattr = "string"} + print("attr:", a) + + assert len(a) == 2 + + # CHECK: 42 : i32 + print(a['integerattr']) + + # CHECK: "string" + print(a['stringattr']) + + # Check that exceptions are raised as expected. + try: + _ = a['does_not_exist'] + except KeyError: + pass + else: + assert False, "Exception not produced" + + try: + _ = a[42] + except IndexError: + pass + else: + assert False, "expected IndexError on accessing an out-of-bounds attribute" + + + +run(testDictAttr) + # CHECK-LABEL: TEST: testTypeAttr def testTypeAttr(): with Context():