diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -18,7 +18,6 @@ using namespace mlir::python; using llvm::SmallVector; -using llvm::StringRef; using llvm::Twine; namespace { @@ -44,6 +43,24 @@ } }; +template +static T pyTryCast(py::handle object) { + try { + return object.cast(); + } catch (py::cast_error &err) { + std::string msg = + std::string( + "Invalid attribute when attempting to create an ArrayAttribute (") + + err.what() + ")"; + throw py::cast_error(msg); + } catch (py::reference_cast_error &err) { + std::string msg = std::string("Invalid attribute (None?) when attempting " + "to create an ArrayAttribute (") + + err.what() + ")"; + throw py::cast_error(msg); + } +} + class PyArrayAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; @@ -76,6 +93,10 @@ int nextIndex = 0; }; + PyAttribute getItem(intptr_t i) { + return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); + } + static void bindDerived(ClassTy &c) { c.def_static( "get", @@ -83,21 +104,7 @@ SmallVector mlirAttributes; mlirAttributes.reserve(py::len(attributes)); for (auto attribute : attributes) { - try { - mlirAttributes.push_back(attribute.cast()); - } catch (py::cast_error &err) { - std::string msg = std::string("Invalid attribute when attempting " - "to create an ArrayAttribute (") + - err.what() + ")"; - throw py::cast_error(msg); - } catch (py::reference_cast_error &err) { - // This exception seems thrown when the value is "None". - std::string msg = - std::string("Invalid attribute (None?) when attempting to " - "create an ArrayAttribute (") + - err.what() + ")"; - throw py::cast_error(msg); - } + mlirAttributes.push_back(pyTryCast(attribute)); } MlirAttribute attr = mlirArrayAttrGet( context->get(), mlirAttributes.size(), mlirAttributes.data()); @@ -109,8 +116,7 @@ [](PyArrayAttribute &arr, intptr_t i) { if (i >= mlirArrayAttrGetNumElements(arr)) throw py::index_error("ArrayAttribute index out of range"); - return PyAttribute(arr.getContext(), - mlirArrayAttrGetElement(arr, i)); + return arr.getItem(i); }) .def("__len__", [](const PyArrayAttribute &arr) { @@ -119,6 +125,18 @@ .def("__iter__", [](const PyArrayAttribute &arr) { return PyArrayAttributeIterator(arr); }); + c.def("__add__", [](PyArrayAttribute arr, py::list extras) { + std::vector attributes; + intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); + attributes.reserve(numOldElements + py::len(extras)); + for (intptr_t i = 0; i < numOldElements; ++i) + attributes.push_back(arr.getItem(i)); + for (py::handle attr : extras) + attributes.push_back(pyTryCast(attr)); + MlirAttribute arrayAttr = mlirArrayAttrGet( + arr.getContext()->get(), attributes.size(), attributes.data()); + return PyArrayAttribute(arr.getContext(), arrayAttr); + }); } }; @@ -602,7 +620,7 @@ mlirNamedAttributes.data()); return PyDictAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + py::arg("value") = py::dict(), py::arg("context") = py::none(), "Gets an uniqued dict attribute"); c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { MlirAttribute attr = diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1590,6 +1590,19 @@ } }; +/// Returns the list of types of the values held by container. +template +static std::vector getValueTypes(Container &container, + PyMlirContextRef &context) { + std::vector result; + result.reserve(container.getNumElements()); + for (int i = 0, e = container.getNumElements(); i < e; ++i) { + result.push_back( + PyType(context, mlirValueGetType(container.getElement(i).get()))); + } + return result; +} + /// A list of block arguments. Internally, these are stored as consecutive /// elements, random access is cheap. The argument list is associated with the /// operation that contains the block (detached blocks are not allowed in @@ -1625,6 +1638,12 @@ return PyBlockArgumentList(operation, block, startIndex, length, step); } + static void bindDerived(ClassTy &c) { + c.def_property_readonly("types", [](PyBlockArgumentList &self) { + return getValueTypes(self, self.operation->getContext()); + }); + } + private: PyOperationRef operation; MlirBlock block; @@ -1712,6 +1731,12 @@ return PyOpResultList(operation, startIndex, length, step); } + static void bindDerived(ClassTy &c) { + c.def_property_readonly("types", [](PyOpResultList &self) { + return getValueTypes(self, self.operation->getContext()); + }); + } + private: PyOperationRef operation; }; diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -343,6 +343,9 @@ else: assert False, "expected IndexError on accessing an out-of-bounds attribute" + # CHECK "empty: {}" + print("empty: ", DictAttr.get()) + # CHECK-LABEL: TEST: testTypeAttr @run @@ -404,3 +407,9 @@ except RuntimeError as e: # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute print("Error: ", e) + + with Context(): + array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")]) + array = array + [StringAttr.get("c")] + # CHECK: concat: ["a", "b", "c"] + print("concat: ", array) diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -145,6 +145,12 @@ print("Length: ", len(entry_block.arguments[:2] + entry_block.arguments[1:])) + # CHECK: Type: i8 + # CHECK: Type: i16 + # CHECK: Type: i24 + for t in entry_block.arguments.types: + print("Type: ", t) + run(testBlockArgumentList) @@ -380,6 +386,12 @@ for res in call.results: print(f"Result {res.result_number}, type {res.type}") + # CHECK: Result type i32 + # CHECK: Result type f64 + # CHECK: Result type index + for t in call.results.types: + print(f"Result type {t}") + run(testOperationResultList)