diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1885,6 +1885,61 @@ } }; +class PyDictAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; + static constexpr const char *pyClassName = "DictAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } + + PyAttribute dunderGetItemNamed(const std::string &name, + DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)); + if (mlirAttributeIsNull(attr)) { + throw SetPyError(PyExc_KeyError, + "attempt to access a non-existent attribute"); + } + return PyAttribute(context->getRef(), attr); + } + + static void bindDerived(ClassTy &c) { + c.def("__len__", &PyDictAttribute::dunderLen); + c.def_static( + "get", + [](llvm::Optional attributes, + DefaultingPyMlirContext context) { + std::vector> mlirAttributes; + MlirAttribute attr; + + mlirAttributes.reserve(attributes->size()); + for (auto &it : *attributes) { + auto &mlir_attr = it.second.cast(); + auto name = it.first.cast(); + mlirAttributes.emplace_back(std::move(name.c_str()), + mlir_attr.get()); + } + if (!mlirAttributes.empty()) { + std::vector mlirNamedAttributes; + mlirNamedAttributes.reserve(mlirAttributes.size()); + for (auto &it : mlirAttributes) { + mlirNamedAttributes.push_back( + mlirNamedAttributeGet(toMlirStringRef(it.first), it.second)); + } + attr = mlirDictionaryAttrGet(context->get(), + mlirNamedAttributes.size(), + mlirNamedAttributes.data()); + } + return PyDictAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets an uniqued dict attribute"); + c.def("__getitem__", &PyDictAttribute::dunderGetItemNamed, py::arg("name"), + py::arg("context") = py::none()); + } +}; + /// Refinement of PyDenseElementsAttribute for attributes containing /// floating-point values. Supports element access. class PyDenseFPElementsAttribute @@ -3095,6 +3150,7 @@ PyDenseElementsAttribute::bind(m); PyDenseIntElementsAttribute::bind(m); PyDenseFPElementsAttribute::bind(m); + PyDictAttribute::bind(m); PyTypeAttribute::bind(m); PyUnitAttribute::bind(m); diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py --- a/mlir/test/Bindings/Python/ir_attributes.py +++ b/mlir/test/Bindings/Python/ir_attributes.py @@ -257,6 +257,32 @@ run(testDenseFPAttr) +# CHECK-LABEL: TEST: testDictAttr +def testDictAttr(): + with Context(): + dict_attr = { + 'stringattr': StringAttr.get('string'), + 'integerattr' : IntegerAttr.get( + IntegerType.get_signless(32), 42) + } + + a = DictAttr.get(dict_attr) + + # CHECK attr: {integerattr = 42 : i32, stringattr = "string"} + print("attr:", a) + + assert len(a) == 2 + + # CHECK: 42 : i32 + print(a['integerattr']) + + # CHECK: "string" + print(a['stringattr']) + + + +run(testDictAttr) + # CHECK-LABEL: TEST: testTypeAttr def testTypeAttr(): with Context():