diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -11,11 +11,15 @@ #include "mlir-c/StandardAttributes.h" #include "mlir-c/StandardTypes.h" +#include "llvm/ADT/SmallVector.h" +#include namespace py = pybind11; using namespace mlir; using namespace mlir::python; +using llvm::SmallVector; + //------------------------------------------------------------------------------ // Docstrings (trivial, non-duplicated docstrings are included inline). //------------------------------------------------------------------------------ @@ -106,6 +110,20 @@ } // namespace +//------------------------------------------------------------------------------ +// Type-checking utilities. +//------------------------------------------------------------------------------ + +namespace { + +/// Checks whether the given type is an integer or float type. +int mlirTypeIsAIntegerOrFloat(MlirType type) { + return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || + mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); +} + +} // namespace + //------------------------------------------------------------------------------ // PyAttribute. //------------------------------------------------------------------------------ @@ -401,6 +419,102 @@ } }; +/// Complex Type subclass - ComplexType. +class PyComplexType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; + static constexpr const char *pyClassName = "ComplexType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_complex", + [](PyType &elementType) { + // The element must be a floating point or integer scalar type. + if (mlirTypeIsAIntegerOrFloat(elementType.type)) { + MlirType t = mlirComplexTypeGet(elementType.type); + return PyComplexType(t); + } + throw SetPyError( + PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point or integer type."); + }, + py::keep_alive<0, 1>(), "Create a complex type"); + c.def_property_readonly( + "element_type", + [](PyComplexType &self) -> PyType { + MlirType t = mlirComplexTypeGetElementType(self.type); + return PyType(t); + }, + "Returns element type."); + } +}; + +/// Vector Type subclass - VectorType. +class PyVectorType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; + static constexpr const char *pyClassName = "VectorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_vector", + [](std::vector shape, PyType &elementType) { + // The element must be a floating point or integer scalar type. + if (mlirTypeIsAIntegerOrFloat(elementType.type)) { + MlirType t = + mlirVectorTypeGet(shape.size(), shape.data(), elementType.type); + return PyVectorType(t); + } + throw SetPyError( + PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point or integer type."); + }, + py::keep_alive<0, 2>(), "Create a vector type"); + } +}; + +/// Tuple Type subclass - TupleType. +class PyTupleType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; + static constexpr const char *pyClassName = "TupleType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_tuple", + [](PyMlirContext &context, py::list elementList) { + intptr_t num = py::len(elementList); + // Mapping py::list to SmallVector. + SmallVector elements; + for (auto element : elementList) + elements.push_back(element.cast().type); + MlirType t = mlirTupleTypeGet(context.context, num, elements.data()); + return PyTupleType(t); + }, + py::keep_alive<0, 1>(), "Create a tuple type"); + c.def( + "get_type", + [](PyTupleType &self, intptr_t pos) -> PyType { + MlirType t = mlirTupleTypeGetType(self.type, pos); + return PyType(t); + }, + py::keep_alive<0, 1>(), "Returns the pos-th type in the tuple type."); + c.def_property_readonly( + "num_types", + [](PyTupleType &self) -> intptr_t { + return mlirTupleTypeGetNumTypes(self.type); + }, + "Returns the number of types contained in a tuple."); + } +}; + } // namespace //------------------------------------------------------------------------------ @@ -591,4 +705,7 @@ PyF32Type::bind(m); PyF64Type::bind(m); PyNoneType::bind(m); + PyComplexType::bind(m); + PyVectorType::bind(m); + PyTupleType::bind(m); } diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -154,3 +154,61 @@ print("none type:", mlir.ir.NoneType(ctx)) run(testNoneType) + +# CHECK-LABEL: TEST: testComplexType +def testComplexType(): + ctx = mlir.ir.Context() + complex_i32 = mlir.ir.ComplexType(ctx.parse_type("complex")) + # CHECK: complex type element: i32 + print("complex type element:", complex_i32.element_type) + + f32 = mlir.ir.F32Type(ctx) + # CHECK: complex type: complex + print("complex type:", mlir.ir.ComplexType.get_complex(f32)) + + index = mlir.ir.IndexType(ctx) + try: + complex_invalid = mlir.ir.ComplexType.get_complex(index) + except ValueError as e: + # CHECK: invalid 'Type(index)' and expected floating point or integer type. + print(e) + else: + print("Exception not produced") + +run(testComplexType) + +# CHECK-LABEL: TEST: testVectorType +def testVectorType(): + ctx = mlir.ir.Context() + f32 = mlir.ir.F32Type(ctx) + shape = [2, 3] + # CHECK: vector type: vector<2x3xf32> + print("vector type:", mlir.ir.VectorType.get_vector(shape, f32)) + + index = mlir.ir.IndexType(ctx) + try: + vector_invalid = mlir.ir.VectorType.get_vector(shape, index) + except ValueError as e: + # CHECK: invalid 'Type(index)' and expected floating point or integer type. + print(e) + else: + print("Exception not produced") + +run(testVectorType) + +# CHECK-LABEL: TEST: testTupleType +def testTupleType(): + ctx = mlir.ir.Context() + i32 = mlir.ir.IntegerType(ctx.parse_type("i32")) + f32 = mlir.ir.F32Type(ctx) + vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>")) + l = [i32, f32, vector] + tuple_type = mlir.ir.TupleType.get_tuple(ctx, l) + # CHECK: tuple type: tuple> + print("tuple type:", tuple_type) + # CHECK: number of types: 3 + print("number of types:", tuple_type.num_types) + # CHECK: pos-th type in the tuple type: f32 + print("pos-th type in the tuple type:", tuple_type.get_type(1)) + +run(testTupleType)