diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -314,31 +314,34 @@ /// as the mlir.ir class (otherwise, it will trigger a recursive /// initialization). mlir_attribute_subclass(py::handle scope, const char *typeClassName, - IsAFunctionTy isaFunction, - const py::object &superClass) - : pure_subclass(scope, typeClassName, superClass) { - // Casting constructor. Note that defining an __init__ method is special - // and not yet generalized on pure_subclass (it requires a somewhat - // different cpp_function and other requirements on chaining to super - // __init__ make it more awkward to do generally). + IsAFunctionTy isaFunction, const py::object &superCls) + : pure_subclass(scope, typeClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in pybind11 due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. std::string captureTypeName( typeClassName); // As string in case if typeClassName is not static. - py::cpp_function initCf( - [superClass, isaFunction, captureTypeName](py::object self, - py::object otherType) { - MlirAttribute rawAttribute = py::cast(otherType); + py::cpp_function newCf( + [superCls, isaFunction, captureTypeName](py::object cls, + py::object otherAttribute) { + MlirAttribute rawAttribute = py::cast(otherAttribute); if (!isaFunction(rawAttribute)) { - auto origRepr = py::repr(otherType).cast(); + auto origRepr = py::repr(otherAttribute).cast(); throw std::invalid_argument( (llvm::Twine("Cannot cast attribute to ") + captureTypeName + " (from " + origRepr + ")") .str()); } - superClass.attr("__init__")(self, otherType); + py::object self = superCls.attr("__new__")(cls, otherAttribute); + return self; }, - py::arg("cast_from_type"), py::is_method(py::none()), - "Casts the passed type to this specific sub-type."); - thisClass.attr("__init__") = initCf; + py::name("__new__"), py::arg("cls"), py::arg("cast_from_attr")); + thisClass.attr("__new__") = newCf; // 'isinstance' method. def_staticmethod( @@ -366,17 +369,21 @@ /// as the mlir.ir class (otherwise, it will trigger a recursive /// initialization). mlir_type_subclass(py::handle scope, const char *typeClassName, - IsAFunctionTy isaFunction, const py::object &superClass) - : pure_subclass(scope, typeClassName, superClass) { - // Casting constructor. Note that defining an __init__ method is special - // and not yet generalized on pure_subclass (it requires a somewhat - // different cpp_function and other requirements on chaining to super - // __init__ make it more awkward to do generally). + IsAFunctionTy isaFunction, const py::object &superCls) + : pure_subclass(scope, typeClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in pybind11 due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. std::string captureTypeName( typeClassName); // As string in case if typeClassName is not static. - py::cpp_function initCf( - [superClass, isaFunction, captureTypeName](py::object self, - py::object otherType) { + py::cpp_function newCf( + [superCls, isaFunction, captureTypeName](py::object cls, + py::object otherType) { MlirType rawType = py::cast(otherType); if (!isaFunction(rawType)) { auto origRepr = py::repr(otherType).cast(); @@ -385,11 +392,11 @@ origRepr + ")") .str()); } - superClass.attr("__init__")(self, otherType); + py::object self = superCls.attr("__new__")(cls, otherType); + return self; }, - py::arg("cast_from_type"), py::is_method(py::none()), - "Casts the passed type to this specific sub-type."); - thisClass.attr("__init__") = initCf; + py::name("__new__"), py::arg("cls"), py::arg("cast_from_type")); + thisClass.attr("__new__") = newCf; // 'isinstance' method. def_staticmethod(