diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -27,46 +27,6 @@ return mlirStringRefCreate(s.data(), s.size()); } -/// CRTP base classes for Python attributes that subclass Attribute and should -/// be castable from it (i.e. via something like StringAttr(attr)). -/// By default, attribute class hierarchies are one level deep (i.e. a -/// concrete attribute class extends PyAttribute); however, intermediate -/// python-visible base classes can be modeled by specifying a BaseTy. -template -class PyConcreteAttribute : public BaseTy { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - using ClassTy = py::class_; - using IsAFunctionTy = bool (*)(MlirAttribute); - - PyConcreteAttribute() = default; - PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) - : BaseTy(std::move(contextRef), attr) {} - PyConcreteAttribute(PyAttribute &orig) - : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} - - static MlirAttribute castFrom(PyAttribute &orig) { - if (!DerivedTy::isaFunction(orig)) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); - } - return orig; - } - - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol()); - cls.def(py::init(), py::keep_alive<0, 1>()); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - class PyAffineMapAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -642,6 +642,46 @@ std::unique_ptr ownedName; }; +/// CRTP base classes for Python attributes that subclass Attribute and should +/// be castable from it (i.e. via something like StringAttr(attr)). +/// By default, attribute class hierarchies are one level deep (i.e. a +/// concrete attribute class extends PyAttribute); however, intermediate +/// python-visible base classes can be modeled by specifying a BaseTy. +template +class PyConcreteAttribute : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + using ClassTy = pybind11::class_; + using IsAFunctionTy = bool (*)(MlirAttribute); + + PyConcreteAttribute() = default; + PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) + : BaseTy(std::move(contextRef), attr) {} + PyConcreteAttribute(PyAttribute &orig) + : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} + + static MlirAttribute castFrom(PyAttribute &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, + llvm::Twine("Cannot cast attribute to ") + + DerivedTy::pyClassName + " (from " + origRepr + ")"); + } + return orig; + } + + static void bind(pybind11::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol()); + cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + /// Wrapper around the generic MlirType. /// The lifetime of a type is bound by the PyContext that created it. class PyType : public BaseContextObject {