diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -305,6 +305,102 @@ } }; +/// Index Type subclass - IndexType. +class PyIndexType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; + static constexpr const char *pyClassName = "IndexType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def(py::init([](PyMlirContext &context) { + MlirType t = mlirIndexTypeGet(context.context); + return PyIndexType(t); + }), + py::keep_alive<0, 1>(), "Create a index type."); + } +}; + +/// Floating Point Type subclass - BF16Type. +class PyBF16Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; + static constexpr const char *pyClassName = "BF16Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def(py::init([](PyMlirContext &context) { + MlirType t = mlirBF16TypeGet(context.context); + return PyBF16Type(t); + }), + py::keep_alive<0, 1>(), "Create a bf16 type."); + } +}; + +/// Floating Point Type subclass - F16Type. +class PyF16Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; + static constexpr const char *pyClassName = "F16Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def(py::init([](PyMlirContext &context) { + MlirType t = mlirF16TypeGet(context.context); + return PyF16Type(t); + }), + py::keep_alive<0, 1>(), "Create a f16 type."); + } +}; + +/// Floating Point Type subclass - F32Type. +class PyF32Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; + static constexpr const char *pyClassName = "F32Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def(py::init([](PyMlirContext &context) { + MlirType t = mlirF32TypeGet(context.context); + return PyF32Type(t); + }), + py::keep_alive<0, 1>(), "Create a f32 type."); + } +}; + +/// Floating Point Type subclass - F64Type. +class PyF64Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; + static constexpr const char *pyClassName = "F64Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def(py::init([](PyMlirContext &context) { + MlirType t = mlirF64TypeGet(context.context); + return PyF64Type(t); + }), + py::keep_alive<0, 1>(), "Create a f64 type."); + } +}; + +/// None Type subclass - NoneType. +class PyNoneType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; + static constexpr const char *pyClassName = "NoneType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def(py::init([](PyMlirContext &context) { + MlirType t = mlirNoneTypeGet(context.context); + return PyNoneType(t); + }), + py::keep_alive<0, 1>(), "Create a none type."); + } +}; + } // namespace //------------------------------------------------------------------------------ @@ -489,4 +585,10 @@ // Standard type bindings. PyIntegerType::bind(m); + PyIndexType::bind(m); + PyBF16Type::bind(m); + PyF16Type::bind(m); + PyF32Type::bind(m); + PyF64Type::bind(m); + PyNoneType::bind(m); } diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -124,3 +124,33 @@ print("unsigned:", mlir.ir.IntegerType.get_unsigned(ctx, 64)) run(testIntegerType) + +# CHECK-LABEL: TEST: testIndexType +def testIndexType(): + ctx = mlir.ir.Context() + # CHECK: index type: index + print("index type:", mlir.ir.IndexType(ctx)) + +run(testIndexType) + +# CHECK-LABEL: TEST: testFloatType +def testFloatType(): + ctx = mlir.ir.Context() + # CHECK: float: bf16 + print("float:", mlir.ir.BF16Type(ctx)) + # CHECK: float: f16 + print("float:", mlir.ir.F16Type(ctx)) + # CHECK: float: f32 + print("float:", mlir.ir.F32Type(ctx)) + # CHECK: float: f64 + print("float:", mlir.ir.F64Type(ctx)) + +run(testFloatType) + +# CHECK-LABEL: TEST: testNoneType +def testNoneType(): + ctx = mlir.ir.Context() + # CHECK: none type: none + print("none type:", mlir.ir.NoneType(ctx)) + +run(testNoneType)