diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -106,6 +106,20 @@ } // namespace +//------------------------------------------------------------------------------ +// Type-checking utilities. +//------------------------------------------------------------------------------ + +namespace { + +/// Checks whether the given type is an integer or float type. +int mlirTypeIsAIntegerOrFloat(MlirType type) { + return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || + mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); +} + +} // namespace + //------------------------------------------------------------------------------ // PyAttribute. //------------------------------------------------------------------------------ @@ -401,6 +415,98 @@ } }; +/// Complex Type subclass - ComplexType. +class PyComplexType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; + static constexpr const char *pyClassName = "ComplexType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_complex", + [](PyType &elementType) { + // The element must be a floating point or integer scalar type. + if (mlirTypeIsAIntegerOrFloat(elementType.type)) { + MlirType t = mlirComplexTypeGet(elementType.type); + return PyComplexType(t); + } else { + throw SetPyError(PyExc_ValueError, + "invalid element type for complex type."); + } + }, + py::keep_alive<0, 1>(), "Create a complex type"); + c.def_property_readonly( + "element_type", + [](PyComplexType &self) -> PyType { + MlirType t = mlirComplexTypeGetElementType(self.type); + return PyType(t); + }, + "Returns element type."); + } +}; + +/// Vector Type subclass - VectorType. +class PyVectorType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; + static constexpr const char *pyClassName = "VectorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_vector", + [](intptr_t rank, int64_t *shape, PyType &elementType) { + // The element must be a floating point or integer scalar type. + if (mlirTypeIsAIntegerOrFloat(elementType.type)) { + MlirType t = mlirVectorTypeGet(rank, shape, elementType.type); + return PyVectorType(t); + } else { + throw SetPyError(PyExc_ValueError, + "invalid element type for vector type."); + } + }, + py::keep_alive<0, 1>(), "Create a vector type"); + } +}; + +/// Tuple Type subclass - TupleType. +class PyTupleType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; + static constexpr const char *pyClassName = "TupleType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_tuple", + [](PyMlirContext &context, py::list elementList) { + intptr_t num = py::len(elementList); + // Mapping py::list to dynamic array. + MlirType *elements = new MlirType[num]; + for (intptr_t i = 0; i < num; i++) + elements[i] = elementList[i].cast().type; + MlirType t = mlirTupleTypeGet(context.context, num, elements); + delete[] elements; + return PyTupleType(t); + }, + py::keep_alive<0, 1>(), "Create a tuple type"); + c.def( + "get_type", + [](PyTupleType &self, intptr_t pos) -> PyType { + MlirType t = mlirTupleTypeGetType(self.type, pos); + return PyType(t); + }, + py::keep_alive<0, 1>(), "Returns the pos-th type in the tuple type."); + c.def_property_readonly( + "types_num", + [](PyTupleType &self) -> intptr_t { + return mlirTupleTypeGetNumTypes(self.type); + }, + "Returns the number of types contained in a tuple."); + } +}; + } // namespace //------------------------------------------------------------------------------ @@ -591,4 +697,7 @@ PyF32Type::bind(m); PyF64Type::bind(m); PyNoneType::bind(m); + PyComplexType::bind(m); + PyVectorType::bind(m); + PyTupleType::bind(m); } diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -154,3 +154,60 @@ print("none type:", mlir.ir.NoneType(ctx)) run(testNoneType) + +# CHECK-LABEL: TEST: testComplexType +def testComplexType(): + ctx = mlir.ir.Context() + complex_i32 = mlir.ir.ComplexType(ctx.parse_type("complex")) + # CHECK: complex type element: i32 + print("complex type element:", complex_i32.element_type) + + f32 = mlir.ir.F32Type(ctx) + # CHECK: complex type: complex + print("complex type:", mlir.ir.ComplexType.get_complex(f32)) + + index = mlir.ir.IndexType(ctx) + try: + complex_invalid = mlir.ir.ComplexType.get_complex(index) + except ValueError as e: + # CHECK: invalid element type for complex type. + print(e) + else: + print("Exception not produced") + +run(testComplexType) + +# CHECK-LABEL: TEST: testVectorType +def testVectorType(): + ctx = mlir.ir.Context() + f32 = mlir.ir.F32Type(ctx) + # CHECK: vector type: vector<3x2xf32> + print("vector type:", mlir.ir.VectorType.get_vector(2, 3, f32)) + + index = mlir.ir.IndexType(ctx) + try: + vector_invalid = mlir.ir.VectorType.get_vector(2, 3, index) + except ValueError as e: + # CHECK: invalid element type for vector type. + print(e) + else: + print("Exception not produced") + +run(testVectorType) + +# CHECK-LABEL: TEST: testTupleType +def testTupleType(): + ctx = mlir.ir.Context() + i32 = mlir.ir.IntegerType(ctx.parse_type("i32")) + f32 = mlir.ir.F32Type(ctx) + vector = mlir.ir.VectorType.get_vector(2, 3, f32) + l = [i32, f32, vector] + tuple_type = mlir.ir.TupleType.get_tuple(ctx, l) + # CHECK: tuple type: tuple> + print("tuple type:", tuple_type) + # CHECK: number of types: 3 + print("number of types:", tuple_type.types_num) + # CHECK: pos-th type in the tuple type: f32 + print("pos-th type in the tuple type:", tuple_type.get_type(1)) + +run(testTupleType)