diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -296,6 +296,61 @@ /// shaped type and use its sizes to build a multi-dimensional index. MLIR_CAPI_EXPORTED int64_t mlirElementsAttrGetNumElements(MlirAttribute attr); +//===----------------------------------------------------------------------===// +// Dense array attribute. +//===----------------------------------------------------------------------===// + +/// Checks whether the given attribute is a dense array attribute. +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseBoolArray(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI8Array(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI16Array(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI32Array(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI64Array(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseF32Array(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseF64Array(MlirAttribute attr); + +/// Create a dense array attribute with the given elements. +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, + intptr_t size, + int const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, + intptr_t size, + int8_t const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, + intptr_t size, + int16_t const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, + intptr_t size, + int32_t const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, + intptr_t size, + int64_t const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, + intptr_t size, + float const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, + intptr_t size, + double const *values); + +/// Get the size of a dense array. +MLIR_CAPI_EXPORTED intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr); + +/// Get an element of a dense array. +MLIR_CAPI_EXPORTED bool mlirDenseBoolArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED float mlirDenseF32ArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED double mlirDenseF64ArrayGetElement(MlirAttribute attr, + intptr_t pos); + //===----------------------------------------------------------------------===// // Dense elements attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -110,6 +110,161 @@ } } +/// A python-wrapped dense array attribute with an element type and a derived +/// implementation class. +template +class PyDenseArrayAttribute + : public PyConcreteAttribute> { +public: + static constexpr typename PyConcreteAttribute< + PyDenseArrayAttribute>::IsAFunctionTy isaFunction = + DerivedT::isaFunction; + static constexpr const char *pyClassName = DerivedT::pyClassName; + using PyConcreteAttribute< + PyDenseArrayAttribute>::PyConcreteAttribute; + + /// Iterator over the integer elements of a dense array. + class PyDenseArrayIterator { + public: + PyDenseArrayIterator(PyAttribute attr) : attr(attr) {} + + /// Return a copy of the iterator. + PyDenseArrayIterator dunderIter() { return *this; } + + /// Return the next element. + EltTy dunderNext() { + // Throw if the index has reached the end. + if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) + throw py::stop_iteration(); + return DerivedT::getElement(attr.get(), nextIndex++); + } + + /// Bind the iterator class. + static void bind(py::module &m) { + py::class_(m, DerivedT::pyIteratorName, + py::module_local()) + .def("__iter__", &PyDenseArrayIterator::dunderIter) + .def("__next__", &PyDenseArrayIterator::dunderNext); + } + + private: + /// The referenced dense array attribute. + PyAttribute attr; + /// The next index to read. + int nextIndex = 0; + }; + + /// Get the element at the given index. + EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } + + /// Bind the attribute class. + static void bindDerived(typename PyConcreteAttribute< + PyDenseArrayAttribute>::ClassTy &c) { + // Bind the constructor. + c.def_static( + "get", + [](const std::vector &values, DefaultingPyMlirContext ctx) { + MlirAttribute attr = + DerivedT::getAttribute(ctx->get(), values.size(), values.data()); + return PyDenseArrayAttribute(ctx->getRef(), attr); + }, + py::arg("values"), py::arg("context") = py::none(), + "Gets a uniqued dense array attribute"); + // Bind the array methods. + c.def("__getitem__", + [](PyDenseArrayAttribute &arr, intptr_t i) { + if (i >= mlirDenseArrayGetNumElements(arr)) + throw py::index_error("DenseArray index out of range"); + return arr.getItem(i); + }); + c.def("__len__", [](const PyDenseArrayAttribute &arr) { + return mlirDenseArrayGetNumElements(arr); + }); + c.def("__iter__", [](const PyDenseArrayAttribute &arr) { + return PyDenseArrayIterator(arr); + }); + // Bind a concat. + c.def("__add__", [](PyDenseArrayAttribute &arr, + py::list extras) { + std::vector values; + intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); + values.reserve(numOldElements + py::len(extras)); + for (intptr_t i = 0; i < numOldElements; ++i) + values.push_back(arr.getItem(i)); + for (py::handle attr : extras) + values.push_back(pyTryCast(attr)); + MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(), + values.size(), values.data()); + return PyDenseArrayAttribute(arr.getContext(), attr); + }); + } +}; + +/// Instantiate the python dense array classes. +struct PyDenseBoolArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; + static constexpr auto getAttribute = mlirDenseBoolArrayGet; + static constexpr auto getElement = mlirDenseBoolArrayGetElement; + static constexpr const char *pyClassName = "DenseBoolArrayAttr"; + static constexpr const char *pyIteratorName = "DenseBoolArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseI8ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; + static constexpr auto getAttribute = mlirDenseI8ArrayGet; + static constexpr auto getElement = mlirDenseI8ArrayGetElement; + static constexpr const char *pyClassName = "DenseI8ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseI8ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseI16ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; + static constexpr auto getAttribute = mlirDenseI16ArrayGet; + static constexpr auto getElement = mlirDenseI16ArrayGetElement; + static constexpr const char *pyClassName = "DenseI16ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseI16ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseI32ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; + static constexpr auto getAttribute = mlirDenseI32ArrayGet; + static constexpr auto getElement = mlirDenseI32ArrayGetElement; + static constexpr const char *pyClassName = "DenseI32ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseI32ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseI64ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; + static constexpr auto getAttribute = mlirDenseI64ArrayGet; + static constexpr auto getElement = mlirDenseI64ArrayGetElement; + static constexpr const char *pyClassName = "DenseI64ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseI64ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseF32ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; + static constexpr auto getAttribute = mlirDenseF32ArrayGet; + static constexpr auto getElement = mlirDenseF32ArrayGetElement; + static constexpr const char *pyClassName = "DenseF32ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseF32ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseF64ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; + static constexpr auto getAttribute = mlirDenseF64ArrayGet; + static constexpr auto getElement = mlirDenseF64ArrayGetElement; + static constexpr const char *pyClassName = "DenseF64ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseF64ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; + class PyArrayAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; @@ -891,6 +1046,22 @@ void mlir::python::populateIRAttributes(py::module &m) { PyAffineMapAttribute::bind(m); + + PyDenseBoolArrayAttribute::bind(m); + PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseI8ArrayAttribute::bind(m); + PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseI16ArrayAttribute::bind(m); + PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseI32ArrayAttribute::bind(m); + PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseI64ArrayAttribute::bind(m); + PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseF32ArrayAttribute::bind(m); + PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseF64ArrayAttribute::bind(m); + PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); + PyArrayAttribute::bind(m); PyArrayAttribute::PyArrayAttributeIterator::bind(m); PyBoolAttribute::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -311,6 +311,106 @@ return unwrap(attr).cast().getNumElements(); } +//===----------------------------------------------------------------------===// +// Dense array attribute. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// IsA support. + +bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) { + return unwrap(attr).isa(); +} +bool mlirAttributeIsADenseI8Array(MlirAttribute attr) { + return unwrap(attr).isa(); +} +bool mlirAttributeIsADenseI16Array(MlirAttribute attr) { + return unwrap(attr).isa(); +} +bool mlirAttributeIsADenseI32Array(MlirAttribute attr) { + return unwrap(attr).isa(); +} +bool mlirAttributeIsADenseI64Array(MlirAttribute attr) { + return unwrap(attr).isa(); +} +bool mlirAttributeIsADenseF32Array(MlirAttribute attr) { + return unwrap(attr).isa(); +} +bool mlirAttributeIsADenseF64Array(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +//===----------------------------------------------------------------------===// +// Constructors. + +MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size, + int const *values) { + SmallVector elements(values, values + size); + return wrap(DenseBoolArrayAttr::get(unwrap(ctx), elements)); +} +MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, intptr_t size, + int8_t const *values) { + return wrap( + DenseI8ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} +MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, intptr_t size, + int16_t const *values) { + return wrap( + DenseI16ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} +MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, + int32_t const *values) { + return wrap( + DenseI32ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} +MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, intptr_t size, + int64_t const *values) { + return wrap( + DenseI64ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} +MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, intptr_t size, + float const *values) { + return wrap( + DenseF32ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} +MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size, + double const *values) { + return wrap( + DenseF64ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} + +//===----------------------------------------------------------------------===// +// Accessors. + +intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { + return unwrap(attr).cast().size(); +} + +//===----------------------------------------------------------------------===// +// Indexed accessors. + +bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast()[pos]; +} +int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast()[pos]; +} +int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast()[pos]; +} +int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast()[pos]; +} +int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast()[pos]; +} +float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast()[pos]; +} +double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast()[pos]; +} + //===----------------------------------------------------------------------===// // Dense elements attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1186,6 +1186,40 @@ mlirAttributeDump(sparseAttr); // CHECK: sparse<{{\[}}[0, 1]], 0.000000e+00> : tensor<1x2xf32> + MlirAttribute boolArray = mlirDenseBoolArrayGet(ctx, 2, bools); + MlirAttribute int8Array = mlirDenseI8ArrayGet(ctx, 2, ints8); + MlirAttribute int16Array = mlirDenseI16ArrayGet(ctx, 2, ints16); + MlirAttribute int32Array = mlirDenseI32ArrayGet(ctx, 2, ints32); + MlirAttribute int64Array = mlirDenseI64ArrayGet(ctx, 2, ints64); + MlirAttribute floatArray = mlirDenseF32ArrayGet(ctx, 2, floats); + MlirAttribute doubleArray = mlirDenseF64ArrayGet(ctx, 2, doubles); + if (!mlirAttributeIsADenseBoolArray(boolArray) || + !mlirAttributeIsADenseI8Array(int8Array) || + !mlirAttributeIsADenseI16Array(int16Array) || + !mlirAttributeIsADenseI32Array(int32Array) || + !mlirAttributeIsADenseI64Array(int64Array) || + !mlirAttributeIsADenseF32Array(floatArray) || + !mlirAttributeIsADenseF64Array(doubleArray)) + return 19; + + if (mlirDenseArrayGetNumElements(boolArray) != 2 || + mlirDenseArrayGetNumElements(int8Array) != 2 || + mlirDenseArrayGetNumElements(int16Array) != 2 || + mlirDenseArrayGetNumElements(int32Array) != 2 || + mlirDenseArrayGetNumElements(int64Array) != 2 || + mlirDenseArrayGetNumElements(floatArray) != 2 || + mlirDenseArrayGetNumElements(doubleArray) != 2) + return 20; + + if (mlirDenseBoolArrayGetElement(boolArray, 1) != 1 || + mlirDenseI8ArrayGetElement(int8Array, 1) != 1 || + mlirDenseI16ArrayGetElement(int16Array, 1) != 1 || + mlirDenseI32ArrayGetElement(int32Array, 1) != 1 || + mlirDenseI64ArrayGetElement(int64Array, 1) != 1 || + fabsf(mlirDenseF32ArrayGetElement(floatArray, 1) - 1.0f) > 1E-6f || + fabs(mlirDenseF64ArrayGetElement(doubleArray, 1) - 1.0) > 1E-6) + return 21; + return 0; } diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -1,8 +1,10 @@ # RUN: %PYTHON %s | FileCheck %s import gc + from mlir.ir import * + def run(f): print("\nTEST:", f.__name__) f() @@ -319,6 +321,29 @@ print(ShapedType(a.type).element_type) +@run +def testDenseArrayGetItem(): + def print_item(AttrClass, attr_asm): + attr = AttrClass(Attribute.parse(attr_asm)) + print(f"{len(attr)}: {attr[0]}, {attr[1]}") + + with Context(): + # CHECK: 2: 0, 1 + print_item(DenseBoolArrayAttr, "array") + # CHECK: 2: 2, 3 + print_item(DenseI8ArrayAttr, "array") + # CHECK: 2: 4, 5 + print_item(DenseI16ArrayAttr, "array") + # CHECK: 2: 6, 7 + print_item(DenseI32ArrayAttr, "array") + # CHECK: 2: 8, 9 + print_item(DenseI64ArrayAttr, "array") + # CHECK: 2: 1.{{0+}}, 2.{{0+}} + print_item(DenseF32ArrayAttr, "array") + # CHECK: 2: 3.{{0+}}, 4.{{0+}} + print_item(DenseF64ArrayAttr, "array") + + # CHECK-LABEL: TEST: testDenseIntAttrGetItem @run def testDenseIntAttrGetItem():