diff --git a/mlir/include/mlir-c/StandardAttributes.h b/mlir/include/mlir-c/StandardAttributes.h --- a/mlir/include/mlir-c/StandardAttributes.h +++ b/mlir/include/mlir-c/StandardAttributes.h @@ -404,6 +404,10 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos); +/** Returns the raw data of the given dense elements attribute. */ +MLIR_CAPI_EXPORTED const void * +mlirDenseElementsAttrGetRawData(MlirAttribute attr); + //===----------------------------------------------------------------------===// // Opaque elements attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1360,7 +1360,7 @@ } static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); + auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol()); cls.def(py::init(), py::keep_alive<0, 1>()); DerivedTy::bindDerived(cls); } @@ -1621,6 +1621,42 @@ intptr_t dunderLen() { return mlirElementsAttrGetNumElements(attr); } + py::buffer_info accessBuffer() { + MlirType shapedType = mlirAttributeGetType(this->attr); + MlirType elementType = mlirShapedTypeGetElementType(shapedType); + + if (mlirTypeIsAF32(elementType)) { + // f32 + return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue); + } else if (mlirTypeIsAF64(elementType)) { + // f64 + return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue); + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 32) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i32 + return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value); + } else if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i32 + return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value); + } + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 64) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i64 + return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value); + } else if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i64 + return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value); + } + } + + std::string message = "unimplemented array format."; + throw SetPyError(PyExc_ValueError, message); + } + static void bindDerived(ClassTy &c) { c.def("__len__", &PyDenseElementsAttribute::dunderLen) .def_static("get", PyDenseElementsAttribute::getFromBuffer, @@ -1633,7 +1669,8 @@ .def_property_readonly("is_splat", [](PyDenseElementsAttribute &self) -> bool { return mlirDenseElementsAttrIsSplat(self.attr); - }); + }) + .def_buffer(&PyDenseElementsAttribute::accessBuffer); } private: @@ -1650,6 +1687,33 @@ const ElementTy *contents = static_cast(arrayInfo.ptr); return ctor(shapedType, numElements, contents); } + + template + py::buffer_info bufferInfo(MlirType shapedType, + Type (*value)(MlirAttribute, intptr_t)) { + intptr_t rank = mlirShapedTypeGetRank(shapedType); + // Prepare the data for the buffer_info. + Type *data = reinterpret_cast( + const_cast(mlirDenseElementsAttrGetRawData(this->attr))); + // Prepare the shape for the buffer_info. + std::vector shape; + for (intptr_t i = 0; i < rank; ++i) + shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); + // Prepare the strides for the buffer_info. + std::vector strides; + intptr_t strideFactor = 1; + for (intptr_t i = 1; i < rank; ++i) { + strideFactor = 1; + for (intptr_t j = i; j < rank; ++j) { + strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); + } + strides.push_back(sizeof(Type) * strideFactor); + } + strides.push_back(sizeof(Type)); + return py::buffer_info(data, sizeof(Type), + py::format_descriptor::format(), rank, shape, + strides); + } }; /// Refinement of the PyDenseElementsAttribute for attributes containing integer diff --git a/mlir/lib/CAPI/IR/StandardAttributes.cpp b/mlir/lib/CAPI/IR/StandardAttributes.cpp --- a/mlir/lib/CAPI/IR/StandardAttributes.cpp +++ b/mlir/lib/CAPI/IR/StandardAttributes.cpp @@ -516,6 +516,14 @@ pos)); } +//===----------------------------------------------------------------------===// +// Raw data accessors. + +const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { + return static_cast( + unwrap(attr).cast().getRawData().data()); +} + //===----------------------------------------------------------------------===// // Opaque elements attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Bindings/Python/ir_array_attributes.py b/mlir/test/Bindings/Python/ir_array_attributes.py --- a/mlir/test/Bindings/Python/ir_array_attributes.py +++ b/mlir/test/Bindings/Python/ir_array_attributes.py @@ -106,6 +106,9 @@ print(attr) # CHECK: is_splat: False print("is_splat:", attr.is_splat) + # CHECK: {{\[}}[1.1 2.2 3.3] + # CHECK: {{\[}}4.4 5.5 6.6]] + print(np.array(attr)) run(testGetDenseElementsF32) @@ -117,6 +120,9 @@ attr = DenseElementsAttr.get(array) # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64> print(attr) + # CHECK: {{\[}}[1.1 2.2 3.3] + # CHECK: {{\[}}4.4 5.5 6.6]] + print(np.array(attr)) run(testGetDenseElementsF64) @@ -129,6 +135,9 @@ attr = DenseElementsAttr.get(array) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) run(testGetDenseElementsI32Signless) @@ -140,6 +149,9 @@ attr = DenseElementsAttr.get(array) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) run(testGetDenseElementsUI32Signless) @@ -150,6 +162,9 @@ attr = DenseElementsAttr.get(array, signless=False) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32> print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) run(testGetDenseElementsI32) @@ -161,6 +176,9 @@ attr = DenseElementsAttr.get(array, signless=False) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32> print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) run(testGetDenseElementsUI32) @@ -173,6 +191,9 @@ attr = DenseElementsAttr.get(array) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) run(testGetDenseElementsI64Signless) @@ -184,6 +205,9 @@ attr = DenseElementsAttr.get(array) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) run(testGetDenseElementsUI64Signless) @@ -194,6 +218,9 @@ attr = DenseElementsAttr.get(array, signless=False) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64> print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) run(testGetDenseElementsI64) @@ -205,6 +232,9 @@ attr = DenseElementsAttr.get(array, signless=False) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64> print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) run(testGetDenseElementsUI64)