diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -162,9 +162,7 @@ c.def_static( "get", [](const std::vector &values, DefaultingPyMlirContext ctx) { - MlirAttribute attr = - DerivedT::getAttribute(ctx->get(), values.size(), values.data()); - return DerivedT(ctx->getRef(), attr); + return getAttribute(values, ctx->getRef()); }, py::arg("values"), py::arg("context") = py::none(), "Gets a uniqued dense array attribute"); @@ -187,16 +185,29 @@ values.push_back(arr.getItem(i)); for (py::handle attr : extras) values.push_back(pyTryCast(attr)); - MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(), - values.size(), values.data()); - return DerivedT(arr.getContext(), attr); + return getAttribute(values, arr.getContext()); }); } + +private: + static DerivedT getAttribute(const std::vector &values, + PyMlirContextRef ctx) { + if constexpr (std::is_same_v) { + std::vector intValues(values.begin(), values.end()); + MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(), + intValues.data()); + return DerivedT(ctx, attr); + } else { + MlirAttribute attr = + DerivedT::getAttribute(ctx->get(), values.size(), values.data()); + return DerivedT(ctx, attr); + } + } }; /// Instantiate the python dense array classes. struct PyDenseBoolArrayAttribute - : public PyDenseArrayAttribute { + : public PyDenseArrayAttribute { static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; static constexpr auto getAttribute = mlirDenseBoolArrayGet; static constexpr auto getElement = mlirDenseBoolArrayGetElement; diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -344,7 +344,7 @@ print(f"{len(attr)}: {attr[0]}, {attr[1]}") with Context(): - # CHECK: 2: 0, 1 + # CHECK: 2: False, True print_item("array") # CHECK: 2: 2, 3 print_item("array") @@ -359,6 +359,13 @@ # CHECK: 2: 3.{{0+}}, 4.{{0+}} print_item("array") + class MyBool: + def __bool__(self): + return True + + # CHECK: myboolarray: array + print("myboolarray:", DenseBoolArrayAttr.get([MyBool()])) + # CHECK-LABEL: TEST: testDenseIntAttrGetItem @run