Index: mlir/lib/Bindings/Python/IRAttributes.cpp =================================================================== --- mlir/lib/Bindings/Python/IRAttributes.cpp +++ mlir/lib/Bindings/Python/IRAttributes.cpp @@ -7,12 +7,15 @@ //===----------------------------------------------------------------------===// #include +#include #include #include "IRModule.h" #include "PybindUtils.h" +#include "llvm/ADT/ScopeExit.h" + #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir/Bindings/Python/PybindAdaptors.h" @@ -612,19 +615,20 @@ std::optional> explicitShape, DefaultingPyMlirContext contextWrapper) { // Request a contiguous view. In exotic cases, this will cause a copy. - int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; - Py_buffer *view = new Py_buffer(); - if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { - delete view; + int flags = PyBUF_ND; + if (!explicitType) { + flags |= PyBUF_FORMAT; + } + Py_buffer view; + if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { throw py::error_already_set(); } - py::buffer_info arrayInfo(view); + auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); SmallVector shape; if (explicitShape) { shape.append(explicitShape->begin(), explicitShape->end()); } else { - shape.append(arrayInfo.shape.begin(), - arrayInfo.shape.begin() + arrayInfo.ndim); + shape.append(view.shape, view.shape + view.ndim); } MlirAttribute encodingAttr = mlirAttributeGetNull(); @@ -638,85 +642,92 @@ std::optional bulkLoadElementType; if (explicitType) { bulkLoadElementType = *explicitType; - } else if (arrayInfo.format == "f") { - // f32 - assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); - bulkLoadElementType = mlirF32TypeGet(context); - } else if (arrayInfo.format == "d") { - // f64 - assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); - bulkLoadElementType = mlirF64TypeGet(context); - } else if (arrayInfo.format == "e") { - // f16 - assert(arrayInfo.itemsize == 2 && "mismatched array itemsize"); - bulkLoadElementType = mlirF16TypeGet(context); - } else if (isSignedIntegerFormat(arrayInfo.format)) { - if (arrayInfo.itemsize == 4) { - // i32 - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeSignedGet(context, 32); - } else if (arrayInfo.itemsize == 8) { - // i64 - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeSignedGet(context, 64); - } else if (arrayInfo.itemsize == 1) { - // i8 - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) - : mlirIntegerTypeSignedGet(context, 8); - } else if (arrayInfo.itemsize == 2) { - // i16 - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16) - : mlirIntegerTypeSignedGet(context, 16); - } - } else if (isUnsignedIntegerFormat(arrayInfo.format)) { - if (arrayInfo.itemsize == 4) { - // unsigned i32 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeUnsignedGet(context, 32); - } else if (arrayInfo.itemsize == 8) { - // unsigned i64 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeUnsignedGet(context, 64); - } else if (arrayInfo.itemsize == 1) { - // i8 - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) - : mlirIntegerTypeUnsignedGet(context, 8); - } else if (arrayInfo.itemsize == 2) { - // i16 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 16) - : mlirIntegerTypeUnsignedGet(context, 16); - } - } - if (bulkLoadElementType) { - MlirType shapedType; - if (mlirTypeIsAShaped(*bulkLoadElementType)) { - if (explicitShape) { - throw std::invalid_argument("Shape can only be specified explicitly " - "when the type is not a shaped type."); + } else { + std::string_view format(view.format); + if (format == "f") { + // f32 + assert(view.itemsize == 4 && "mismatched array itemsize"); + bulkLoadElementType = mlirF32TypeGet(context); + } else if (format == "d") { + // f64 + assert(view.itemsize == 8 && "mismatched array itemsize"); + bulkLoadElementType = mlirF64TypeGet(context); + } else if (format == "e") { + // f16 + assert(view.itemsize == 2 && "mismatched array itemsize"); + bulkLoadElementType = mlirF16TypeGet(context); + } else if (isSignedIntegerFormat(format)) { + if (view.itemsize == 4) { + // i32 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeSignedGet(context, 32); + } else if (view.itemsize == 8) { + // i64 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeSignedGet(context, 64); + } else if (view.itemsize == 1) { + // i8 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeSignedGet(context, 8); + } else if (view.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeSignedGet(context, 16); + } + } else if (isUnsignedIntegerFormat(format)) { + if (view.itemsize == 4) { + // unsigned i32 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeUnsignedGet(context, 32); + } else if (view.itemsize == 8) { + // unsigned i64 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeUnsignedGet(context, 64); + } else if (view.itemsize == 1) { + // i8 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeUnsignedGet(context, 8); + } else if (view.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeUnsignedGet(context, 16); } - shapedType = *bulkLoadElementType; - } else { - shapedType = mlirRankedTensorTypeGet( - shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); } - size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize; - MlirAttribute attr = mlirDenseElementsAttrRawBufferGet( - shapedType, rawBufferSize, arrayInfo.ptr); - if (mlirAttributeIsNull(attr)) { + if (!bulkLoadElementType) { throw std::invalid_argument( - "DenseElementsAttr could not be constructed from the given buffer. " - "This may mean that the Python buffer layout does not match that " - "MLIR expected layout and is a bug."); + std::string("unimplemented array format conversion from format: ") + + std::string(format)); } - return PyDenseElementsAttribute(contextWrapper->getRef(), attr); } - throw std::invalid_argument( - std::string("unimplemented array format conversion from format: ") + - arrayInfo.format); + MlirType shapedType; + if (mlirTypeIsAShaped(*bulkLoadElementType)) { + if (explicitShape) { + throw std::invalid_argument("Shape can only be specified explicitly " + "when the type is not a shaped type."); + } + shapedType = *bulkLoadElementType; + } else { + shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), + *bulkLoadElementType, encodingAttr); + } + size_t rawBufferSize = view.len; + MlirAttribute attr = + mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf); + if (mlirAttributeIsNull(attr)) { + throw std::invalid_argument( + "DenseElementsAttr could not be constructed from the given buffer. " + "This may mean that the Python buffer layout does not match that " + "MLIR expected layout and is a bug."); + } + return PyDenseElementsAttribute(contextWrapper->getRef(), attr); } static PyDenseElementsAttribute getSplat(const PyType &shapedType, @@ -852,7 +863,7 @@ } private: - static bool isUnsignedIntegerFormat(const std::string &format) { + static bool isUnsignedIntegerFormat(std::string_view format) { if (format.empty()) return false; char code = format[0]; @@ -860,7 +871,7 @@ code == 'Q'; } - static bool isSignedIntegerFormat(const std::string &format) { + static bool isSignedIntegerFormat(std::string_view format) { if (format.empty()) return false; char code = format[0]; Index: mlir/test/python/ir/array_attributes.py =================================================================== --- mlir/test/python/ir/array_attributes.py +++ mlir/test/python/ir/array_attributes.py @@ -30,6 +30,24 @@ # CHECK: unimplemented array format conversion from format: print(e) +# CHECK-LABEL: TEST: testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided +@run +def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided(): + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) + # datetime64 specifically isn't important: it's just a 64-bit type that + # doesn't have a format under the Python buffer protocol. A more + # realistic example would be a NumPy extension type like the bfloat16 + # type from the ml_dtypes package, which isn't a dependency of this + # test. + attr = DenseElementsAttr.get(array.view(np.datetime64), + type=IntegerType.get_signless(64)) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) + ################################################################################ # Splats.