diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1621,11 +1621,14 @@ return PyDenseElementsAttribute(contextWrapper->getRef(), elements); } + intptr_t dunderLen() { return mlirElementsAttrGetNumElements(attr); } + static void bindDerived(ClassTy &c) { - c.def_static("get", PyDenseElementsAttribute::getFromBuffer, - py::arg("array"), py::arg("signless") = true, - py::arg("context") = py::none(), - "Gets from a buffer or ndarray") + c.def("__len__", &PyDenseElementsAttribute::dunderLen) + .def_static("get", PyDenseElementsAttribute::getFromBuffer, + py::arg("array"), py::arg("signless") = true, + py::arg("context") = py::none(), + "Gets from a buffer or ndarray") .def_static("get_splat", PyDenseElementsAttribute::getSplat, py::arg("shaped_type"), py::arg("element_attr"), "Gets a DenseElementsAttr where all values are the same") @@ -1651,6 +1654,101 @@ } }; +/// Refinement of the PyDenseElementsAttribute for attributes containing integer +/// (and boolean) values. Supports element access. +class PyDenseIntElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; + static constexpr const char *pyClassName = "DenseIntElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + /// Returns the element at the given linear position. Asserts if the index is + /// out of range. + py::int_ dunderGetItem(intptr_t pos) { + if (pos < 0 || pos >= dunderLen()) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds element"); + } + + MlirType type = mlirAttributeGetType(attr); + type = mlirShapedTypeGetElementType(type); + assert(mlirTypeIsAInteger(type) && + "expected integer element type in dense int elements attribute"); + // Dispatch element extraction to an appropriate C function based on the + // elemental type of the attribute. py::int_ is implicitly constructible + // from any C++ integral type and handles bitwidth correctly. + // TODO: consider caching the type properties in the constructor to avoid + // querying them on each element access. + unsigned width = mlirIntegerTypeGetWidth(type); + bool isUnsigned = mlirIntegerTypeIsUnsigned(type); + if (isUnsigned) { + if (width == 1) { + return mlirDenseElementsAttrGetBoolValue(attr, pos); + } + if (width == 32) { + return mlirDenseElementsAttrGetUInt32Value(attr, pos); + } + if (width == 64) { + return mlirDenseElementsAttrGetUInt64Value(attr, pos); + } + } else { + if (width == 1) { + return mlirDenseElementsAttrGetBoolValue(attr, pos); + } + if (width == 32) { + return mlirDenseElementsAttrGetInt32Value(attr, pos); + } + if (width == 64) { + return mlirDenseElementsAttrGetInt64Value(attr, pos); + } + } + throw SetPyError(PyExc_TypeError, "Unsupported integer type"); + } + + static void bindDerived(ClassTy &c) { + c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); + } +}; + +/// Refinement of PyDenseElementsAttribute for attributes containing +/// floating-point values. Supports element access. +class PyDenseFPElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; + static constexpr const char *pyClassName = "DenseFPElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + py::float_ dunderGetItem(intptr_t pos) { + if (pos < 0 || pos >= dunderLen()) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds element"); + } + + MlirType type = mlirAttributeGetType(attr); + type = mlirShapedTypeGetElementType(type); + // Dispatch element extraction to an appropriate C function based on the + // elemental type of the attribute. py::float_ is implicitly constructible + // from float and double. + // TODO: consider caching the type properties in the constructor to avoid + // querying them on each element access. + if (mlirTypeIsAF32(type)) { + return mlirDenseElementsAttrGetFloatValue(attr, pos); + } + if (mlirTypeIsAF64(type)) { + return mlirDenseElementsAttrGetDoubleValue(attr, pos); + } + throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); + } + + static void bindDerived(ClassTy &c) { + c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); + } +}; + } // namespace //------------------------------------------------------------------------------ @@ -2754,6 +2852,8 @@ PyBoolAttribute::bind(m); PyStringAttribute::bind(m); PyDenseElementsAttribute::bind(m); + PyDenseIntElementsAttribute::bind(m); + PyDenseFPElementsAttribute::bind(m); //---------------------------------------------------------------------------- // Mapping of PyType. diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py --- a/mlir/test/Bindings/Python/ir_attributes.py +++ b/mlir/test/Bindings/Python/ir_attributes.py @@ -181,3 +181,63 @@ print("named:", named) run(testNamedAttr) + + +# CHECK-LABEL: TEST: testDenseIntAttr +def testDenseIntAttr(): + with Context(): + raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>") + # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]> + print("attr:", raw) + + a = DenseIntElementsAttr(raw) + assert len(a) == 6 + + # CHECK: 0 1 2 3 4 5 + for value in a: + print(value, end=" ") + print() + + # CHECK: i32 + print(ShapedType(a.type).element_type) + + raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>") + # CHECK: attr: dense<[true, false, true, false]> + print("attr:", raw) + + a = DenseIntElementsAttr(raw) + assert len(a) == 4 + + # CHECK: 1 0 1 0 + for value in a: + print(value, end=" ") + print() + + # CHECK: i1 + print(ShapedType(a.type).element_type) + + +run(testDenseIntAttr) + + +# CHECK-LABEL: TEST: testDenseFPAttr +def testDenseFPAttr(): + with Context(): + raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>") + # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> + + print("attr:", raw) + + a = DenseFPElementsAttr(raw) + assert len(a) == 4 + + # CHECK: 0.0 1.0 2.0 3.0 + for value in a: + print(value, end=" ") + print() + + # CHECK: f32 + print(ShapedType(a.type).element_type) + + +run(testDenseFPAttr)