diff --git a/mlir/include/mlir-c/StandardAttributes.h b/mlir/include/mlir-c/StandardAttributes.h --- a/mlir/include/mlir-c/StandardAttributes.h +++ b/mlir/include/mlir-c/StandardAttributes.h @@ -277,6 +277,9 @@ * shaped type and use its sizes to build a multi-dimensional index. */ int64_t mlirElementsAttrGetNumElements(MlirAttribute attr); +/** Returns the type of the given elements attribute. */ +MlirType mlirElementsAttrGetType(MlirAttribute attr); + //===----------------------------------------------------------------------===// // Dense elements attribute. //===----------------------------------------------------------------------===// @@ -381,6 +384,9 @@ MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos); +/** Returns the raw data of the given dense elements attribute. */ +const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr); + //===----------------------------------------------------------------------===// // Opaque elements attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1321,7 +1321,7 @@ } static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); + auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol()); cls.def(py::init(), py::keep_alive<0, 1>()); DerivedTy::bindDerived(cls); } @@ -1580,6 +1580,48 @@ return PyDenseElementsAttribute(contextWrapper->getRef(), elements); } + py::buffer_info accessBuffer() { + MlirType shapedType = mlirElementsAttrGetType(this->attr); + MlirType elementType = mlirShapedTypeGetElementType(shapedType); + int64_t size = mlirElementsAttrGetNumElements(this->attr); + int64_t rank = mlirShapedTypeGetRank(shapedType); + + if (mlirTypeIsAF32(elementType) == 1) + // f32 + return bufferInfo(shapedType, size, rank, + mlirDenseElementsAttrGetFloatValue); + else if (mlirTypeIsAF64(elementType) == 1) + // f64 + return bufferInfo(shapedType, size, rank, + mlirDenseElementsAttrGetDoubleValue); + else if (mlirTypeIsAInteger(elementType) == 1 && + mlirIntegerTypeGetWidth(elementType) == 32) { + if (mlirIntegerTypeIsSignless(elementType) == 1 || + mlirIntegerTypeIsSigned(elementType) == 1) + // i32 + return bufferInfo(shapedType, size, rank, + mlirDenseElementsAttrGetInt32Value); + else if (mlirIntegerTypeIsUnsigned(elementType) == 1) + // unsigned i32 + return bufferInfo(shapedType, size, rank, + mlirDenseElementsAttrGetUInt32Value); + } else if (mlirTypeIsAInteger(elementType) == 1 && + mlirIntegerTypeGetWidth(elementType) == 64) { + if (mlirIntegerTypeIsSignless(elementType) == 1 || + mlirIntegerTypeIsSigned(elementType) == 1) + // i64 + return bufferInfo(shapedType, size, rank, + mlirDenseElementsAttrGetInt64Value); + else if (mlirIntegerTypeIsUnsigned(elementType) == 1) + // unsigned i64 + return bufferInfo(shapedType, size, rank, + mlirDenseElementsAttrGetUInt64Value); + } + + std::string message = "unimplemented array format."; + throw SetPyError(PyExc_ValueError, message); + } + static void bindDerived(ClassTy &c) { c.def_static("get", PyDenseElementsAttribute::getFromBuffer, py::arg("array"), py::arg("signless") = true, @@ -1591,7 +1633,8 @@ .def_property_readonly("is_splat", [](PyDenseElementsAttribute &self) -> bool { return mlirDenseElementsAttrIsSplat(self.attr); - }); + }) + .def_buffer(&PyDenseElementsAttribute::accessBuffer); } private: @@ -1608,6 +1651,32 @@ const ElementTy *contents = static_cast(arrayInfo.ptr); return ctor(shapedType, numElements, contents); } + + template + py::buffer_info bufferInfo(MlirType shapedType, int64_t size, int64_t rank, + Type (*value)(MlirAttribute, intptr_t)) { + // Prepare the data for the buffer_info. + Type *data = reinterpret_cast( + const_cast(mlirDenseElementsAttrGetRawData(this->attr))); + // Prepare the shape for the buffer_info. + std::vector shape; + for (intptr_t i = 0; i < rank; ++i) + shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); + // Prepare the strides for the buffer_info. + std::vector strides; + int64_t strideFactor = 1; + for (intptr_t i = 1; i < rank; ++i) { + strideFactor = 1; + for (intptr_t j = i; j < rank; ++j) { + strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); + } + strides.push_back(sizeof(Type) * strideFactor); + } + strides.push_back(sizeof(Type)); + return py::buffer_info(data, sizeof(Type), + py::format_descriptor::format(), rank, shape, + strides); + } }; } // namespace diff --git a/mlir/lib/CAPI/IR/StandardAttributes.cpp b/mlir/lib/CAPI/IR/StandardAttributes.cpp --- a/mlir/lib/CAPI/IR/StandardAttributes.cpp +++ b/mlir/lib/CAPI/IR/StandardAttributes.cpp @@ -302,6 +302,10 @@ return unwrap(attr).cast().getNumElements(); } +MlirType mlirElementsAttrGetType(MlirAttribute attr) { + return wrap(unwrap(attr).cast().getType()); +} + //===----------------------------------------------------------------------===// // Dense elements attribute. //===----------------------------------------------------------------------===// @@ -516,6 +520,14 @@ pos)); } +//===----------------------------------------------------------------------===// +// Raw data accessors. + +const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { + return static_cast( + unwrap(attr).cast().getRawData().data()); +} + //===----------------------------------------------------------------------===// // Opaque elements attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Bindings/Python/ir_array_attributes.py b/mlir/test/Bindings/Python/ir_array_attributes.py --- a/mlir/test/Bindings/Python/ir_array_attributes.py +++ b/mlir/test/Bindings/Python/ir_array_attributes.py @@ -106,6 +106,9 @@ print(attr) # CHECK: is_splat: False print("is_splat:", attr.is_splat) + # CHECK: {{\[}}[1.1 2.2 3.3] + # CHECK: {{\[}}4.4 5.5 6.6]] + print(np.array(attr)) run(testGetDenseElementsF32) @@ -117,6 +120,9 @@ attr = DenseElementsAttr.get(array) # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64> print(attr) + # CHECK: {{\[}}[1.1 2.2 3.3] + # CHECK: {{\[}}4.4 5.5 6.6]] + print(np.array(attr)) run(testGetDenseElementsF64) @@ -129,6 +135,9 @@ attr = DenseElementsAttr.get(array) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) run(testGetDenseElementsI32Signless) @@ -140,6 +149,9 @@ attr = DenseElementsAttr.get(array) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) run(testGetDenseElementsUI32Signless) @@ -150,6 +162,9 @@ attr = DenseElementsAttr.get(array, signless=False) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32> print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) run(testGetDenseElementsI32) @@ -161,6 +176,9 @@ attr = DenseElementsAttr.get(array, signless=False) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32> print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) run(testGetDenseElementsUI32) @@ -173,6 +191,9 @@ attr = DenseElementsAttr.get(array) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) run(testGetDenseElementsI64Signless) @@ -184,6 +205,9 @@ attr = DenseElementsAttr.get(array) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) run(testGetDenseElementsUI64Signless) @@ -194,6 +218,9 @@ attr = DenseElementsAttr.get(array, signless=False) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64> print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) run(testGetDenseElementsI64) @@ -205,6 +232,9 @@ attr = DenseElementsAttr.get(array, signless=False) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64> print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) run(testGetDenseElementsUI64)