diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -81,6 +81,20 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx); +/// Checks whether the given type is an f8E5M2FNUZ type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type); + +/// Creates an f8E5M2FNUZ type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx); + +/// Checks whether the given type is an f8E4M3FNUZ type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type); + +/// Creates an f8E4M3FNUZ type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -62,6 +62,8 @@ // Types. FloatType getFloat8E5M2Type(); FloatType getFloat8E4M3FNType(); + FloatType getFloat8E5M2FNUZType(); + FloatType getFloat8E4M3FNUZType(); FloatType getBF16Type(); FloatType getF16Type(); FloatType getF32Type(); diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -47,6 +47,8 @@ static FloatType getF128(MLIRContext *ctx); static FloatType getFloat8E5M2(MLIRContext *ctx); static FloatType getFloat8E4M3FN(MLIRContext *ctx); + static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx); + static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx); /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Type type); @@ -374,8 +376,9 @@ } inline bool FloatType::classof(Type type) { - return type.isa(); + return type.isa(); } inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) { @@ -386,6 +389,14 @@ return Float8E4M3FNType::get(ctx); } +inline FloatType FloatType::getFloat8E5M2FNUZ(MLIRContext *ctx) { + return Float8E5M2FNUZType::get(ctx); +} + +inline FloatType FloatType::getFloat8E4M3FNUZ(MLIRContext *ctx) { + return Float8E4M3FNUZType::get(ctx); +} + inline FloatType FloatType::getBF16(MLIRContext *ctx) { return BFloat16Type::get(ctx); } diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -118,6 +118,50 @@ }]; } +//===----------------------------------------------------------------------===// +// Float8E5M2FNUZType + +def Builtin_Float8E5M2FNUZ : Builtin_FloatType<"Float8E5M2FNUZ"> { + let summary = "8-bit floating point with 2 bit mantissa"; + let description = [{ + An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits + mantissa. This is not a standard type as defined by IEEE-754, but it follows + similar conventions, with the exception that there are no infinity values, + no negative zero, and only one NaN representation. This type has the + following characteristics: + + * bit encoding: S1E5M2 + * exponent bias: 16 + * infinities: Not supported + * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s + * denormals when exponent is 0 + + Described in: https://arxiv.org/abs/2206.02915 + }]; +} + +//===----------------------------------------------------------------------===// +// Float8E4M3FNUZType + +def Builtin_Float8E4M3FNUZ : Builtin_FloatType<"Float8E4M3FNUZ"> { + let summary = "8-bit floating point with 3 bit mantissa"; + let description = [{ + An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits + mantissa. This is not a standard type as defined by IEEE-754, but it follows + similar conventions, with the exception that there are no infinity values, + no negative zero, and only one NaN representation. This type has the + following characteristics: + + * bit encoding: S1E4M3 + * exponent bias: 8 + * infinities: Not supported + * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s + * denormals when exponent is 0 + + Described in: https://arxiv.org/abs/2209.05433 + }]; +} + //===----------------------------------------------------------------------===// // BFloat16Type diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -488,6 +488,10 @@ BuildableType<"$_builder.getFloat8E4M3FNType()">; def F8E5M2 : Type, "f8E5M2 type">, BuildableType<"$_builder.getFloat8E5M2Type()">; +def F8E4M3FNUZ : Type, "f8E4M3FNUZ type">, + BuildableType<"$_builder.getFloat8E4M3FNUZType()">; +def F8E5M2FNUZ : Type, "f8E5M2FNUZ type">, + BuildableType<"$_builder.getFloat8E5M2FNUZType()">; def AnyComplex : Type()">, "complex-type", "::mlir::ComplexType">; diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -122,6 +122,8 @@ bool isIndex() const; bool isFloat8E5M2() const; bool isFloat8E4M3FN() const; + bool isFloat8E5M2FNUZ() const; + bool isFloat8E4M3FNUZ() const; bool isBF16() const; bool isF16() const; bool isF32() const; diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -95,6 +95,8 @@ TOK_KEYWORD(f80) TOK_KEYWORD(f8E5M2) TOK_KEYWORD(f8E4M3FN) +TOK_KEYWORD(f8E5M2FNUZ) +TOK_KEYWORD(f8E4M3FNUZ) TOK_KEYWORD(f128) TOK_KEYWORD(false) TOK_KEYWORD(floordiv) diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -33,6 +33,8 @@ case Token::inttype: case Token::kw_f8E5M2: case Token::kw_f8E4M3FN: + case Token::kw_f8E5M2FNUZ: + case Token::kw_f8E4M3FNUZ: case Token::kw_bf16: case Token::kw_f16: case Token::kw_f32: @@ -295,6 +297,12 @@ case Token::kw_f8E4M3FN: consumeToken(Token::kw_f8E4M3FN); return builder.getFloat8E4M3FNType(); + case Token::kw_f8E5M2FNUZ: + consumeToken(Token::kw_f8E5M2FNUZ); + return builder.getFloat8E5M2FNUZType(); + case Token::kw_f8E4M3FNUZ: + consumeToken(Token::kw_f8E4M3FNUZ); + return builder.getFloat8E4M3FNUZType(); case Token::kw_bf16: consumeToken(Token::kw_bf16); return builder.getBF16Type(); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -139,6 +139,42 @@ } }; +/// Floating Point Type subclass - Float8E4M3FNUZ. +class PyFloat8E4M3FNUZType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; + static constexpr const char *pyClassName = "Float8E4M3FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); + return PyFloat8E4M3FNUZType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e4m3fnuz type."); + } +}; + +/// Floating Point Type subclass - Float8E5M2FNUZ. +class PyFloat8E5M2FNUZType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; + static constexpr const char *pyClassName = "Float8E5M2FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); + return PyFloat8E5M2FNUZType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e5m2fnuz type."); + } +}; + /// Floating Point Type subclass - BF16Type. class PyBF16Type : public PyConcreteType { public: @@ -700,6 +736,8 @@ PyIndexType::bind(m); PyFloat8E4M3FNType::bind(m); PyFloat8E5M2Type::bind(m); + PyFloat8E4M3FNUZType::bind(m); + PyFloat8E5M2FNUZType::bind(m); PyBF16Type::bind(m); PyF16Type::bind(m); PyF32Type::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -84,6 +84,22 @@ return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx))); } +bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) { + return unwrap(type).isFloat8E5M2FNUZ(); +} + +MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx))); +} + +bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) { + return unwrap(type).isFloat8E4M3FNUZ(); +} + +MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx))); +} + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2410,6 +2410,8 @@ .Case([&](Type) { os << "index"; }) .Case([&](Type) { os << "f8E5M2"; }) .Case([&](Type) { os << "f8E4M3FN"; }) + .Case([&](Type) { os << "f8E5M2FNUZ"; }) + .Case([&](Type) { os << "f8E4M3FNUZ"; }) .Case([&](Type) { os << "bf16"; }) .Case([&](Type) { os << "f16"; }) .Case([&](Type) { os << "f32"; }) diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -41,6 +41,14 @@ return FloatType::getFloat8E4M3FN(context); } +FloatType Builder::getFloat8E5M2FNUZType() { + return FloatType::getFloat8E5M2FNUZ(context); +} + +FloatType Builder::getFloat8E4M3FNUZType() { + return FloatType::getFloat8E4M3FNUZ(context); +} + FloatType Builder::getBF16Type() { return FloatType::getBF16(context); } FloatType Builder::getF16Type() { return FloatType::getF16(context); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -88,7 +88,8 @@ //===----------------------------------------------------------------------===// unsigned FloatType::getWidth() { - if (isa()) + if (isa()) return 8; if (isa()) return 16; @@ -109,6 +110,10 @@ return APFloat::Float8E5M2(); if (isa()) return APFloat::Float8E4M3FN(); + if (isa()) + return APFloat::Float8E5M2FNUZ(); + if (isa()) + return APFloat::Float8E4M3FNUZ(); if (isa()) return APFloat::BFloat(); if (isa()) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -209,6 +209,8 @@ /// Cached Type Instances. Float8E5M2Type f8E5M2Ty; Float8E4M3FNType f8E4M3FNTy; + Float8E5M2FNUZType f8E5M2FNUZTy; + Float8E4M3FNUZType f8E4M3FNUZTy; BFloat16Type bf16Ty; Float16Type f16Ty; Float32Type f32Ty; @@ -281,6 +283,8 @@ /// Floating-point Types. impl->f8E5M2Ty = TypeUniquer::get(this); impl->f8E4M3FNTy = TypeUniquer::get(this); + impl->f8E5M2FNUZTy = TypeUniquer::get(this); + impl->f8E4M3FNUZTy = TypeUniquer::get(this); impl->bf16Ty = TypeUniquer::get(this); impl->f16Ty = TypeUniquer::get(this); impl->f32Ty = TypeUniquer::get(this); @@ -870,6 +874,12 @@ Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) { return context->getImpl().f8E4M3FNTy; } +Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) { + return context->getImpl().f8E5M2FNUZTy; +} +Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) { + return context->getImpl().f8E4M3FNUZTy; +} BFloat16Type BFloat16Type::get(MLIRContext *context) { return context->getImpl().bf16Ty; } diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -36,6 +36,8 @@ bool Type::isFloat8E5M2() const { return isa(); } bool Type::isFloat8E4M3FN() const { return isa(); } +bool Type::isFloat8E5M2FNUZ() const { return isa(); } +bool Type::isFloat8E4M3FNUZ() const { return isa(); } bool Type::isBF16() const { return isa(); } bool Type::isF16() const { return isa(); } bool Type::isF32() const { return isa(); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -52,6 +52,8 @@ "DictAttr", "Float8E4M3FNType", "Float8E5M2Type", + "Float8E4M3FNUZType", + "Float8E5M2FNUZType", "F16Type", "F32Type", "F64Type", @@ -593,6 +595,20 @@ @staticmethod def isinstance(arg: Any) -> bool: ... +class Float8E4M3FNUZType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> Float8E4M3FNUZType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + +class Float8E5M2FNUZType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> Float8E5M2FNUZType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + # TODO: Auto-generated. Audit and fix. class F16Type(Type): def __init__(self, cast_from_type: Type) -> None: ... diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -44,6 +44,14 @@ // CHECK: float_attr = 2.000000e+00 : f8E4M3FN float_attr = 2. : f8E4M3FN } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : f8E5M2FNUZ + float_attr = 2. : f8E5M2FNUZ + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : f8E4M3FNUZ + float_attr = 2. : f8E4M3FNUZ + } : () -> () "test.float_attrs"() { // CHECK: float_attr = 2.000000e+00 : f16 float_attr = 2. : f16 diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -197,6 +197,10 @@ print("float:", Float8E4M3FNType.get()) # CHECK: float: f8E5M2 print("float:", Float8E5M2Type.get()) + # CHECK: float: f8E5M2FNUZ + print("float:", Float8E5M2FNUZType.get()) + # CHECK: float: f8E4M3FNUZ + print("float:", Float8E4M3FNUZType.get()) # CHECK: float: bf16 print("float:", BF16Type.get()) # CHECK: float: f16 diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py --- a/mlir/utils/lldb-scripts/mlirDataFormatters.py +++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py @@ -52,6 +52,8 @@ "mlir::UnknownLoc": '"loc(unknown)"', "mlir::Float8E5M2Type": '"f8E5M2"', "mlir::Float8E4M3FNType": '"f8E4M3FN"', + "mlir::Float8E5M2FNUZType": '"f8E5M2FNUZ"', + "mlir::Float8E4M3FNUZType": '"f8E4M3FNUZ"', "mlir::BFloat16Type": '"bf16"', "mlir::Float16Type": '"f16"', "mlir::Float32Type": '"f32"',