diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -688,13 +688,6 @@ intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } py::buffer_info accessBuffer() { - if (mlirDenseElementsAttrIsSplat(*this)) { - // TODO: Currently crashes the program. - // Reported as https://github.com/pybind/pybind11/issues/3336 - throw std::invalid_argument( - "unsupported data type for conversion to Python buffer"); - } - MlirType shapedType = mlirAttributeGetType(*this); MlirType elementType = mlirShapedTypeGetElementType(shapedType); std::string format; @@ -821,15 +814,18 @@ shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); // Prepare the strides for the buffer_info. SmallVector strides; - intptr_t strideFactor = 1; - for (intptr_t i = 1; i < rank; ++i) { - strideFactor = 1; - for (intptr_t j = i; j < rank; ++j) { - strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); + if (mlirDenseElementsAttrIsSplat(*this)) { + // Splats are special, only the single value is stored. + strides.assign(rank, 0); + } else { + for (intptr_t i = 1; i < rank; ++i) { + intptr_t strideFactor = 1; + for (intptr_t j = i; j < rank; ++j) + strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); + strides.push_back(sizeof(Type) * strideFactor); } - strides.push_back(sizeof(Type) * strideFactor); + strides.push_back(sizeof(Type)); } - strides.push_back(sizeof(Type)); std::string format; if (explicitFormat) { format = explicitFormat; diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py --- a/mlir/test/python/ir/array_attributes.py +++ b/mlir/test/python/ir/array_attributes.py @@ -100,10 +100,9 @@ print(attr) # CHECK: is_splat: True print("is_splat:", attr.is_splat) - # TODO: Re-enable this once a solution is found to raising an exception - # from buffer protocol. - # Reported as https://github.com/pybind/pybind11/issues/3336 - # print(np.array(attr)) + # CHECK{LITERAL}: [[1. 1. 1.] + # CHECK{LITERAL}: [1. 1. 1.]] + print(np.array(attr)) # CHECK-LABEL: TEST: testNonSplat