diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -102,6 +102,42 @@ } }; +/// Floating Point Type subclass - Float8E4M3FNType. +class PyFloat8E4M3FNType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; + static constexpr const char *pyClassName = "Float8E4M3FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); + return PyFloat8E4M3FNType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e4m3fn type."); + } +}; + +/// Floating Point Type subclass - Float8M5E2Type. +class PyFloat8E5M2Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; + static constexpr const char *pyClassName = "Float8E5M2Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E5M2TypeGet(context->get()); + return PyFloat8E5M2Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e5m2 type."); + } +}; + /// Floating Point Type subclass - BF16Type. class PyBF16Type : public PyConcreteType { public: @@ -663,6 +699,8 @@ void mlir::python::populateIRTypes(py::module &m) { PyIntegerType::bind(m); PyIndexType::bind(m); + PyFloat8E4M3FNType::bind(m); + PyFloat8E5M2Type::bind(m); PyBF16Type::bind(m); PyF16Type::bind(m); PyF32Type::bind(m); diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -50,6 +50,8 @@ "DiagnosticHandler", "DiagnosticSeverity", "DictAttr", + "Float8E4M3FNType", + "Float8E5M2Type", "F16Type", "F32Type", "F64Type", @@ -577,6 +579,20 @@ @property def type(self) -> Type: ... +class Float8E4M3FNType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> Float8E4M3FNType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + +class Float8E5M2Type(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> Float8E5M2Type: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + # TODO: Auto-generated. Audit and fix. class F16Type(Type): def __init__(self, cast_from_type: Type) -> None: ... diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -193,6 +193,10 @@ @run def testFloatType(): with Context(): + # CHECK: float: f8E4M3FN + print("float:", Float8E4M3FNType.get()) + # CHECK: float: f8E5M2 + print("float:", Float8E5M2Type.get()) # CHECK: float: bf16 print("float:", BF16Type.get()) # CHECK: float: f16