diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp --- a/clang/lib/AST/MicrosoftMangle.cpp +++ b/clang/lib/AST/MicrosoftMangle.cpp @@ -838,6 +838,9 @@ case APFloat::S_x87DoubleExtended: Out << 'X'; break; case APFloat::S_IEEEquad: Out << 'Y'; break; case APFloat::S_PPCDoubleDouble: Out << 'Z'; break; + case APFloat::S_Float8E5M2: + llvm_unreachable("Tried to mangle unexpected APFloat semantics"); + break; } mangleBits(Number.bitcastToAPInt()); diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h --- a/llvm/include/llvm/ADT/APFloat.h +++ b/llvm/include/llvm/ADT/APFloat.h @@ -153,10 +153,13 @@ S_BFloat, S_IEEEsingle, S_IEEEdouble, - S_x87DoubleExtended, S_IEEEquad, S_PPCDoubleDouble, - S_MaxSemantics = S_PPCDoubleDouble + // 8-bit floating point number following IEEE-754 conventions with bit + // layout S1E5M2 as described in https://arxiv.org/abs/2209.05433 + S_Float8E5M2, + S_x87DoubleExtended, + S_MaxSemantics = S_x87DoubleExtended, }; static const llvm::fltSemantics &EnumToSemantics(Semantics S); @@ -168,6 +171,7 @@ static const fltSemantics &IEEEdouble() LLVM_READNONE; static const fltSemantics &IEEEquad() LLVM_READNONE; static const fltSemantics &PPCDoubleDouble() LLVM_READNONE; + static const fltSemantics &Float8E5M2() LLVM_READNONE; static const fltSemantics &x87DoubleExtended() LLVM_READNONE; /// A Pseudo fltsemantic used to construct APFloats that cannot conflict with @@ -552,6 +556,7 @@ APInt convertQuadrupleAPFloatToAPInt() const; APInt convertF80LongDoubleAPFloatToAPInt() const; APInt convertPPCDoubleDoubleAPFloatToAPInt() const; + APInt convertFloat8E5M2APFloatToAPInt() const; void initFromAPInt(const fltSemantics *Sem, const APInt &api); void initFromHalfAPInt(const APInt &api); void initFromBFloatAPInt(const APInt &api); @@ -560,6 +565,7 @@ void initFromQuadrupleAPInt(const APInt &api); void initFromF80LongDoubleAPInt(const APInt &api); void initFromPPCDoubleDoubleAPInt(const APInt &api); + void initFromFloat8E5M2APInt(const APInt &api); void assign(const IEEEFloat &); void copySignificand(const IEEEFloat &); diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp --- a/llvm/lib/Support/APFloat.cpp +++ b/llvm/lib/Support/APFloat.cpp @@ -80,6 +80,7 @@ static const fltSemantics semIEEEsingle = {127, -126, 24, 32}; static const fltSemantics semIEEEdouble = {1023, -1022, 53, 64}; static const fltSemantics semIEEEquad = {16383, -16382, 113, 128}; + static const fltSemantics semFloat8E5M2 = {15, -14, 3, 8}; static const fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80}; static const fltSemantics semBogus = {0, 0, 0, 0}; @@ -131,12 +132,14 @@ return IEEEsingle(); case S_IEEEdouble: return IEEEdouble(); - case S_x87DoubleExtended: - return x87DoubleExtended(); case S_IEEEquad: return IEEEquad(); case S_PPCDoubleDouble: return PPCDoubleDouble(); + case S_Float8E5M2: + return Float8E5M2(); + case S_x87DoubleExtended: + return x87DoubleExtended(); } llvm_unreachable("Unrecognised floating semantics"); } @@ -151,12 +154,14 @@ return S_IEEEsingle; else if (&Sem == &llvm::APFloat::IEEEdouble()) return S_IEEEdouble; - else if (&Sem == &llvm::APFloat::x87DoubleExtended()) - return S_x87DoubleExtended; else if (&Sem == &llvm::APFloat::IEEEquad()) return S_IEEEquad; else if (&Sem == &llvm::APFloat::PPCDoubleDouble()) return S_PPCDoubleDouble; + else if (&Sem == &llvm::APFloat::Float8E5M2()) + return S_Float8E5M2; + else if (&Sem == &llvm::APFloat::x87DoubleExtended()) + return S_x87DoubleExtended; else llvm_unreachable("Unknown floating semantics"); } @@ -173,18 +178,15 @@ const fltSemantics &APFloatBase::IEEEdouble() { return semIEEEdouble; } - const fltSemantics &APFloatBase::IEEEquad() { - return semIEEEquad; + const fltSemantics &APFloatBase::IEEEquad() { return semIEEEquad; } + const fltSemantics &APFloatBase::PPCDoubleDouble() { + return semPPCDoubleDouble; } + const fltSemantics &APFloatBase::Float8E5M2() { return semFloat8E5M2; } const fltSemantics &APFloatBase::x87DoubleExtended() { return semX87DoubleExtended; } - const fltSemantics &APFloatBase::Bogus() { - return semBogus; - } - const fltSemantics &APFloatBase::PPCDoubleDouble() { - return semPPCDoubleDouble; - } + const fltSemantics &APFloatBase::Bogus() { return semBogus; } constexpr RoundingMode APFloatBase::rmNearestTiesToEven; constexpr RoundingMode APFloatBase::rmTowardPositive; @@ -3353,6 +3355,33 @@ (mysignificand & 0x3ff))); } +APInt IEEEFloat::convertFloat8E5M2APFloatToAPInt() const { + assert(semantics == (const llvm::fltSemantics *)&semFloat8E5M2); + assert(partCount() == 1); + + uint32_t myexponent, mysignificand; + + if (isFiniteNonZero()) { + myexponent = exponent + 15; // bias + mysignificand = (uint32_t)*significandParts(); + if (myexponent == 1 && !(mysignificand & 0x4)) + myexponent = 0; // denormal + } else if (category == fcZero) { + myexponent = 0; + mysignificand = 0; + } else if (category == fcInfinity) { + myexponent = 0x1f; + mysignificand = 0; + } else { + assert(category == fcNaN && "Unknown category!"); + myexponent = 0x1f; + mysignificand = (uint32_t)*significandParts(); + } + + return APInt(8, (((sign & 1) << 7) | ((myexponent & 0x1f) << 2) | + (mysignificand & 0x3))); +} + // This function creates an APInt that is just a bit map of the floating // point constant as it would appear in memory. It is not a conversion, // and treating the result as a normal integer is unlikely to be useful. @@ -3376,6 +3405,9 @@ if (semantics == (const llvm::fltSemantics *)&semPPCDoubleDoubleLegacy) return convertPPCDoubleDoubleAPFloatToAPInt(); + if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2) + return convertFloat8E5M2APFloatToAPInt(); + assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended && "unknown format!"); return convertF80LongDoubleAPFloatToAPInt(); @@ -3603,6 +3635,34 @@ } } +void IEEEFloat::initFromFloat8E5M2APInt(const APInt &api) { + uint32_t i = (uint32_t)*api.getRawData(); + uint32_t myexponent = (i >> 2) & 0x1f; + uint32_t mysignificand = i & 0x3; + + initialize(&semFloat8E5M2); + assert(partCount() == 1); + + sign = i >> 7; + if (myexponent == 0 && mysignificand == 0) { + makeZero(sign); + } else if (myexponent == 0x1f && mysignificand == 0) { + makeInf(sign); + } else if (myexponent == 0x1f && mysignificand != 0) { + category = fcNaN; + exponent = exponentNaN(); + *significandParts() = mysignificand; + } else { + category = fcNormal; + exponent = myexponent - 15; // bias + *significandParts() = mysignificand; + if (myexponent == 0) // denormal + exponent = -14; + else + *significandParts() |= 0x4; // integer bit + } +} + /// Treat api as containing the bits of a floating point number. Currently /// we infer the floating point type from the size of the APInt. The /// isIEEE argument distinguishes between PPC128 and IEEE128 (not meaningful @@ -3623,6 +3683,8 @@ return initFromQuadrupleAPInt(api); if (Sem == &semPPCDoubleDoubleLegacy) return initFromPPCDoubleDoubleAPInt(api); + if (Sem == &semFloat8E5M2) + return initFromFloat8E5M2APInt(api); llvm_unreachable(nullptr); } diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp --- a/llvm/unittests/ADT/APFloatTest.cpp +++ b/llvm/unittests/ADT/APFloatTest.cpp @@ -1752,18 +1752,20 @@ const unsigned long long bitPattern[2]; const unsigned bitPatternLength; } const GetZeroTest[] = { - { &APFloat::IEEEhalf(), false, {0, 0}, 1}, - { &APFloat::IEEEhalf(), true, {0x8000ULL, 0}, 1}, - { &APFloat::IEEEsingle(), false, {0, 0}, 1}, - { &APFloat::IEEEsingle(), true, {0x80000000ULL, 0}, 1}, - { &APFloat::IEEEdouble(), false, {0, 0}, 1}, - { &APFloat::IEEEdouble(), true, {0x8000000000000000ULL, 0}, 1}, - { &APFloat::IEEEquad(), false, {0, 0}, 2}, - { &APFloat::IEEEquad(), true, {0, 0x8000000000000000ULL}, 2}, - { &APFloat::PPCDoubleDouble(), false, {0, 0}, 2}, - { &APFloat::PPCDoubleDouble(), true, {0x8000000000000000ULL, 0}, 2}, - { &APFloat::x87DoubleExtended(), false, {0, 0}, 2}, - { &APFloat::x87DoubleExtended(), true, {0, 0x8000ULL}, 2}, + {&APFloat::IEEEhalf(), false, {0, 0}, 1}, + {&APFloat::IEEEhalf(), true, {0x8000ULL, 0}, 1}, + {&APFloat::IEEEsingle(), false, {0, 0}, 1}, + {&APFloat::IEEEsingle(), true, {0x80000000ULL, 0}, 1}, + {&APFloat::IEEEdouble(), false, {0, 0}, 1}, + {&APFloat::IEEEdouble(), true, {0x8000000000000000ULL, 0}, 1}, + {&APFloat::IEEEquad(), false, {0, 0}, 2}, + {&APFloat::IEEEquad(), true, {0, 0x8000000000000000ULL}, 2}, + {&APFloat::PPCDoubleDouble(), false, {0, 0}, 2}, + {&APFloat::PPCDoubleDouble(), true, {0x8000000000000000ULL, 0}, 2}, + {&APFloat::x87DoubleExtended(), false, {0, 0}, 2}, + {&APFloat::x87DoubleExtended(), true, {0, 0x8000ULL}, 2}, + {&APFloat::Float8E5M2(), false, {0, 0}, 1}, + {&APFloat::Float8E5M2(), true, {0x80ULL, 0}, 1}, }; const unsigned NumGetZeroTests = 12; for (unsigned i = 0; i < NumGetZeroTests; ++i) { @@ -4754,7 +4756,7 @@ EXPECT_TRUE(ilogb(F) == -1); } -TEST(APFloatTest, ToDouble) { +TEST(APFloatTest, IEEEdoubleToDouble) { APFloat DPosZero(0.0); APFloat DPosZeroToDouble(DPosZero.convertToDouble()); EXPECT_TRUE(DPosZeroToDouble.isPosZero()); @@ -4790,7 +4792,9 @@ DNegInf.convertToDouble()); APFloat DQNaN = APFloat::getQNaN(APFloat::IEEEdouble()); EXPECT_TRUE(std::isnan(DQNaN.convertToDouble())); +} +TEST(APFloatTest, IEEEsingleToDouble) { APFloat FPosZero(0.0F); APFloat FPosZeroToDouble(FPosZero.convertToDouble()); EXPECT_TRUE(FPosZeroToDouble.isPosZero()); @@ -4825,7 +4829,9 @@ FNegInf.convertToDouble()); APFloat FQNaN = APFloat::getQNaN(APFloat::IEEEsingle()); EXPECT_TRUE(std::isnan(FQNaN.convertToDouble())); +} +TEST(APFloatTest, IEEEhalfToDouble) { APFloat HPosZero = APFloat::getZero(APFloat::IEEEhalf()); APFloat HPosZeroToDouble(HPosZero.convertToDouble()); EXPECT_TRUE(HPosZeroToDouble.isPosZero()); @@ -4867,7 +4873,9 @@ APFloat BNegZero = APFloat::getZero(APFloat::IEEEhalf(), true); APFloat BNegZeroToDouble(BNegZero.convertToDouble()); EXPECT_TRUE(BNegZeroToDouble.isNegZero()); +} +TEST(APFloatTest, BFloatToDouble) { APFloat BOne(APFloat::BFloat(), "1.0"); EXPECT_EQ(1.0, BOne.convertToDouble()); APFloat BPosLargest = APFloat::getLargest(APFloat::BFloat(), false); @@ -4901,7 +4909,35 @@ EXPECT_TRUE(std::isnan(BQNaN.convertToDouble())); } -TEST(APFloatTest, ToFloat) { +TEST(APFloatTest, Float8E5M2ToDouble) { + APFloat One(APFloat::Float8E5M2(), "1.0"); + EXPECT_EQ(1.0, One.convertToDouble()); + APFloat Two(APFloat::Float8E5M2(), "2.0"); + EXPECT_EQ(2.0, Two.convertToDouble()); + APFloat PosLargest = APFloat::getLargest(APFloat::Float8E5M2(), false); + EXPECT_EQ(5.734400e+04, PosLargest.convertToDouble()); + APFloat NegLargest = APFloat::getLargest(APFloat::Float8E5M2(), true); + EXPECT_EQ(-5.734400e+04, NegLargest.convertToDouble()); + APFloat PosSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E5M2(), false); + EXPECT_EQ(0x1.p-14, PosSmallest.convertToDouble()); + APFloat NegSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E5M2(), true); + EXPECT_EQ(-0x1.p-14, NegSmallest.convertToDouble()); + + APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E5M2(), false); + EXPECT_TRUE(SmallestDenorm.isDenormal()); + EXPECT_EQ(0x1p-16, SmallestDenorm.convertToDouble()); + + APFloat PosInf = APFloat::getInf(APFloat::Float8E5M2()); + EXPECT_EQ(std::numeric_limits::infinity(), PosInf.convertToDouble()); + APFloat NegInf = APFloat::getInf(APFloat::Float8E5M2(), true); + EXPECT_EQ(-std::numeric_limits::infinity(), NegInf.convertToDouble()); + APFloat QNaN = APFloat::getQNaN(APFloat::Float8E5M2()); + EXPECT_TRUE(std::isnan(QNaN.convertToDouble())); +} + +TEST(APFloatTest, IEEEsingleToFloat) { APFloat FPosZero(0.0F); APFloat FPosZeroToFloat(FPosZero.convertToFloat()); EXPECT_TRUE(FPosZeroToFloat.isPosZero()); @@ -4935,7 +4971,9 @@ EXPECT_EQ(-std::numeric_limits::infinity(), FNegInf.convertToFloat()); APFloat FQNaN = APFloat::getQNaN(APFloat::IEEEsingle()); EXPECT_TRUE(std::isnan(FQNaN.convertToFloat())); +} +TEST(APFloatTest, IEEEhalfToFloat) { APFloat HPosZero = APFloat::getZero(APFloat::IEEEhalf()); APFloat HPosZeroToFloat(HPosZero.convertToFloat()); EXPECT_TRUE(HPosZeroToFloat.isPosZero()); @@ -4969,7 +5007,9 @@ EXPECT_EQ(-std::numeric_limits::infinity(), HNegInf.convertToFloat()); APFloat HQNaN = APFloat::getQNaN(APFloat::IEEEhalf()); EXPECT_TRUE(std::isnan(HQNaN.convertToFloat())); +} +TEST(APFloatTest, BFloatToFloat) { APFloat BPosZero = APFloat::getZero(APFloat::BFloat()); APFloat BPosZeroToDouble(BPosZero.convertToFloat()); EXPECT_TRUE(BPosZeroToDouble.isPosZero()); @@ -5008,4 +5048,41 @@ APFloat BQNaN = APFloat::getQNaN(APFloat::BFloat()); EXPECT_TRUE(std::isnan(BQNaN.convertToFloat())); } + +TEST(APFloatTest, Float8E5M2ToFloat) { + APFloat PosZero = APFloat::getZero(APFloat::Float8E5M2()); + APFloat PosZeroToFloat(PosZero.convertToFloat()); + EXPECT_TRUE(PosZeroToFloat.isPosZero()); + APFloat NegZero = APFloat::getZero(APFloat::Float8E5M2(), true); + APFloat NegZeroToFloat(NegZero.convertToFloat()); + EXPECT_TRUE(NegZeroToFloat.isNegZero()); + + APFloat One(APFloat::Float8E5M2(), "1.0"); + EXPECT_EQ(1.0F, One.convertToFloat()); + APFloat Two(APFloat::Float8E5M2(), "2.0"); + EXPECT_EQ(2.0F, Two.convertToFloat()); + + APFloat PosLargest = APFloat::getLargest(APFloat::Float8E5M2(), false); + EXPECT_EQ(5.734400e+04, PosLargest.convertToFloat()); + APFloat NegLargest = APFloat::getLargest(APFloat::Float8E5M2(), true); + EXPECT_EQ(-5.734400e+04, NegLargest.convertToFloat()); + APFloat PosSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E5M2(), false); + EXPECT_EQ(0x1.p-14, PosSmallest.convertToFloat()); + APFloat NegSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E5M2(), true); + EXPECT_EQ(-0x1.p-14, NegSmallest.convertToFloat()); + + APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E5M2(), false); + EXPECT_TRUE(SmallestDenorm.isDenormal()); + EXPECT_EQ(0x1.p-16, SmallestDenorm.convertToFloat()); + + APFloat PosInf = APFloat::getInf(APFloat::Float8E5M2()); + EXPECT_EQ(std::numeric_limits::infinity(), PosInf.convertToFloat()); + APFloat NegInf = APFloat::getInf(APFloat::Float8E5M2(), true); + EXPECT_EQ(-std::numeric_limits::infinity(), NegInf.convertToFloat()); + APFloat QNaN = APFloat::getQNaN(APFloat::Float8E5M2()); + EXPECT_TRUE(std::isnan(QNaN.convertToFloat())); } + +} // namespace diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -67,6 +67,13 @@ // Floating-point types. //===----------------------------------------------------------------------===// +/// Checks whether the given type is an f8E5M2 type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type); + +/// Creates an f8E5M2 type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -59,6 +59,7 @@ Attribute metadata = Attribute()); // Types. + FloatType getFloat8E5M2Type(); FloatType getBF16Type(); FloatType getF16Type(); FloatType getF32Type(); diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -46,6 +46,7 @@ static FloatType getF64(MLIRContext *ctx); static FloatType getF80(MLIRContext *ctx); static FloatType getF128(MLIRContext *ctx); + static FloatType getFloat8E5M2(MLIRContext *ctx); /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Type type); @@ -373,8 +374,12 @@ } inline bool FloatType::classof(Type type) { - return type.isa(); + return type.isa(); +} + +inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) { + return Float8E5M2Type::get(ctx); } inline FloatType FloatType::getBF16(MLIRContext *ctx) { diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -76,6 +76,28 @@ }]; } +//===----------------------------------------------------------------------===// +// Float8E5M2Type + +def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2"> { + let summary = "8-bit floating point with 2 bit mantissa"; + let description = [{ + An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits + mantissa. This is not a standard type as defined by IEEE-754, but it + follows similar conventions with the following characteristics: + + * bit encoding: S1E5M2 + * exponent bias: 15 + * infinities: supported with exponent set to all 1s and mantissa 0s + * NaNs: supported with exponent bits set to all 1s and mantissa of + (01, 10, or 11) + * denormals when exponent is 0 + + Described in: https://arxiv.org/abs/2209.05433 + }]; +} + + //===----------------------------------------------------------------------===// // BFloat16Type diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -123,6 +123,7 @@ // Convenience predicates. This is only for floating point types, // derived types should use isa/dyn_cast. bool isIndex() const; + bool isFloat8E5M2() const; bool isBF16() const; bool isF16() const; bool isF32() const; diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -93,6 +93,7 @@ TOK_KEYWORD(f32) TOK_KEYWORD(f64) TOK_KEYWORD(f80) +TOK_KEYWORD(f8E5M2) TOK_KEYWORD(f128) TOK_KEYWORD(false) TOK_KEYWORD(floordiv) diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -30,6 +30,7 @@ case Token::kw_tuple: case Token::kw_vector: case Token::inttype: + case Token::kw_f8E5M2: case Token::kw_bf16: case Token::kw_f16: case Token::kw_f32: @@ -286,6 +287,9 @@ } // float-type + case Token::kw_f8E5M2: + consumeToken(Token::kw_f8E5M2); + return builder.getFloat8E5M2Type(); case Token::kw_bf16: consumeToken(Token::kw_bf16); return builder.getBF16Type(); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -68,6 +68,14 @@ // Floating-point types. //===----------------------------------------------------------------------===// +bool mlirTypeIsAFloat8E5M2(MlirType type) { + return unwrap(type).isFloat8E5M2(); +} + +MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E5M2(unwrap(ctx))); +} + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2179,6 +2179,7 @@ opaqueTy.getTypeData()); }) .Case([&](Type) { os << "index"; }) + .Case([&](Type) { os << "f8E5M2"; }) .Case([&](Type) { os << "bf16"; }) .Case([&](Type) { os << "f16"; }) .Case([&](Type) { os << "f32"; }) diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -33,6 +33,10 @@ // Types. //===----------------------------------------------------------------------===// +FloatType Builder::getFloat8E5M2Type() { + return FloatType::getFloat8E5M2(context); +} + FloatType Builder::getBF16Type() { return FloatType::getBF16(context); } FloatType Builder::getF16Type() { return FloatType::getF16(context); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -88,6 +88,8 @@ //===----------------------------------------------------------------------===// unsigned FloatType::getWidth() { + if (isa()) + return 8; if (isa()) return 16; if (isa()) @@ -103,6 +105,8 @@ /// Returns the floating semantics for the given type. const llvm::fltSemantics &FloatType::getFloatSemantics() { + if (isa()) + return APFloat::Float8E5M2(); if (isa()) return APFloat::BFloat(); if (isa()) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -206,6 +206,7 @@ StorageUniquer typeUniquer; /// Cached Type Instances. + Float8E5M2Type f8E5M2Ty; BFloat16Type bf16Ty; Float16Type f16Ty; Float32Type f32Ty; @@ -276,6 +277,7 @@ //// Types. /// Floating-point Types. + impl->f8E5M2Ty = TypeUniquer::get(this); impl->bf16Ty = TypeUniquer::get(this); impl->f16Ty = TypeUniquer::get(this); impl->f32Ty = TypeUniquer::get(this); @@ -840,6 +842,9 @@ /// This should not be used directly. StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } +Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) { + return context->getImpl().f8E5M2Ty; +} BFloat16Type BFloat16Type::get(MLIRContext *context) { return context->getImpl().bf16Ty; } diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -18,6 +18,7 @@ MLIRContext *Type::getContext() const { return getDialect().getContext(); } +bool Type::isFloat8E5M2() const { return isa(); } bool Type::isBF16() const { return isa(); } bool Type::isF16() const { return isa(); } bool Type::isF32() const { return isa(); } diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -31,6 +31,42 @@ // ----- +//===----------------------------------------------------------------------===// +// Test float attributes +//===----------------------------------------------------------------------===// + +func.func @float_attrs_pass() { + "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : f8E5M2 + float_attr = 2. : f8E5M2 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : f16 + float_attr = 2. : f16 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : bf16 + float_attr = 2. : bf16 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : f32 + float_attr = 2. : f32 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : f64 + float_attr = 2. : f64 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : f80 + float_attr = 2. : f80 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : f128 + float_attr = 2. : f128 + } : () -> () + return +} + //===----------------------------------------------------------------------===// // Test integer attributes //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -193,6 +193,14 @@ let assemblyFormat = "$attr attr-dict"; } +def FloatAttrOp : TEST_Op<"float_attrs"> { + // TODO: Clean up the OpBase float type and attribute selectors so they + // can express all of the types. + let arguments = (ins + AnyAttr:$float_attr + ); +} + def I32Case5: I32EnumAttrCase<"case5", 5>; def I32Case10: I32EnumAttrCase<"case10", 10>;