diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -74,6 +74,13 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx); +/// Checks whether the given type is an f8E4M3FN type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type); + +/// Creates an f8E4M3FN type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -60,6 +60,7 @@ // Types. FloatType getFloat8E5M2Type(); + FloatType getFloat8E4M3FNType(); FloatType getBF16Type(); FloatType getF16Type(); FloatType getF32Type(); diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -47,6 +47,7 @@ static FloatType getF80(MLIRContext *ctx); static FloatType getF128(MLIRContext *ctx); static FloatType getFloat8E5M2(MLIRContext *ctx); + static FloatType getFloat8E4M3FN(MLIRContext *ctx); /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Type type); @@ -374,14 +375,18 @@ } inline bool FloatType::classof(Type type) { - return type.isa(); + return type.isa(); } inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) { return Float8E5M2Type::get(ctx); } +inline FloatType FloatType::getFloat8E4M3FN(MLIRContext *ctx) { + return Float8E4M3FNType::get(ctx); +} + inline FloatType FloatType::getBF16(MLIRContext *ctx) { return BFloat16Type::get(ctx); } diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -89,7 +89,7 @@ * bit encoding: S1E5M2 * exponent bias: 15 * infinities: supported with exponent set to all 1s and mantissa 0s - * NaNs: supported with exponent bits set to all 1s and mantissa of + * NaNs: supported with exponent bits set to all 1s and mantissa of (01, 10, or 11) * denormals when exponent is 0 @@ -97,6 +97,27 @@ }]; } +//===----------------------------------------------------------------------===// +// Float8E4M3FNType + +def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN"> { + let summary = "8-bit floating point with 3 bit mantissa"; + let description = [{ + An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits + mantissa. This is not a standard type as defined by IEEE-754, but it follows + similar conventions, with the exception that there are no infinity values + and only two NaN representations. This type has the following + characteristics: + + * bit encoding: S1E4M3 + * exponent bias: 7 + * infinities: Not supported + * NaNs: supported with exponent bits and mantissa bits set to all 1s + * denormals when exponent is 0 + + Described in: https://arxiv.org/abs/2209.05433 + }]; +} //===----------------------------------------------------------------------===// // BFloat16Type diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -124,6 +124,7 @@ // derived types should use isa/dyn_cast. bool isIndex() const; bool isFloat8E5M2() const; + bool isFloat8E4M3FN() const; bool isBF16() const; bool isF16() const; bool isF32() const; diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -94,6 +94,7 @@ TOK_KEYWORD(f64) TOK_KEYWORD(f80) TOK_KEYWORD(f8E5M2) +TOK_KEYWORD(f8E4M3FN) TOK_KEYWORD(f128) TOK_KEYWORD(false) TOK_KEYWORD(floordiv) diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -31,6 +31,7 @@ case Token::kw_vector: case Token::inttype: case Token::kw_f8E5M2: + case Token::kw_f8E4M3FN: case Token::kw_bf16: case Token::kw_f16: case Token::kw_f32: @@ -290,6 +291,9 @@ case Token::kw_f8E5M2: consumeToken(Token::kw_f8E5M2); return builder.getFloat8E5M2Type(); + case Token::kw_f8E4M3FN: + consumeToken(Token::kw_f8E4M3FN); + return builder.getFloat8E4M3FNType(); case Token::kw_bf16: consumeToken(Token::kw_bf16); return builder.getBF16Type(); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -76,6 +76,14 @@ return wrap(FloatType::getFloat8E5M2(unwrap(ctx))); } +bool mlirTypeIsAFloat8E4M3FN(MlirType type) { + return unwrap(type).isFloat8E4M3FN(); +} + +MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx))); +} + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2244,6 +2244,7 @@ }) .Case([&](Type) { os << "index"; }) .Case([&](Type) { os << "f8E5M2"; }) + .Case([&](Type) { os << "f8E4M3FN"; }) .Case([&](Type) { os << "bf16"; }) .Case([&](Type) { os << "f16"; }) .Case([&](Type) { os << "f32"; }) diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -37,6 +37,10 @@ return FloatType::getFloat8E5M2(context); } +FloatType Builder::getFloat8E4M3FNType() { + return FloatType::getFloat8E4M3FN(context); +} + FloatType Builder::getBF16Type() { return FloatType::getBF16(context); } FloatType Builder::getF16Type() { return FloatType::getF16(context); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -88,7 +88,7 @@ //===----------------------------------------------------------------------===// unsigned FloatType::getWidth() { - if (isa()) + if (isa()) return 8; if (isa()) return 16; @@ -107,6 +107,8 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() { if (isa()) return APFloat::Float8E5M2(); + if (isa()) + return APFloat::Float8E4M3FN(); if (isa()) return APFloat::BFloat(); if (isa()) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -207,6 +207,7 @@ /// Cached Type Instances. Float8E5M2Type f8E5M2Ty; + Float8E4M3FNType f8E4M3FNTy; BFloat16Type bf16Ty; Float16Type f16Ty; Float32Type f32Ty; @@ -278,6 +279,7 @@ //// Types. /// Floating-point Types. impl->f8E5M2Ty = TypeUniquer::get(this); + impl->f8E4M3FNTy = TypeUniquer::get(this); impl->bf16Ty = TypeUniquer::get(this); impl->f16Ty = TypeUniquer::get(this); impl->f32Ty = TypeUniquer::get(this); @@ -861,6 +863,9 @@ Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) { return context->getImpl().f8E5M2Ty; } +Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) { + return context->getImpl().f8E4M3FNTy; +} BFloat16Type BFloat16Type::get(MLIRContext *context) { return context->getImpl().bf16Ty; } diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -19,6 +19,7 @@ MLIRContext *Type::getContext() const { return getDialect().getContext(); } bool Type::isFloat8E5M2() const { return isa(); } +bool Type::isFloat8E4M3FN() const { return isa(); } bool Type::isBF16() const { return isa(); } bool Type::isF16() const { return isa(); } bool Type::isF32() const { return isa(); } diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -40,6 +40,10 @@ // CHECK: float_attr = 2.000000e+00 : f8E5M2 float_attr = 2. : f8E5M2 } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : f8E4M3FN + float_attr = 2. : f8E4M3FN + } : () -> () "test.float_attrs"() { // CHECK: float_attr = 2.000000e+00 : f16 float_attr = 2. : f16