diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1461,6 +1461,88 @@ static void bindDerived(ClassTy &m) {} }; +class PyArrayAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; + static constexpr const char *pyClassName = "ArrayAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + class PyArrayAttributeIterator { + public: + PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {} + + PyArrayAttributeIterator &dunderIter() { return *this; } + + PyAttribute dunderNext() { + if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { + throw py::stop_iteration(); + } + return PyAttribute(attr.getContext(), + mlirArrayAttrGetElement(attr.get(), nextIndex++)); + } + + static void bind(py::module &m) { + py::class_(m, "ArrayAttributeIterator") + .def("__iter__", &PyArrayAttributeIterator::dunderIter) + .def("__next__", &PyArrayAttributeIterator::dunderNext); + } + + private: + PyAttribute attr; + int nextIndex = 0; + }; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](py::list attributes, DefaultingPyMlirContext context) { + SmallVector mlirAttributes; + mlirAttributes.reserve(py::len(attributes)); + for (auto attribute : attributes) { + try { + mlirAttributes.push_back(attribute.cast()); + } catch (py::cast_error &err) { + std::string msg = std::string("Invalid attribute when attempting " + "to create an ArrayAttribute (") + + err.what() + ")"; + throw py::cast_error(msg); + } catch (py::reference_cast_error &err) { + // This exception seems thrown when the value is "None". + std::string msg = + std::string("Invalid attribute (None?) when attempting to " + "create an ArrayAttribute (") + + err.what() + ")"; + throw py::cast_error(msg); + } + } + MlirAttribute attr = mlirArrayAttrGet( + context->get(), mlirAttributes.size(), mlirAttributes.data()); + return PyArrayAttribute(context->getRef(), attr); + }, + py::arg("attributes"), py::arg("context") = py::none(), + "Gets a uniqued Array attribute"); + c.def("__getitem__", + [](PyArrayAttribute &arr, intptr_t i) { + if (i >= mlirArrayAttrGetNumElements(arr)) + throw py::index_error("ArrayAttribute index out of range"); + return PyAttribute(arr.getContext(), + mlirArrayAttrGetElement(arr, i)); + }) + .def("__len__", + [](const PyArrayAttribute &arr) { + return mlirArrayAttrGetNumElements(arr); + }) + .def( + "__iter__", + [](const PyArrayAttribute &arr) { + return PyArrayAttributeIterator(arr); + }, + py::keep_alive< + 0, + 1>() /* Essential: keep object alive while iterator exists */); + } +}; + /// Float Point Attribute subclass - FloatAttr. class PyFloatAttribute : public PyConcreteAttribute { public: @@ -3089,6 +3171,8 @@ // Builtin attribute bindings. PyFloatAttribute::bind(m); + PyArrayAttribute::bind(m); + PyArrayAttribute::PyArrayAttributeIterator::bind(m); PyIntegerAttribute::bind(m); PyBoolAttribute::bind(m); PyStringAttribute::bind(m); diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py --- a/mlir/test/Bindings/Python/ir_attributes.py +++ b/mlir/test/Bindings/Python/ir_attributes.py @@ -269,3 +269,54 @@ run(testTypeAttr) + + +# CHECK-LABEL: TEST: testArrayAttr +def testArrayAttr(): + with Context(): + raw = Attribute.parse("[42, true, vector<4xf32>]") + # CHECK: attr: [42, true, vector<4xf32>] + print("raw attr:", raw) + # CHECK: - 42 + # CHECK: - true + # CHECK: - vector<4xf32> + for attr in ArrayAttr(raw): + print("- ", attr) + + with Context(): + intAttr = Attribute.parse("42") + vecAttr = Attribute.parse("vector<4xf32>") + boolAttr = BoolAttr.get(True) + raw = ArrayAttr.get([vecAttr, boolAttr, intAttr]) + # CHECK: attr: [vector<4xf32>, true, 42] + print("raw attr:", raw) + # CHECK: - vector<4xf32> + # CHECK: - true + # CHECK: - 42 + arr = ArrayAttr(raw) + for attr in arr: + print("- ", attr) + # CHECK: attr[0]: vector<4xf32> + print("attr[0]:", arr[0]) + # CHECK: attr[1]: true + print("attr[1]:", arr[1]) + # CHECK: attr[2]: 42 + print("attr[2]:", arr[2]) + try: + print("attr[3]:", arr[3]) + except IndexError as e: + # CHECK: Error: ArrayAttribute index out of range + print("Error: ", e) + with Context(): + try: + ArrayAttr.get([None]) + except RuntimeError as e: + # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute + print("Error: ", e) + try: + ArrayAttr.get([42]) + except RuntimeError as e: + # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute (Unable to cast Python instance of type to C++ type 'mlir::python::PyAttribute') + print("Error: ", e) +run(testArrayAttr) +