diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -777,6 +777,16 @@ [](PyDenseElementsAttribute &self) -> bool { return mlirDenseElementsAttrIsSplat(self); }) + .def("get_splat_value", + [](PyDenseElementsAttribute &self) -> PyAttribute { + if (!mlirDenseElementsAttrIsSplat(self)) { + throw SetPyError( + PyExc_ValueError, + "get_splat_value called on a non-splat attribute"); + } + return PyAttribute(self.getContext(), + mlirDenseElementsAttrGetSplatValue(self)); + }) .def_buffer(&PyDenseElementsAttribute::accessBuffer); } diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py --- a/mlir/test/python/ir/array_attributes.py +++ b/mlir/test/python/ir/array_attributes.py @@ -43,6 +43,7 @@ print(attr) # CHECK: is_splat: True print("is_splat:", attr.is_splat) + assert attr.get_splat_value() == element # CHECK-LABEL: TEST: testGetDenseElementsSplatFloat @@ -55,6 +56,7 @@ attr = DenseElementsAttr.get_splat(shaped_type, element) # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32> print(attr) + assert attr.get_splat_value() == element # CHECK-LABEL: TEST: testGetDenseElementsSplatErrors