diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp --- a/clang/lib/AST/MicrosoftMangle.cpp +++ b/clang/lib/AST/MicrosoftMangle.cpp @@ -840,6 +840,8 @@ case APFloat::S_PPCDoubleDouble: Out << 'Z'; break; case APFloat::S_Float8E5M2: case APFloat::S_Float8E4M3FN: + case APFloat::S_Float8E5M2FZN: + case APFloat::S_Float8E4M3FZN: llvm_unreachable("Tried to mangle unexpected APFloat semantics"); } diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h --- a/llvm/include/llvm/ADT/APFloat.h +++ b/llvm/include/llvm/ADT/APFloat.h @@ -163,6 +163,18 @@ // Unlike IEEE-754 types, there are no infinity values, and NaN is // represented with the exponent and mantissa bits set to all 1s. S_Float8E4M3FN, + // 8-bit floating point number mostly following IEEE-754 conventions with + // bit layout S1E5M2 as described in https://arxiv.org/abs/2206.02915. + // Unlike IEEE-754 types, there are no infinity values, there is no + // negative zero, and NaN is represented with the exponent and mantissa + // bits set to all 0s and the sign bit set to 1. + S_Float8E5M2FZN, + // 8-bit floating point number mostly following IEEE-754 conventions with + // bit layout S1E4M3 as described in https://arxiv.org/abs/2206.02915. + // Unlike IEEE-754 types, there are no infinity values, there is no + // negative zero, and NaN is represented with the exponent and mantissa + // bits set to all 0s and the sign bit set to 1. + S_Float8E4M3FZN, S_x87DoubleExtended, S_MaxSemantics = S_x87DoubleExtended, }; @@ -178,6 +190,8 @@ static const fltSemantics &PPCDoubleDouble() LLVM_READNONE; static const fltSemantics &Float8E5M2() LLVM_READNONE; static const fltSemantics &Float8E4M3FN() LLVM_READNONE; + static const fltSemantics &Float8E5M2FZN() LLVM_READNONE; + static const fltSemantics &Float8E4M3FZN() LLVM_READNONE; static const fltSemantics &x87DoubleExtended() LLVM_READNONE; /// A Pseudo fltsemantic used to construct APFloats that cannot conflict with @@ -570,6 +584,8 @@ APInt convertPPCDoubleDoubleAPFloatToAPInt() const; APInt convertFloat8E5M2APFloatToAPInt() const; APInt convertFloat8E4M3FNAPFloatToAPInt() const; + APInt convertFloat8E5M2FZNAPFloatToAPInt() const; + APInt convertFloat8E4M3FZNAPFloatToAPInt() const; void initFromAPInt(const fltSemantics *Sem, const APInt &api); void initFromHalfAPInt(const APInt &api); void initFromBFloatAPInt(const APInt &api); @@ -580,6 +596,8 @@ void initFromPPCDoubleDoubleAPInt(const APInt &api); void initFromFloat8E5M2APInt(const APInt &api); void initFromFloat8E4M3FNAPInt(const APInt &api); + void initFromFloat8E5M2FZNAPInt(const APInt &api); + void initFromFloat8E4M3FZNAPInt(const APInt &api); void assign(const IEEEFloat &); void copySignificand(const IEEEFloat &); diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp --- a/llvm/lib/Support/APFloat.cpp +++ b/llvm/lib/Support/APFloat.cpp @@ -65,6 +65,25 @@ // as non-signalling, although the paper does not state whether the NaN // values are signalling or not. NanOnly, + + // Float8E5M2FZN and Float8E4M3FZN have this behavior. There is no Inf + // representation. A value is NaN if the exponent field and the mantissa + // field are all 0s, and sign bit is 1. This behavior matches the FP8 types + // described in https://arxiv.org/abs/2206.02915. + NanOnlyS1E0M0, + }; + + // Whether both positive and negative zero or only positive zero are + // represented. + enum class fltSignedZeroSupport { + // Represents standard IEEE 754 behavior. Zero can be both positive and + // negative. + IEEE754, + + // Float8E5M2FZN and Float8E4M3FZN have this behavior. They only support + // positive zero. This behavior matches the FP8 types described in + // https://arxiv.org/abs/2206.02915. + PositiveOnly, }; /* Represents floating point arithmetic semantics. */ @@ -86,6 +105,9 @@ fltNonfiniteBehavior nonFiniteBehavior = fltNonfiniteBehavior::IEEE754; + /* The supported zero encodings. */ + fltSignedZeroSupport zeroSupport = fltSignedZeroSupport::IEEE754; + // Returns true if any number described by this semantics can be precisely // represented by the specified semantics. Does not take into account // the value of fltNonfiniteBehavior. @@ -103,6 +125,14 @@ static const fltSemantics semFloat8E5M2 = {15, -14, 3, 8}; static const fltSemantics semFloat8E4M3FN = {8, -6, 4, 8, fltNonfiniteBehavior::NanOnly}; + static const fltSemantics semFloat8E5M2FZN = { + 15, -15, 3, 8, + fltNonfiniteBehavior::NanOnlyS1E0M0, + fltSignedZeroSupport::PositiveOnly}; + static const fltSemantics semFloat8E4M3FZN = { + 7, -7, 4, 8, + fltNonfiniteBehavior::NanOnlyS1E0M0, + fltSignedZeroSupport::PositiveOnly}; static const fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80}; static const fltSemantics semBogus = {0, 0, 0, 0}; @@ -162,6 +192,10 @@ return Float8E5M2(); case S_Float8E4M3FN: return Float8E4M3FN(); + case S_Float8E5M2FZN: + return Float8E5M2FZN(); + case S_Float8E4M3FZN: + return Float8E4M3FZN(); case S_x87DoubleExtended: return x87DoubleExtended(); } @@ -186,6 +220,10 @@ return S_Float8E5M2; else if (&Sem == &llvm::APFloat::Float8E4M3FN()) return S_Float8E4M3FN; + else if (&Sem == &llvm::APFloat::Float8E5M2FZN()) + return S_Float8E5M2FZN; + else if (&Sem == &llvm::APFloat::Float8E4M3FZN()) + return S_Float8E4M3FZN; else if (&Sem == &llvm::APFloat::x87DoubleExtended()) return S_x87DoubleExtended; else @@ -210,6 +248,8 @@ } const fltSemantics &APFloatBase::Float8E5M2() { return semFloat8E5M2; } const fltSemantics &APFloatBase::Float8E4M3FN() { return semFloat8E4M3FN; } + const fltSemantics &APFloatBase::Float8E5M2FZN() { return semFloat8E5M2FZN; } + const fltSemantics &APFloatBase::Float8E4M3FZN() { return semFloat8E4M3FZN; } const fltSemantics &APFloatBase::x87DoubleExtended() { return semX87DoubleExtended; } @@ -805,6 +845,13 @@ fill = &fill_storage; } + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0) { + // There is only one valid NaN encoding. + sign = 1; + *significand = 0; + return; + } + // Set the significand bits to the fill. if (!fill || fill->getNumWords() < numParts) APInt::tcSet(significand, 0, numParts); @@ -1406,7 +1453,8 @@ rounding_mode == rmNearestTiesToAway || (rounding_mode == rmTowardPositive && !sign) || (rounding_mode == rmTowardNegative && sign)) { - if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly || + semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0) makeNaN(false, sign); else category = fcInfinity; @@ -1546,6 +1594,12 @@ /* Did the significand increment overflow? */ if (omsb == (unsigned) semantics->precision + 1) { + // NanOnlyS1E0M0 types can't overflow to infinity. + if (exponent == semantics->maxExponent && + semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0) { + return handleOverflow(rounding_mode); + } + /* Renormalize by incrementing the exponent and shifting our significand right one. However if we already have the maximum exponent we overflow to infinity. */ @@ -1783,7 +1837,8 @@ return opOK; case PackCategoriesIntoKey(fcNormal, fcZero): - if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly || + semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0) makeNaN(false, sign); else category = fcInfinity; @@ -2347,8 +2402,8 @@ // If this is a truncation, perform the shift before we narrow the storage. if (shift < 0 && (isFiniteNonZero() || - (category == fcNaN && semantics->nonFiniteBehavior != - fltNonfiniteBehavior::NanOnly))) + (category == fcNaN && semantics->nonFiniteBehavior == + fltNonfiniteBehavior::IEEE754))) lostFraction = shiftRight(significandParts(), oldPartCount, -shift); // Fix the storage so it can hold to new value. @@ -2370,6 +2425,14 @@ significand.part = newPart; } + // A bit needs to be set in the significand when converting from a + // NanOnlyS1E0M0 encoding, otherwise this will become a -Inf for types that + // follow the IEEE-754 convention. + if (category == fcNaN && + (fromSemantics.nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0 && + toSemantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnlyS1E0M0)) + APInt::tcSetBit(significandParts(), 1); + // Now that we have the right storage, switch the semantics. semantics = &toSemantics; @@ -2382,9 +2445,10 @@ fs = normalize(rounding_mode, lostFraction); *losesInfo = (fs != opOK); } else if (category == fcNaN) { - if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) { + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly || + semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0) { *losesInfo = - fromSemantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnly; + fromSemantics.nonFiniteBehavior != semantics->nonFiniteBehavior; makeNaN(false, sign); return is_signaling ? opInvalidOp : opOK; } @@ -2406,7 +2470,9 @@ fs = opOK; } } else if (category == fcInfinity && - semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) { + (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly || + semantics->nonFiniteBehavior == + fltNonfiniteBehavior::NanOnlyS1E0M0)) { makeNaN(false, sign); *losesInfo = true; fs = opInexact; @@ -3534,6 +3600,60 @@ (mysignificand & 0x7))); } +APInt IEEEFloat::convertFloat8E5M2FZNAPFloatToAPInt() const { + assert(semantics == (const llvm::fltSemantics *)&semFloat8E5M2FZN); + assert(partCount() == 1); + + uint32_t mysign, myexponent, mysignificand; + + if (isFiniteNonZero()) { + mysign = sign; + myexponent = exponent + 16; // bias + mysignificand = (uint32_t)*significandParts(); + if (myexponent == 1 && !(mysignificand & 0x4)) + myexponent = 0; // denormal + } else if (category == fcZero) { + mysign = 0; + myexponent = 0; + mysignificand = 0; + } else { + assert(category == fcNaN && "Unknown category!"); + mysign = 1; + myexponent = 0; + mysignificand = 0; + } + + return APInt(8, (((mysign & 1) << 7) | ((myexponent & 0x1f) << 2) | + (mysignificand & 0x3))); +} + +APInt IEEEFloat::convertFloat8E4M3FZNAPFloatToAPInt() const { + assert(semantics == (const llvm::fltSemantics *)&semFloat8E4M3FZN); + assert(partCount() == 1); + + uint32_t mysign, myexponent, mysignificand; + + if (isFiniteNonZero()) { + mysign = sign; + myexponent = exponent + 8; // bias + mysignificand = (uint32_t)*significandParts(); + if (myexponent == 1 && !(mysignificand & 0x8)) + myexponent = 0; // denormal + } else if (category == fcZero) { + mysign = 0; + myexponent = 0; + mysignificand = 0; + } else { + assert(category == fcNaN && "Unknown category!"); + mysign = 1; + myexponent = 0; + mysignificand = 0; + } + + return APInt(8, (((mysign & 1) << 7) | ((myexponent & 0xf) << 3) | + (mysignificand & 0x7))); +} + // This function creates an APInt that is just a bit map of the floating // point constant as it would appear in memory. It is not a conversion, // and treating the result as a normal integer is unlikely to be useful. @@ -3563,6 +3683,12 @@ if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FN) return convertFloat8E4M3FNAPFloatToAPInt(); + if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2FZN) + return convertFloat8E5M2FZNAPFloatToAPInt(); + + if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FZN) + return convertFloat8E4M3FZNAPFloatToAPInt(); + assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended && "unknown format!"); return convertF80LongDoubleAPFloatToAPInt(); @@ -3844,6 +3970,66 @@ } } +void IEEEFloat::initFromFloat8E5M2FZNAPInt(const APInt &api) { + uint32_t i = (uint32_t)*api.getRawData(); + uint32_t mysign = (i >> 7) & 1; + uint32_t myexponent = (i >> 2) & 0x1f; + uint32_t mysignificand = i & 0x3; + + initialize(&semFloat8E5M2FZN); + assert(partCount() == 1); + + if (mysign == 0 && myexponent == 0 && mysignificand == 0) { + // Negative zero not supported. + makeZero(false); + } else if (mysign == 1 && myexponent == 0 && mysignificand == 0) { + category = fcNaN; + sign = 1; + exponent = exponentNaN(); + *significandParts() = mysignificand; + } else { + category = fcNormal; + sign = mysign; + exponent = myexponent - 16; // bias + *significandParts() = mysignificand; + + if (myexponent == 0) // denormal + exponent = -15; + else + *significandParts() |= 0x4; // integer bit + } +} + +void IEEEFloat::initFromFloat8E4M3FZNAPInt(const APInt &api) { + uint32_t i = (uint32_t)*api.getRawData(); + uint32_t mysign = (i >> 7) & 1; + uint32_t myexponent = (i >> 3) & 0xf; + uint32_t mysignificand = i & 0x7; + + initialize(&semFloat8E4M3FZN); + assert(partCount() == 1); + + if (mysign == 0 && myexponent == 0 && mysignificand == 0) { + // Negative zero not supported. + makeZero(false); + } else if (mysign == 1 && myexponent == 0 && mysignificand == 0) { + category = fcNaN; + sign = 1; + exponent = exponentNaN(); + *significandParts() = mysignificand; + } else { + category = fcNormal; + sign = mysign; + exponent = myexponent - 8; // bias + *significandParts() = mysignificand; + + if (myexponent == 0) // denormal + exponent = -7; + else + *significandParts() |= 0x8; // integer bit + } +} + /// Treat api as containing the bits of a floating point number. void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) { assert(api.getBitWidth() == Sem->sizeInBits); @@ -3865,6 +4051,10 @@ return initFromFloat8E5M2APInt(api); if (Sem == &semFloat8E4M3FN) return initFromFloat8E4M3FNAPInt(api); + if (Sem == &semFloat8E5M2FZN) + return initFromFloat8E5M2FZNAPInt(api); + if (Sem == &semFloat8E4M3FZN) + return initFromFloat8E4M3FZNAPInt(api); llvm_unreachable(nullptr); } @@ -4274,6 +4464,8 @@ return false; if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) return false; + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0) + return false; // IEEE-754R 2008 6.2.1: A signaling NaN bit string should be encoded with the // first bit of the trailing significand being 0. @@ -4325,7 +4517,8 @@ } if (isLargest() && !isNegative()) { - if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) { + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly || + semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0) { // nextUp(getLargest()) == NAN makeNaN(); break; @@ -4409,6 +4602,8 @@ APFloatBase::ExponentType IEEEFloat::exponentNaN() const { if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) return semantics->maxExponent; + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0) + return 0; return semantics->maxExponent + 1; } @@ -4421,7 +4616,8 @@ } void IEEEFloat::makeInf(bool Negative) { - if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) { + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly || + semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0) { // There is no Inf, so make NaN instead. makeNaN(false, Negative); return; @@ -4433,6 +4629,10 @@ } void IEEEFloat::makeZero(bool Negative) { + if (semantics->zeroSupport == fltSignedZeroSupport::PositiveOnly) { + Negative = false; + } + category = fcZero; sign = Negative; exponent = exponentZero(); @@ -4441,7 +4641,8 @@ void IEEEFloat::makeQuiet() { assert(isNaN()); - if (semantics->nonFiniteBehavior != fltNonfiniteBehavior::NanOnly) + if (semantics->nonFiniteBehavior != fltNonfiniteBehavior::NanOnly && + semantics->nonFiniteBehavior != fltNonfiniteBehavior::NanOnlyS1E0M0) APInt::tcSetBit(significandParts(), semantics->precision - 2); } diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp --- a/llvm/unittests/ADT/APFloatTest.cpp +++ b/llvm/unittests/ADT/APFloatTest.cpp @@ -1735,6 +1735,10 @@ EXPECT_EQ(3.402823466e+38f, APFloat::getLargest(APFloat::IEEEsingle()).convertToFloat()); EXPECT_EQ(1.7976931348623158e+308, APFloat::getLargest(APFloat::IEEEdouble()).convertToDouble()); EXPECT_EQ(448, APFloat::getLargest(APFloat::Float8E4M3FN()).convertToDouble()); + EXPECT_EQ(57344, + APFloat::getLargest(APFloat::Float8E5M2FZN()).convertToDouble()); + EXPECT_EQ(240, + APFloat::getLargest(APFloat::Float8E4M3FZN()).convertToDouble()); } TEST(APFloatTest, getSmallest) { @@ -1840,9 +1844,12 @@ {&APFloat::Float8E5M2(), true, {0x80ULL, 0}, 1}, {&APFloat::Float8E4M3FN(), false, {0, 0}, 1}, {&APFloat::Float8E4M3FN(), true, {0x80ULL, 0}, 1}, + // Not testing sign = true cases because negative zero isn't supported. + {&APFloat::Float8E5M2FZN(), false, {0, 0}, 1}, + {&APFloat::Float8E4M3FZN(), false, {0, 0}, 1}, }; - const unsigned NumGetZeroTests = 12; - for (unsigned i = 0; i < NumGetZeroTests; ++i) { + + for (unsigned i = 0; i < std::size(GetZeroTest); ++i) { APFloat test = APFloat::getZero(*GetZeroTest[i].semantics, GetZeroTest[i].sign); const char *pattern = GetZeroTest[i].sign? "-0x0p+0" : "0x0p+0"; @@ -5233,20 +5240,501 @@ } } +TEST(APFloatTest, Float8E4M3FZNFromString) { + // Exactly representable + EXPECT_EQ(240, APFloat(APFloat::Float8E4M3FZN(), "240").convertToDouble()); + // Round down to maximum value + EXPECT_EQ(240, APFloat(APFloat::Float8E4M3FZN(), "244").convertToDouble()); + // Round up, causing overflow to NaN + EXPECT_TRUE(APFloat(APFloat::Float8E4M3FZN(), "256").isNaN()); + // Overflow without rounding + EXPECT_TRUE(APFloat(APFloat::Float8E4M3FZN(), "480").isNaN()); + // Inf converted to NaN + EXPECT_TRUE(APFloat(APFloat::Float8E4M3FZN(), "inf").isNaN()); + // NaN converted to NaN + EXPECT_TRUE(APFloat(APFloat::Float8E4M3FZN(), "nan").isNaN()); +} + +TEST(APFloatTest, Float8E4M3FZNAdd) { + APFloat QNaN = APFloat::getNaN(APFloat::Float8E4M3FZN(), false); + + auto FromStr = [](StringRef S) { + return APFloat(APFloat::Float8E4M3FZN(), S); + }; + + struct { + APFloat x; + APFloat y; + const char *result; + int status; + int category; + APFloat::roundingMode roundingMode = APFloat::rmNearestTiesToEven; + } AdditionTests[] = { + // Test addition operations involving NaN, overflow, and the max E4M3FZN + // value (240) because E4M3FZN differs from IEEE-754 types in these + // regards + {FromStr("240"), FromStr("4"), "240", APFloat::opInexact, + APFloat::fcNormal}, + {FromStr("240"), FromStr("8"), "NaN", + APFloat::opOverflow | APFloat::opInexact, APFloat::fcNaN}, + {FromStr("240"), FromStr("16"), "NaN", + APFloat::opOverflow | APFloat::opInexact, APFloat::fcNaN}, + {FromStr("-240"), FromStr("-16"), "NaN", + APFloat::opOverflow | APFloat::opInexact, APFloat::fcNaN}, + {QNaN, FromStr("-240"), "NaN", APFloat::opOK, APFloat::fcNaN}, + {FromStr("240"), FromStr("-16"), "224", APFloat::opOK, APFloat::fcNormal}, + {FromStr("240"), FromStr("0"), "240", APFloat::opOK, APFloat::fcNormal}, + {FromStr("240"), FromStr("32"), "240", APFloat::opInexact, + APFloat::fcNormal, APFloat::rmTowardZero}, + {FromStr("240"), FromStr("240"), "240", APFloat::opInexact, + APFloat::fcNormal, APFloat::rmTowardZero}, + }; + + for (size_t i = 0; i < std::size(AdditionTests); ++i) { + APFloat x(AdditionTests[i].x); + APFloat y(AdditionTests[i].y); + APFloat::opStatus status = x.add(y, AdditionTests[i].roundingMode); + + APFloat result(APFloat::Float8E4M3FZN(), AdditionTests[i].result); + + EXPECT_TRUE(result.bitwiseIsEqual(x)); + EXPECT_EQ(AdditionTests[i].status, (int)status); + EXPECT_EQ(AdditionTests[i].category, (int)x.getCategory()); + } +} + +TEST(APFloatTest, Float8E4M3FZNDivideByZero) { + APFloat x(APFloat::Float8E4M3FZN(), "1"); + APFloat zero(APFloat::Float8E4M3FZN(), "0"); + EXPECT_EQ(x.divide(zero, APFloat::rmNearestTiesToEven), APFloat::opDivByZero); + EXPECT_TRUE(x.isNaN()); +} + +TEST(APFloatTest, Float8E4M3FZNNext) { + APFloat test(APFloat::Float8E4M3FZN(), APFloat::uninitialized); + APFloat expected(APFloat::Float8E4M3FZN(), APFloat::uninitialized); + + // nextUp on positive numbers + for (int i = 0; i < 127; i++) { + test = APFloat(APFloat::Float8E4M3FZN(), APInt(8, i)); + expected = APFloat(APFloat::Float8E4M3FZN(), APInt(8, i + 1)); + EXPECT_EQ(test.next(false), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + } + + // nextUp on negative nonzero numbers + for (int i = 130; i < 255; i++) { + test = APFloat(APFloat::Float8E4M3FZN(), APInt(8, i)); + expected = APFloat(APFloat::Float8E4M3FZN(), APInt(8, i - 1)); + EXPECT_EQ(test.next(false), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)) << i; + } + + // nextUp on NaN + test = APFloat::getQNaN(APFloat::Float8E4M3FZN(), false); + expected = APFloat::getQNaN(APFloat::Float8E4M3FZN(), false); + EXPECT_EQ(test.next(false), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + + // nextDown on positive nonzero finite numbers + for (int i = 1; i < 127; i++) { + test = APFloat(APFloat::Float8E4M3FZN(), APInt(8, i)); + expected = APFloat(APFloat::Float8E4M3FZN(), APInt(8, i - 1)); + EXPECT_EQ(test.next(true), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + } + + // nextDown on positive zero + test = APFloat::getZero(APFloat::Float8E4M3FZN(), true); + expected = APFloat::getSmallest(APFloat::Float8E4M3FZN(), true); + EXPECT_EQ(test.next(true), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + + // nextDown on negative finite numbers + for (int i = 129; i < 255; i++) { + test = APFloat(APFloat::Float8E4M3FZN(), APInt(8, i)); + expected = APFloat(APFloat::Float8E4M3FZN(), APInt(8, i + 1)); + EXPECT_EQ(test.next(true), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)) << i; + } + + // nextDown on NaN + test = APFloat::getQNaN(APFloat::Float8E4M3FZN(), false); + expected = APFloat::getQNaN(APFloat::Float8E4M3FZN(), false); + EXPECT_EQ(test.next(true), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); +} + +TEST(APFloatTest, Float8E4M3FZNExhaustive) { + // Test each of the 256 Float8E4M3FZN values. + for (int i = 0; i < 256; i++) { + APFloat test(APFloat::Float8E4M3FZN(), APInt(8, i)); + SCOPED_TRACE("i=" + std::to_string(i)); + + // isLargest + if (i == 127 || i == 255) { + EXPECT_TRUE(test.isLargest()); + EXPECT_EQ(abs(test).convertToDouble(), 240.); + } else { + EXPECT_FALSE(test.isLargest()); + } + + // isSmallest + if (i == 1 || i == 129) { + EXPECT_TRUE(test.isSmallest()); + EXPECT_EQ(abs(test).convertToDouble(), 0x1p-10); + } else { + EXPECT_FALSE(test.isSmallest()); + } + + // convert to BFloat + APFloat test2 = test; + bool loses_info; + APFloat::opStatus status = test2.convert( + APFloat::BFloat(), APFloat::rmNearestTiesToEven, &loses_info); + EXPECT_EQ(status, APFloat::opOK); + EXPECT_FALSE(loses_info); + if (i == 128) + EXPECT_TRUE(test2.isNaN()); + else + EXPECT_EQ(test.convertToFloat(), test2.convertToFloat()); + + // bitcastToAPInt + EXPECT_EQ(i, test.bitcastToAPInt()); + } +} + +TEST(APFloatTest, Float8E4M3FZNExhaustivePair) { + // Test each pair of Float8E4M3FZN values. + for (int i = 0; i < 256; i++) { + for (int j = 0; j < 256; j++) { + SCOPED_TRACE("i=" + std::to_string(i) + ",j=" + std::to_string(j)); + APFloat x(APFloat::Float8E4M3FZN(), APInt(8, i)); + APFloat y(APFloat::Float8E4M3FZN(), APInt(8, j)); + + bool losesInfo; + APFloat x16 = x; + x16.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_FALSE(losesInfo); + APFloat y16 = y; + y16.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_FALSE(losesInfo); + + // Add + APFloat z = x; + z.add(y, APFloat::rmNearestTiesToEven); + APFloat z16 = x16; + z16.add(y16, APFloat::rmNearestTiesToEven); + z16.convert(APFloat::Float8E4M3FZN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)); + + // Subtract + z = x; + z.subtract(y, APFloat::rmNearestTiesToEven); + z16 = x16; + z16.subtract(y16, APFloat::rmNearestTiesToEven); + z16.convert(APFloat::Float8E4M3FZN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)); + + // Multiply + z = x; + z.multiply(y, APFloat::rmNearestTiesToEven); + z16 = x16; + z16.multiply(y16, APFloat::rmNearestTiesToEven); + z16.convert(APFloat::Float8E4M3FZN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j; + + // Divide + z = x; + z.divide(y, APFloat::rmNearestTiesToEven); + z16 = x16; + z16.divide(y16, APFloat::rmNearestTiesToEven); + z16.convert(APFloat::Float8E4M3FZN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j; + + // Mod + z = x; + z.mod(y); + z16 = x16; + z16.mod(y16); + z16.convert(APFloat::Float8E4M3FZN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j; + + // Remainder + z = x; + z.remainder(y); + z16 = x16; + z16.remainder(y16); + z16.convert(APFloat::Float8E4M3FZN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j; + } + } +} + +TEST(APFloatTest, Float8E5M2FZNFromString) { + // Exactly representable + EXPECT_EQ(57344, + APFloat(APFloat::Float8E5M2FZN(), "57344").convertToDouble()); + // Round down to maximum value + EXPECT_EQ(57344, + APFloat(APFloat::Float8E5M2FZN(), "59392").convertToDouble()); + // Round up, causing overflow to NaN + EXPECT_TRUE(APFloat(APFloat::Float8E5M2FZN(), "61440").isNaN()); + // Overflow without rounding + EXPECT_TRUE(APFloat(APFloat::Float8E5M2FZN(), "131072").isNaN()); + // Inf converted to NaN + EXPECT_TRUE(APFloat(APFloat::Float8E5M2FZN(), "inf").isNaN()); + // NaN converted to NaN + EXPECT_TRUE(APFloat(APFloat::Float8E5M2FZN(), "nan").isNaN()); +} + +TEST(APFloatTest, Float8E5M2FZNAdd) { + APFloat QNaN = APFloat::getNaN(APFloat::Float8E5M2FZN(), false); + + auto FromStr = [](StringRef S) { + return APFloat(APFloat::Float8E5M2FZN(), S); + }; + + struct { + APFloat x; + APFloat y; + const char *result; + int status; + int category; + APFloat::roundingMode roundingMode = APFloat::rmNearestTiesToEven; + } AdditionTests[] = { + // Test addition operations involving NaN, overflow, and the max E5M2FZN + // value (57344) because E5M2FZN differs from IEEE-754 types in these + // regards + {FromStr("57344"), FromStr("2048"), "57344", APFloat::opInexact, + APFloat::fcNormal}, + {FromStr("57344"), FromStr("4096"), "NaN", + APFloat::opOverflow | APFloat::opInexact, APFloat::fcNaN}, + {FromStr("-57344"), FromStr("-4096"), "NaN", + APFloat::opOverflow | APFloat::opInexact, APFloat::fcNaN}, + {QNaN, FromStr("-57344"), "NaN", APFloat::opOK, APFloat::fcNaN}, + {FromStr("57344"), FromStr("-8192"), "49152", APFloat::opOK, + APFloat::fcNormal}, + {FromStr("57344"), FromStr("0"), "57344", APFloat::opOK, + APFloat::fcNormal}, + {FromStr("57344"), FromStr("4096"), "57344", APFloat::opInexact, + APFloat::fcNormal, APFloat::rmTowardZero}, + {FromStr("57344"), FromStr("57344"), "57344", APFloat::opInexact, + APFloat::fcNormal, APFloat::rmTowardZero}, + }; + + for (size_t i = 0; i < std::size(AdditionTests); ++i) { + APFloat x(AdditionTests[i].x); + APFloat y(AdditionTests[i].y); + APFloat::opStatus status = x.add(y, AdditionTests[i].roundingMode); + + APFloat result(APFloat::Float8E5M2FZN(), AdditionTests[i].result); + + EXPECT_TRUE(result.bitwiseIsEqual(x)); + EXPECT_EQ(AdditionTests[i].status, (int)status); + EXPECT_EQ(AdditionTests[i].category, (int)x.getCategory()); + } +} + +TEST(APFloatTest, Float8E5M2FZNDivideByZero) { + APFloat x(APFloat::Float8E5M2FZN(), "1"); + APFloat zero(APFloat::Float8E5M2FZN(), "0"); + EXPECT_EQ(x.divide(zero, APFloat::rmNearestTiesToEven), APFloat::opDivByZero); + EXPECT_TRUE(x.isNaN()); +} + +TEST(APFloatTest, Float8E5M2FZNNext) { + APFloat test(APFloat::Float8E5M2FZN(), APFloat::uninitialized); + APFloat expected(APFloat::Float8E5M2FZN(), APFloat::uninitialized); + + // nextUp on positive numbers + for (int i = 0; i < 127; i++) { + test = APFloat(APFloat::Float8E5M2FZN(), APInt(8, i)); + expected = APFloat(APFloat::Float8E5M2FZN(), APInt(8, i + 1)); + EXPECT_EQ(test.next(false), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + } + + // nextUp on negative nonzero numbers + for (int i = 130; i < 255; i++) { + test = APFloat(APFloat::Float8E5M2FZN(), APInt(8, i)); + expected = APFloat(APFloat::Float8E5M2FZN(), APInt(8, i - 1)); + EXPECT_EQ(test.next(false), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)) << i; + } + + // nextUp on NaN + test = APFloat::getQNaN(APFloat::Float8E5M2FZN(), false); + expected = APFloat::getQNaN(APFloat::Float8E5M2FZN(), false); + EXPECT_EQ(test.next(false), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + + // nextDown on positive nonzero finite numbers + for (int i = 1; i < 127; i++) { + test = APFloat(APFloat::Float8E5M2FZN(), APInt(8, i)); + expected = APFloat(APFloat::Float8E5M2FZN(), APInt(8, i - 1)); + EXPECT_EQ(test.next(true), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + } + + // nextDown on positive zero + test = APFloat::getZero(APFloat::Float8E5M2FZN(), true); + expected = APFloat::getSmallest(APFloat::Float8E5M2FZN(), true); + EXPECT_EQ(test.next(true), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + + // nextDown on negative finite numbers + for (int i = 129; i < 255; i++) { + test = APFloat(APFloat::Float8E5M2FZN(), APInt(8, i)); + expected = APFloat(APFloat::Float8E5M2FZN(), APInt(8, i + 1)); + EXPECT_EQ(test.next(true), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)) << i; + } + + // nextDown on NaN + test = APFloat::getQNaN(APFloat::Float8E5M2FZN(), false); + expected = APFloat::getQNaN(APFloat::Float8E5M2FZN(), false); + EXPECT_EQ(test.next(true), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); +} + +TEST(APFloatTest, Float8E5M2FZNExhaustive) { + // Test each of the 256 Float8E5M2FZN values. + for (int i = 0; i < 256; i++) { + APFloat test(APFloat::Float8E5M2FZN(), APInt(8, i)); + SCOPED_TRACE("i=" + std::to_string(i)); + + // isLargest + if (i == 127 || i == 255) { + EXPECT_TRUE(test.isLargest()); + EXPECT_EQ(abs(test).convertToDouble(), 57344); + } else { + EXPECT_FALSE(test.isLargest()); + } + + // isSmallest + if (i == 1 || i == 129) { + EXPECT_TRUE(test.isSmallest()); + EXPECT_EQ(abs(test).convertToDouble(), 0x1p-17); + } else { + EXPECT_FALSE(test.isSmallest()); + } + + // convert to BFloat + APFloat test2 = test; + bool loses_info; + APFloat::opStatus status = test2.convert( + APFloat::BFloat(), APFloat::rmNearestTiesToEven, &loses_info); + EXPECT_EQ(status, APFloat::opOK); + EXPECT_FALSE(loses_info); + if (i == 128) + EXPECT_TRUE(test2.isNaN()); + else + EXPECT_EQ(test.convertToFloat(), test2.convertToFloat()); + + // bitcastToAPInt + EXPECT_EQ(i, test.bitcastToAPInt()); + } +} + +TEST(APFloatTest, Float8E5M2FZNExhaustivePair) { + // Test each pair of Float8E5M2FZN values. + for (int i = 0; i < 256; i++) { + for (int j = 0; j < 256; j++) { + SCOPED_TRACE("i=" + std::to_string(i) + ",j=" + std::to_string(j)); + APFloat x(APFloat::Float8E5M2FZN(), APInt(8, i)); + APFloat y(APFloat::Float8E5M2FZN(), APInt(8, j)); + + bool losesInfo; + APFloat x16 = x; + x16.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_FALSE(losesInfo); + APFloat y16 = y; + y16.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_FALSE(losesInfo); + + // Add + APFloat z = x; + z.add(y, APFloat::rmNearestTiesToEven); + APFloat z16 = x16; + z16.add(y16, APFloat::rmNearestTiesToEven); + z16.convert(APFloat::Float8E5M2FZN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)); + + // Subtract + z = x; + z.subtract(y, APFloat::rmNearestTiesToEven); + z16 = x16; + z16.subtract(y16, APFloat::rmNearestTiesToEven); + z16.convert(APFloat::Float8E5M2FZN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)); + + // Multiply + z = x; + z.multiply(y, APFloat::rmNearestTiesToEven); + z16 = x16; + z16.multiply(y16, APFloat::rmNearestTiesToEven); + z16.convert(APFloat::Float8E5M2FZN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j; + + // Divide + z = x; + z.divide(y, APFloat::rmNearestTiesToEven); + z16 = x16; + z16.divide(y16, APFloat::rmNearestTiesToEven); + z16.convert(APFloat::Float8E5M2FZN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j; + + // Mod + z = x; + z.mod(y); + z16 = x16; + z16.mod(y16); + z16.convert(APFloat::Float8E5M2FZN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j; + + // Remainder + z = x; + z.remainder(y); + z16 = x16; + z16.remainder(y16); + z16.convert(APFloat::Float8E5M2FZN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j; + } + } +} + TEST(APFloatTest, F8ToString) { for (APFloat::Semantics S : - {APFloat::S_Float8E5M2, APFloat::S_Float8E4M3FN}) { + {APFloat::S_Float8E5M2, APFloat::S_Float8E4M3FN, + APFloat::S_Float8E5M2FZN, APFloat::S_Float8E4M3FZN}) { SCOPED_TRACE("Semantics=" + std::to_string(S)); for (int i = 0; i < 256; i++) { SCOPED_TRACE("i=" + std::to_string(i)); - APFloat test(APFloat::Float8E5M2(), APInt(8, i)); + APFloat test(APFloat::EnumToSemantics(S), APInt(8, i)); llvm::SmallString<128> str; test.toString(str); if (test.isNaN()) { EXPECT_EQ(str, "NaN"); } else { - APFloat test2(APFloat::Float8E5M2(), str); + APFloat test2(APFloat::EnumToSemantics(S), str); EXPECT_TRUE(test.bitwiseIsEqual(test2)); } } @@ -5458,6 +5946,56 @@ EXPECT_TRUE(std::isnan(QNaN.convertToDouble())); } +TEST(APFloatTest, Float8E5M2FZNToDouble) { + APFloat One(APFloat::Float8E5M2FZN(), "1.0"); + EXPECT_EQ(1.0, One.convertToDouble()); + APFloat Two(APFloat::Float8E5M2FZN(), "2.0"); + EXPECT_EQ(2.0, Two.convertToDouble()); + APFloat PosLargest = APFloat::getLargest(APFloat::Float8E5M2FZN(), false); + EXPECT_EQ(57344., PosLargest.convertToDouble()); + APFloat NegLargest = APFloat::getLargest(APFloat::Float8E5M2FZN(), true); + EXPECT_EQ(-57344., NegLargest.convertToDouble()); + APFloat PosSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E5M2FZN(), false); + EXPECT_EQ(0x1.p-15, PosSmallest.convertToDouble()); + APFloat NegSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E5M2FZN(), true); + EXPECT_EQ(-0x1.p-15, NegSmallest.convertToDouble()); + + APFloat SmallestDenorm = + APFloat::getSmallest(APFloat::Float8E5M2FZN(), false); + EXPECT_TRUE(SmallestDenorm.isDenormal()); + EXPECT_EQ(0x1p-17, SmallestDenorm.convertToDouble()); + + APFloat QNaN = APFloat::getQNaN(APFloat::Float8E5M2FZN()); + EXPECT_TRUE(std::isnan(QNaN.convertToDouble())); +} + +TEST(APFloatTest, Float8E4M3FZNToDouble) { + APFloat One(APFloat::Float8E4M3FZN(), "1.0"); + EXPECT_EQ(1.0, One.convertToDouble()); + APFloat Two(APFloat::Float8E4M3FZN(), "2.0"); + EXPECT_EQ(2.0, Two.convertToDouble()); + APFloat PosLargest = APFloat::getLargest(APFloat::Float8E4M3FZN(), false); + EXPECT_EQ(240., PosLargest.convertToDouble()); + APFloat NegLargest = APFloat::getLargest(APFloat::Float8E4M3FZN(), true); + EXPECT_EQ(-240., NegLargest.convertToDouble()); + APFloat PosSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E4M3FZN(), false); + EXPECT_EQ(0x1.p-7, PosSmallest.convertToDouble()); + APFloat NegSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E4M3FZN(), true); + EXPECT_EQ(-0x1.p-7, NegSmallest.convertToDouble()); + + APFloat SmallestDenorm = + APFloat::getSmallest(APFloat::Float8E4M3FZN(), false); + EXPECT_TRUE(SmallestDenorm.isDenormal()); + EXPECT_EQ(0x1p-10, SmallestDenorm.convertToDouble()); + + APFloat QNaN = APFloat::getQNaN(APFloat::Float8E4M3FZN()); + EXPECT_TRUE(std::isnan(QNaN.convertToDouble())); +} + TEST(APFloatTest, IEEEsingleToFloat) { APFloat FPosZero(0.0F); APFloat FPosZeroToFloat(FPosZero.convertToFloat()); @@ -5638,4 +6176,74 @@ EXPECT_TRUE(std::isnan(QNaN.convertToFloat())); } +TEST(APFloatTest, Float8E5M2FZNToFloat) { + APFloat PosZero = APFloat::getZero(APFloat::Float8E5M2FZN()); + APFloat PosZeroToFloat(PosZero.convertToFloat()); + EXPECT_TRUE(PosZeroToFloat.isPosZero()); + + // Negative zero is not supported + APFloat NegZero = APFloat::getZero(APFloat::Float8E5M2FZN(), true); + APFloat NegZeroToFloat(NegZero.convertToFloat()); + EXPECT_TRUE(NegZeroToFloat.isPosZero()); + + APFloat One(APFloat::Float8E5M2FZN(), "1.0"); + EXPECT_EQ(1.0F, One.convertToFloat()); + APFloat Two(APFloat::Float8E5M2FZN(), "2.0"); + EXPECT_EQ(2.0F, Two.convertToFloat()); + + APFloat PosLargest = APFloat::getLargest(APFloat::Float8E5M2FZN(), false); + EXPECT_EQ(57344., PosLargest.convertToFloat()); + APFloat NegLargest = APFloat::getLargest(APFloat::Float8E5M2FZN(), true); + EXPECT_EQ(-57344., NegLargest.convertToFloat()); + APFloat PosSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E5M2FZN(), false); + EXPECT_EQ(0x1.p-15, PosSmallest.convertToFloat()); + APFloat NegSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E5M2FZN(), true); + EXPECT_EQ(-0x1.p-15, NegSmallest.convertToFloat()); + + APFloat SmallestDenorm = + APFloat::getSmallest(APFloat::Float8E5M2FZN(), false); + EXPECT_TRUE(SmallestDenorm.isDenormal()); + EXPECT_EQ(0x1.p-17, SmallestDenorm.convertToFloat()); + + APFloat QNaN = APFloat::getQNaN(APFloat::Float8E5M2FZN()); + EXPECT_TRUE(std::isnan(QNaN.convertToFloat())); +} + +TEST(APFloatTest, Float8E4M3FZNToFloat) { + APFloat PosZero = APFloat::getZero(APFloat::Float8E4M3FZN()); + APFloat PosZeroToFloat(PosZero.convertToFloat()); + EXPECT_TRUE(PosZeroToFloat.isPosZero()); + + // No negative zero + APFloat NegZero = APFloat::getZero(APFloat::Float8E4M3FZN(), true); + APFloat NegZeroToFloat(NegZero.convertToFloat()); + EXPECT_TRUE(NegZeroToFloat.isPosZero()); + + APFloat One(APFloat::Float8E4M3FZN(), "1.0"); + EXPECT_EQ(1.0F, One.convertToFloat()); + APFloat Two(APFloat::Float8E4M3FZN(), "2.0"); + EXPECT_EQ(2.0F, Two.convertToFloat()); + + APFloat PosLargest = APFloat::getLargest(APFloat::Float8E4M3FZN(), false); + EXPECT_EQ(240., PosLargest.convertToFloat()); + APFloat NegLargest = APFloat::getLargest(APFloat::Float8E4M3FZN(), true); + EXPECT_EQ(-240, NegLargest.convertToFloat()); + APFloat PosSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E4M3FZN(), false); + EXPECT_EQ(0x1.p-7, PosSmallest.convertToFloat()); + APFloat NegSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E4M3FZN(), true); + EXPECT_EQ(-0x1.p-7, NegSmallest.convertToFloat()); + + APFloat SmallestDenorm = + APFloat::getSmallest(APFloat::Float8E4M3FZN(), false); + EXPECT_TRUE(SmallestDenorm.isDenormal()); + EXPECT_EQ(0x1.p-10, SmallestDenorm.convertToFloat()); + + APFloat QNaN = APFloat::getQNaN(APFloat::Float8E4M3FZN()); + EXPECT_TRUE(std::isnan(QNaN.convertToFloat())); +} + } // namespace diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -81,6 +81,20 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx); +/// Checks whether the given type is an f8E4M3FZN type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FZN(MlirType type); + +/// Creates an f8E4M3FZN type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FZNTypeGet(MlirContext ctx); + +/// Checks whether the given type is an f8E5M2FZN type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FZN(MlirType type); + +/// Creates an f8E5M2FZN type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FZNTypeGet(MlirContext ctx); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -61,6 +61,8 @@ // Types. FloatType getFloat8E5M2Type(); FloatType getFloat8E4M3FNType(); + FloatType getFloat8E4M3FZNType(); + FloatType getFloat8E5M2FZNType(); FloatType getBF16Type(); FloatType getF16Type(); FloatType getF32Type(); diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -48,6 +48,8 @@ static FloatType getF128(MLIRContext *ctx); static FloatType getFloat8E5M2(MLIRContext *ctx); static FloatType getFloat8E4M3FN(MLIRContext *ctx); + static FloatType getFloat8E4M3FZN(MLIRContext *ctx); + static FloatType getFloat8E5M2FZN(MLIRContext *ctx); /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Type type); @@ -375,8 +377,9 @@ } inline bool FloatType::classof(Type type) { - return type.isa(); + return type.isa(); } inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) { @@ -387,6 +390,14 @@ return Float8E4M3FNType::get(ctx); } +inline FloatType FloatType::getFloat8E4M3FZN(MLIRContext *ctx) { + return Float8E4M3FZNType::get(ctx); +} + +inline FloatType FloatType::getFloat8E5M2FZN(MLIRContext *ctx) { + return Float8E5M2FZNType::get(ctx); +} + inline FloatType FloatType::getBF16(MLIRContext *ctx) { return BFloat16Type::get(ctx); } diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -119,6 +119,50 @@ }]; } +//===----------------------------------------------------------------------===// +// Float8E4M3FZNType + +def Builtin_Float8E4M3FZN : Builtin_FloatType<"Float8E4M3FZN"> { + let summary = "8-bit floating point with 3 bit mantissa"; + let description = [{ + An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits + mantissa. This is not a standard type as defined by IEEE-754, but it follows + similar conventions, with the exception that there are no infinity values, + no negative zero, and only one NaN representation. This type has the + following characteristics: + + * bit encoding: S1E4M3 + * exponent bias: 8 + * infinities: Not supported + * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s + * denormals when exponent is 0 + + Described in: https://arxiv.org/abs/2209.05433 + }]; +} + +//===----------------------------------------------------------------------===// +// Float8E5M2FZNType + +def Builtin_Float8E5M2FZN : Builtin_FloatType<"Float8E5M2FZN"> { + let summary = "8-bit floating point with 2 bit mantissa"; + let description = [{ + An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits + mantissa. This is not a standard type as defined by IEEE-754, but it follows + similar conventions, with the exception that there are no infinity values, + no negative zero, and only one NaN representation. This type has the + following characteristics: + + * bit encoding: S1E5M2 + * exponent bias: 16 + * infinities: Not supported + * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s + * denormals when exponent is 0 + + Described in: https://arxiv.org/abs/2206.02915 + }]; +} + //===----------------------------------------------------------------------===// // BFloat16Type diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -494,6 +494,10 @@ BuildableType<"$_builder.getFloat8E4M3FNType()">; def F8E5M2 : Type, "f8E5M2 type">, BuildableType<"$_builder.getFloat8E5M2Type()">; +def F8E4M3FZN : Type, "f8E4M3FZN type">, + BuildableType<"$_builder.getFloat8E4M3FZNType()">; +def F8E5M2FZN : Type, "f8E5M2FZN type">, + BuildableType<"$_builder.getFloat8E5M2FZNType()">; def AnyComplex : Type()">, "complex-type", "::mlir::ComplexType">; diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -125,6 +125,8 @@ bool isIndex() const; bool isFloat8E5M2() const; bool isFloat8E4M3FN() const; + bool isFloat8E4M3FZN() const; + bool isFloat8E5M2FZN() const; bool isBF16() const; bool isF16() const; bool isF32() const; diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -95,6 +95,8 @@ TOK_KEYWORD(f80) TOK_KEYWORD(f8E5M2) TOK_KEYWORD(f8E4M3FN) +TOK_KEYWORD(f8E4M3FZN) +TOK_KEYWORD(f8E5M2FZN) TOK_KEYWORD(f128) TOK_KEYWORD(false) TOK_KEYWORD(floordiv) diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -32,6 +32,8 @@ case Token::inttype: case Token::kw_f8E5M2: case Token::kw_f8E4M3FN: + case Token::kw_f8E4M3FZN: + case Token::kw_f8E5M2FZN: case Token::kw_bf16: case Token::kw_f16: case Token::kw_f32: @@ -294,6 +296,12 @@ case Token::kw_f8E4M3FN: consumeToken(Token::kw_f8E4M3FN); return builder.getFloat8E4M3FNType(); + case Token::kw_f8E4M3FZN: + consumeToken(Token::kw_f8E4M3FZN); + return builder.getFloat8E4M3FZNType(); + case Token::kw_f8E5M2FZN: + consumeToken(Token::kw_f8E5M2FZN); + return builder.getFloat8E5M2FZNType(); case Token::kw_bf16: consumeToken(Token::kw_bf16); return builder.getBF16Type(); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -138,6 +138,42 @@ } }; +/// Floating Point Type subclass - Float8E4M3FZN. +class PyFloat8E4M3FZNType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FZN; + static constexpr const char *pyClassName = "Float8E4M3FZNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3FZNTypeGet(context->get()); + return PyFloat8E4M3FZNType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a f8e4m3fnz type."); + } +}; + +/// Floating Point Type subclass - Float8E5M2FZN. +class PyFloat8E5M2FZNType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FZN; + static constexpr const char *pyClassName = "Float8E5M2FZNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E5M2FZNTypeGet(context->get()); + return PyFloat8E5M2FZNType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a f8e5m2fnz type."); + } +}; + /// Floating Point Type subclass - BF16Type. class PyBF16Type : public PyConcreteType { public: @@ -701,6 +737,8 @@ PyIndexType::bind(m); PyFloat8E4M3FNType::bind(m); PyFloat8E5M2Type::bind(m); + PyFloat8E4M3FZNType::bind(m); + PyFloat8E5M2FZNType::bind(m); PyBF16Type::bind(m); PyF16Type::bind(m); PyF32Type::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -84,6 +84,22 @@ return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx))); } +bool mlirTypeIsAFloat8E4M3FZN(MlirType type) { + return unwrap(type).isFloat8E4M3FZN(); +} + +MlirType mlirFloat8E4M3FZNTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E4M3FZN(unwrap(ctx))); +} + +bool mlirTypeIsAFloat8E5M2FZN(MlirType type) { + return unwrap(type).isFloat8E5M2FZN(); +} + +MlirType mlirFloat8E5M2FZNTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E5M2FZN(unwrap(ctx))); +} + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2412,6 +2412,8 @@ .Case([&](Type) { os << "index"; }) .Case([&](Type) { os << "f8E5M2"; }) .Case([&](Type) { os << "f8E4M3FN"; }) + .Case([&](Type) { os << "f8E4M3FZN"; }) + .Case([&](Type) { os << "f8E5M2FZN"; }) .Case([&](Type) { os << "bf16"; }) .Case([&](Type) { os << "f16"; }) .Case([&](Type) { os << "f32"; }) diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -41,6 +41,14 @@ return FloatType::getFloat8E4M3FN(context); } +FloatType Builder::getFloat8E4M3FZNType() { + return FloatType::getFloat8E4M3FZN(context); +} + +FloatType Builder::getFloat8E5M2FZNType() { + return FloatType::getFloat8E5M2FZN(context); +} + FloatType Builder::getBF16Type() { return FloatType::getBF16(context); } FloatType Builder::getF16Type() { return FloatType::getF16(context); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -88,7 +88,8 @@ //===----------------------------------------------------------------------===// unsigned FloatType::getWidth() { - if (isa()) + if (isa()) return 8; if (isa()) return 16; @@ -109,6 +110,10 @@ return APFloat::Float8E5M2(); if (isa()) return APFloat::Float8E4M3FN(); + if (isa()) + return APFloat::Float8E4M3FZN(); + if (isa()) + return APFloat::Float8E5M2FZN(); if (isa()) return APFloat::BFloat(); if (isa()) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -208,6 +208,8 @@ /// Cached Type Instances. Float8E5M2Type f8E5M2Ty; Float8E4M3FNType f8E4M3FNTy; + Float8E4M3FZNType f8E4M3FZNTy; + Float8E5M2FZNType f8E5M2FZNTy; BFloat16Type bf16Ty; Float16Type f16Ty; Float32Type f32Ty; @@ -280,6 +282,8 @@ /// Floating-point Types. impl->f8E5M2Ty = TypeUniquer::get(this); impl->f8E4M3FNTy = TypeUniquer::get(this); + impl->f8E4M3FZNTy = TypeUniquer::get(this); + impl->f8E5M2FZNTy = TypeUniquer::get(this); impl->bf16Ty = TypeUniquer::get(this); impl->f16Ty = TypeUniquer::get(this); impl->f32Ty = TypeUniquer::get(this); @@ -866,6 +870,12 @@ Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) { return context->getImpl().f8E4M3FNTy; } +Float8E4M3FZNType Float8E4M3FZNType::get(MLIRContext *context) { + return context->getImpl().f8E4M3FZNTy; +} +Float8E5M2FZNType Float8E5M2FZNType::get(MLIRContext *context) { + return context->getImpl().f8E5M2FZNTy; +} BFloat16Type BFloat16Type::get(MLIRContext *context) { return context->getImpl().bf16Ty; } diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -20,6 +20,8 @@ bool Type::isFloat8E5M2() const { return isa(); } bool Type::isFloat8E4M3FN() const { return isa(); } +bool Type::isFloat8E4M3FZN() const { return isa(); } +bool Type::isFloat8E5M2FZN() const { return isa(); } bool Type::isBF16() const { return isa(); } bool Type::isF16() const { return isa(); } bool Type::isF32() const { return isa(); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -52,6 +52,8 @@ "DictAttr", "Float8E4M3FNType", "Float8E5M2Type", + "Float8E4M3FZNType", + "Float8E5M2FZNType", "F16Type", "F32Type", "F64Type", @@ -586,6 +588,20 @@ @staticmethod def isinstance(arg: Any) -> bool: ... +class Float8E5M2FZNType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> Float8E5M2FZNType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + +class Float8E4M3FZNType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> Float8E4M3FZNType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + class Float8E5M2Type(Type): def __init__(self, cast_from_type: Type) -> None: ... @staticmethod diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -44,6 +44,14 @@ // CHECK: float_attr = 2.000000e+00 : f8E4M3FN float_attr = 2. : f8E4M3FN } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : f8E4M3FZN + float_attr = 2. : f8E4M3FZN + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : f8E5M2FZN + float_attr = 2. : f8E5M2FZN + } : () -> () "test.float_attrs"() { // CHECK: float_attr = 2.000000e+00 : f16 float_attr = 2. : f16 diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -197,6 +197,10 @@ print("float:", Float8E4M3FNType.get()) # CHECK: float: f8E5M2 print("float:", Float8E5M2Type.get()) + # CHECK: float: f8E4M3FZN + print("float:", Float8E4M3FZNType.get()) + # CHECK: float: f8E5M2FZN + print("float:", Float8E5M2FZNType.get()) # CHECK: float: bf16 print("float:", BF16Type.get()) # CHECK: float: f16 diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py --- a/mlir/utils/lldb-scripts/mlirDataFormatters.py +++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py @@ -52,6 +52,8 @@ "mlir::UnknownLoc": '"loc(unknown)"', "mlir::Float8E5M2Type": '"f8E5M2"', "mlir::Float8E4M3FNType": '"f8E4M3FN"', + "mlir::Float8E4M3FZNType": '"f8E4M3FZN"', + "mlir::Float8E5M2FZNType": '"f8E5M2FZN"', "mlir::BFloat16Type": '"bf16"', "mlir::Float16Type": '"f16"', "mlir::Float32Type": '"f32"',