diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -705,6 +705,49 @@ MlirType type; }; +/// CRTP base classes for Python types that subclass Type and should be +/// castable from it (i.e. via something like IntegerType(t)). +/// By default, type class hierarchies are one level deep (i.e. a +/// concrete type class extends PyType); however, intermediate python-visible +/// base classes can be modeled by specifying a BaseTy. +template +class PyConcreteType : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + using ClassTy = pybind11::class_; + using IsAFunctionTy = bool (*)(MlirType); + + PyConcreteType() = default; + PyConcreteType(PyMlirContextRef contextRef, MlirType t) + : BaseTy(std::move(contextRef), t) {} + PyConcreteType(PyType &orig) + : PyConcreteType(orig.getContext(), castFrom(orig)) {} + + static MlirType castFrom(PyType &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + + DerivedTy::pyClassName + + " (from " + origRepr + ")"); + } + return orig; + } + + static void bind(pybind11::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); + cls.def_static("isinstance", [](PyType &otherType) -> bool { + return DerivedTy::isaFunction(otherType); + }); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + /// Wrapper around the generic MlirValue. /// Values are managed completely by the operation that resulted in their /// definition. For op result value, this is the operation that defines the diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -28,49 +28,6 @@ mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); } -/// CRTP base classes for Python types that subclass Type and should be -/// castable from it (i.e. via something like IntegerType(t)). -/// By default, type class hierarchies are one level deep (i.e. a -/// concrete type class extends PyType); however, intermediate python-visible -/// base classes can be modeled by specifying a BaseTy. -template -class PyConcreteType : public BaseTy { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - using ClassTy = py::class_; - using IsAFunctionTy = bool (*)(MlirType); - - PyConcreteType() = default; - PyConcreteType(PyMlirContextRef contextRef, MlirType t) - : BaseTy(std::move(contextRef), t) {} - PyConcreteType(PyType &orig) - : PyConcreteType(orig.getContext(), castFrom(orig)) {} - - static MlirType castFrom(PyType &orig) { - if (!DerivedTy::isaFunction(orig)) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); - } - return orig; - } - - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); - cls.def(py::init(), py::keep_alive<0, 1>()); - cls.def_static("isinstance", [](PyType &otherType) -> bool { - return DerivedTy::isaFunction(otherType); - }); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - class PyIntegerType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;