diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -221,34 +221,37 @@ /// CRTP base classes for Python attributes that subclass Attribute and should /// be castable from it (i.e. via something like StringAttr(attr)). -template -class PyConcreteAttribute : public PyAttribute { +/// By default, attribute class hierarchies are one level deep (i.e. a +/// concrete attribute class extends PyAttribute); however, intermediate +/// python-visible base classes can be modeled by specifying a BaseTy. +template +class PyConcreteAttribute : public BaseTy { public: // Derived classes must define statics for: // IsAFunctionTy isaFunction // const char *pyClassName - using ClassTy = py::class_; + using ClassTy = py::class_; using IsAFunctionTy = int (*)(MlirAttribute); PyConcreteAttribute() = default; - PyConcreteAttribute(MlirAttribute attr) : PyAttribute(attr) {} + PyConcreteAttribute(MlirAttribute attr) : BaseTy(attr) {} PyConcreteAttribute(PyAttribute &orig) : PyConcreteAttribute(castFrom(orig)) {} static MlirAttribute castFrom(PyAttribute &orig) { - if (!T::isaFunction(orig.attr)) { + if (!DerivedTy::isaFunction(orig.attr)) { auto origRepr = py::repr(py::cast(orig)).cast(); throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast attribute to ") + - T::pyClassName + " (from " + origRepr + ")"); + DerivedTy::pyClassName + " (from " + origRepr + ")"); } return orig.attr; } static void bind(py::module &m) { - auto cls = ClassTy(m, T::pyClassName); + auto cls = ClassTy(m, DerivedTy::pyClassName); cls.def(py::init(), py::keep_alive<0, 1>()); - T::bindDerived(cls); + DerivedTy::bindDerived(cls); } /// Implemented by derived classes to add methods to the Python subclass. @@ -301,33 +304,36 @@ /// CRTP base classes for Python types that subclass Type and should be /// castable from it (i.e. via something like IntegerType(t)). -template -class PyConcreteType : public PyType { +/// By default, type class hierarchies are one level deep (i.e. a +/// concrete type class extends PyType); however, intermediate python-visible +/// base classes can be modeled by specifying a BaseTy. +template +class PyConcreteType : public BaseTy { public: // Derived classes must define statics for: // IsAFunctionTy isaFunction // const char *pyClassName - using ClassTy = py::class_; + using ClassTy = py::class_; using IsAFunctionTy = int (*)(MlirType); PyConcreteType() = default; - PyConcreteType(MlirType t) : PyType(t) {} - PyConcreteType(PyType &orig) : PyType(castFrom(orig)) {} + PyConcreteType(MlirType t) : BaseTy(t) {} + PyConcreteType(PyType &orig) : PyConcreteType(castFrom(orig)) {} static MlirType castFrom(PyType &orig) { - if (!T::isaFunction(orig.type)) { + if (!DerivedTy::isaFunction(orig.type)) { auto origRepr = py::repr(py::cast(orig)).cast(); throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + - T::pyClassName + " (from " + - origRepr + ")"); + DerivedTy::pyClassName + + " (from " + origRepr + ")"); } return orig.type; } static void bind(py::module &m) { - auto cls = ClassTy(m, T::pyClassName); + auto cls = ClassTy(m, DerivedTy::pyClassName); cls.def(py::init(), py::keep_alive<0, 1>()); - T::bindDerived(cls); + DerivedTy::bindDerived(cls); } /// Implemented by derived classes to add methods to the Python subclass. @@ -590,142 +596,130 @@ }; /// Vector Type subclass - VectorType. -class PyVectorType : public PyShapedType { +class PyVectorType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; static constexpr const char *pyClassName = "VectorType"; - using PyShapedType::PyShapedType; - // TODO: Switch back to bindDerived by making the ClassTy modifiable by - // subclasses, exposing the ShapedType hierarchy. - static void bind(py::module &m) { - py::class_(m, pyClassName) - .def(py::init(), py::keep_alive<0, 1>()) - .def_static( - "get_vector", - // TODO: Make the location optional and create a default location. - [](std::vector shape, PyType &elementType, - PyLocation &loc) { - MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(), - elementType.type, loc.loc); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - llvm::Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point or integer type."); - } - return PyVectorType(t); - }, - py::keep_alive<0, 2>(), "Create a vector type"); + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_vector", + // TODO: Make the location optional and create a default location. + [](std::vector shape, PyType &elementType, PyLocation &loc) { + MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(), + elementType.type, loc.loc); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point or integer type."); + } + return PyVectorType(t); + }, + py::keep_alive<0, 2>(), "Create a vector type"); } }; /// Ranked Tensor Type subclass - RankedTensorType. -class PyRankedTensorType : public PyShapedType { +class PyRankedTensorType + : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; static constexpr const char *pyClassName = "RankedTensorType"; - using PyShapedType::PyShapedType; - // TODO: Switch back to bindDerived by making the ClassTy modifiable by - // subclasses, exposing the ShapedType hierarchy. - static void bind(py::module &m) { - py::class_(m, pyClassName) - .def(py::init(), py::keep_alive<0, 1>()) - .def_static( - "get_ranked_tensor", - // TODO: Make the location optional and create a default location. - [](std::vector shape, PyType &elementType, - PyLocation &loc) { - MlirType t = mlirRankedTensorTypeGetChecked( - shape.size(), shape.data(), elementType.type, loc.loc); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - llvm::Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } - return PyRankedTensorType(t); - }, - py::keep_alive<0, 2>(), "Create a ranked tensor type"); + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_ranked_tensor", + // TODO: Make the location optional and create a default location. + [](std::vector shape, PyType &elementType, PyLocation &loc) { + MlirType t = mlirRankedTensorTypeGetChecked( + shape.size(), shape.data(), elementType.type, loc.loc); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or " + "complex " + "type."); + } + return PyRankedTensorType(t); + }, + py::keep_alive<0, 2>(), "Create a ranked tensor type"); } }; /// Unranked Tensor Type subclass - UnrankedTensorType. -class PyUnrankedTensorType : public PyShapedType { +class PyUnrankedTensorType + : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; static constexpr const char *pyClassName = "UnrankedTensorType"; - using PyShapedType::PyShapedType; - // TODO: Switch back to bindDerived by making the ClassTy modifiable by - // subclasses, exposing the ShapedType hierarchy. - static void bind(py::module &m) { - py::class_(m, pyClassName) - .def(py::init(), py::keep_alive<0, 1>()) - .def_static( - "get_unranked_tensor", - // TODO: Make the location optional and create a default location. - [](PyType &elementType, PyLocation &loc) { - MlirType t = - mlirUnrankedTensorTypeGetChecked(elementType.type, loc.loc); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - llvm::Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } - return PyUnrankedTensorType(t); - }, - py::keep_alive<0, 1>(), "Create a unranked tensor type"); + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_unranked_tensor", + // TODO: Make the location optional and create a default location. + [](PyType &elementType, PyLocation &loc) { + MlirType t = + mlirUnrankedTensorTypeGetChecked(elementType.type, loc.loc); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or " + "complex " + "type."); + } + return PyUnrankedTensorType(t); + }, + py::keep_alive<0, 1>(), "Create a unranked tensor type"); } }; /// Ranked MemRef Type subclass - MemRefType. -class PyMemRefType : public PyShapedType { +class PyMemRefType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; static constexpr const char *pyClassName = "MemRefType"; - using PyShapedType::PyShapedType; - // TODO: Switch back to bindDerived by making the ClassTy modifiable by - // subclasses, exposing the ShapedType hierarchy. - static void bind(py::module &m) { - py::class_(m, pyClassName) - .def(py::init(), py::keep_alive<0, 1>()) - // TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding - // once the affine map binding is completed. - .def_static( - "get_contiguous_memref", - // TODO: Make the location optional and create a default location. - [](PyType &elementType, std::vector shape, - unsigned memorySpace, PyLocation &loc) { - MlirType t = mlirMemRefTypeContiguousGetChecked( - elementType.type, shape.size(), shape.data(), memorySpace, - loc.loc); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - llvm::Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } - return PyMemRefType(t); - }, - py::keep_alive<0, 1>(), "Create a memref type") + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + // TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding + // once the affine map binding is completed. + c.def_static( + "get_contiguous_memref", + // TODO: Make the location optional and create a default location. + [](PyType &elementType, std::vector shape, + unsigned memorySpace, PyLocation &loc) { + MlirType t = mlirMemRefTypeContiguousGetChecked( + elementType.type, shape.size(), shape.data(), memorySpace, + loc.loc); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or " + "complex " + "type."); + } + return PyMemRefType(t); + }, + py::keep_alive<0, 1>(), "Create a memref type") .def_property_readonly( "num_affine_maps", [](PyMemRefType &self) -> intptr_t { @@ -743,36 +737,34 @@ }; /// Unranked MemRef Type subclass - UnrankedMemRefType. -class PyUnrankedMemRefType : public PyShapedType { +class PyUnrankedMemRefType + : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; static constexpr const char *pyClassName = "UnrankedMemRefType"; - using PyShapedType::PyShapedType; - // TODO: Switch back to bindDerived by making the ClassTy modifiable by - // subclasses, exposing the ShapedType hierarchy. - static void bind(py::module &m) { - py::class_(m, pyClassName) - .def(py::init(), py::keep_alive<0, 1>()) - .def_static( - "get_unranked_memref", - // TODO: Make the location optional and create a default location. - [](PyType &elementType, unsigned memorySpace, PyLocation &loc) { - MlirType t = mlirUnrankedMemRefTypeGetChecked( - elementType.type, memorySpace, loc.loc); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - llvm::Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } - return PyUnrankedMemRefType(t); - }, - py::keep_alive<0, 1>(), "Create a unranked memref type") + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_unranked_memref", + // TODO: Make the location optional and create a default location. + [](PyType &elementType, unsigned memorySpace, PyLocation &loc) { + MlirType t = mlirUnrankedMemRefTypeGetChecked(elementType.type, + memorySpace, loc.loc); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or " + "complex " + "type."); + } + return PyUnrankedMemRefType(t); + }, + py::keep_alive<0, 1>(), "Create a unranked memref type") .def_property_readonly( "memory_space", [](PyUnrankedMemRefType &self) -> unsigned { diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -177,11 +177,11 @@ run(testComplexType) -# CHECK-LABEL: TEST: testShapedType +# CHECK-LABEL: TEST: testConcreteShapedType # Shaped type is not a kind of standard types, it is the base class for # vectors, memrefs and tensors, so this test case uses an instance of vector -# to test the shaped type. -def testShapedType(): +# to test the shaped type. The class hierarchy is preserved on the python side. +def testConcreteShapedType(): ctx = mlir.ir.Context() vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>")) # CHECK: element type: f32 @@ -196,12 +196,25 @@ print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0)) # CHECK: dim size: 3 print("dim size:", vector.get_dim_size(1)) - # CHECK: False - print(vector.is_dynamic_size(3)) - # CHECK: False - print(vector.is_dynamic_stride_or_offset(1)) + # CHECK: is_dynamic_size: False + print("is_dynamic_size:", vector.is_dynamic_size(3)) + # CHECK: is_dynamic_stride_or_offset: False + print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1)) + # CHECK: isinstance(ShapedType): True + print("isinstance(ShapedType):", isinstance(vector, mlir.ir.ShapedType)) + +run(testConcreteShapedType) + +# CHECK-LABEL: TEST: testAbstractShapedType +# Tests that ShapedType operates as an abstract base class of a concrete +# shaped type (using vector as an example). +def testAbstractShapedType(): + ctx = mlir.ir.Context() + vector = mlir.ir.ShapedType(ctx.parse_type("vector<2x3xf32>")) + # CHECK: element type: f32 + print("element type:", vector.element_type) -run(testShapedType) +run(testAbstractShapedType) # CHECK-LABEL: TEST: testVectorType def testVectorType():