diff --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h --- a/clang/include/clang/AST/Stmt.h +++ b/clang/include/clang/AST/Stmt.h @@ -22,6 +22,7 @@ #include "clang/Basic/LangOptions.h" #include "clang/Basic/SourceLocation.h" #include "clang/Basic/Specifiers.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/PointerIntPair.h" @@ -389,7 +390,10 @@ unsigned : NumExprBits; - unsigned Semantics : 3; // Provides semantics for APFloat construction + static_assert( + llvm::APFloat::S_MaxSemantics < 16, + "Too many Semantics enum values to fit in bitfield of size 4"); + unsigned Semantics : 4; // Provides semantics for APFloat construction unsigned IsExact : 1; }; diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp --- a/clang/lib/AST/MicrosoftMangle.cpp +++ b/clang/lib/AST/MicrosoftMangle.cpp @@ -839,6 +839,7 @@ case APFloat::S_IEEEquad: Out << 'Y'; break; case APFloat::S_PPCDoubleDouble: Out << 'Z'; break; case APFloat::S_Float8E5M2: + case APFloat::S_Float8E4M3FN: llvm_unreachable("Tried to mangle unexpected APFloat semantics"); } diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h --- a/llvm/include/llvm/ADT/APFloat.h +++ b/llvm/include/llvm/ADT/APFloat.h @@ -156,8 +156,13 @@ S_IEEEquad, S_PPCDoubleDouble, // 8-bit floating point number following IEEE-754 conventions with bit - // layout S1E5M2 as described in https://arxiv.org/abs/2209.05433 + // layout S1E5M2 as described in https://arxiv.org/abs/2209.05433. S_Float8E5M2, + // 8-bit floating point number mostly following IEEE-754 conventions with + // bit layout S1E4M3 as described in https://arxiv.org/abs/2209.05433. + // Unlike IEEE-754 types, there are no infinity values, and NaN is + // represented with the exponent and mantissa bits set to all 1s. + S_Float8E4M3FN, S_x87DoubleExtended, S_MaxSemantics = S_x87DoubleExtended, }; @@ -172,6 +177,7 @@ static const fltSemantics &IEEEquad() LLVM_READNONE; static const fltSemantics &PPCDoubleDouble() LLVM_READNONE; static const fltSemantics &Float8E5M2() LLVM_READNONE; + static const fltSemantics &Float8E4M3FN() LLVM_READNONE; static const fltSemantics &x87DoubleExtended() LLVM_READNONE; /// A Pseudo fltsemantic used to construct APFloats that cannot conflict with @@ -508,6 +514,7 @@ void zeroSignificand(); /// Return true if the significand excluding the integral bit is all ones. bool isSignificandAllOnes() const; + bool isSignificandAllOnesExceptLSB() const; /// Return true if the significand excluding the integral bit is all zeros. bool isSignificandAllZeros() const; @@ -557,6 +564,7 @@ APInt convertF80LongDoubleAPFloatToAPInt() const; APInt convertPPCDoubleDoubleAPFloatToAPInt() const; APInt convertFloat8E5M2APFloatToAPInt() const; + APInt convertFloat8E4M3FNAPFloatToAPInt() const; void initFromAPInt(const fltSemantics *Sem, const APInt &api); void initFromHalfAPInt(const APInt &api); void initFromBFloatAPInt(const APInt &api); @@ -566,6 +574,7 @@ void initFromF80LongDoubleAPInt(const APInt &api); void initFromPPCDoubleDoubleAPInt(const APInt &api); void initFromFloat8E5M2APInt(const APInt &api); + void initFromFloat8E4M3FNAPInt(const APInt &api); void assign(const IEEEFloat &); void copySignificand(const IEEEFloat &); diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp --- a/llvm/lib/Support/APFloat.cpp +++ b/llvm/lib/Support/APFloat.cpp @@ -50,6 +50,23 @@ static_assert(APFloatBase::integerPartWidth % 4 == 0, "Part width must be divisible by 4!"); namespace llvm { + + // How the nonfinite values Inf and NaN are represented. + enum class fltNonfiniteBehavior { + // Represents standard IEEE 754 behavior. A value is nonfinite if the + // exponent field is all 1s. In such cases, a value is Inf if the + // significand bits are all zero, and NaN otherwise + IEEE754, + + // Only the Float8E5M2 has this behavior. There is no Inf representation. A + // value is NaN if the exponent field and the mantissa field are all 1s. + // This behavior matches the FP8 E4M3 type described in + // https://arxiv.org/abs/2209.05433. We treat both signed and unsigned NaNs + // as non-signalling, although the paper does not state whether the NaN + // values are signalling or not. + NanOnly, + }; + /* Represents floating point arithmetic semantics. */ struct fltSemantics { /* The largest E such that 2^E is representable; this matches the @@ -67,8 +84,11 @@ /* Number of bits actually used in the semantics. */ unsigned int sizeInBits; + fltNonfiniteBehavior nonFiniteBehavior = fltNonfiniteBehavior::IEEE754; + // Returns true if any number described by this semantics can be precisely - // represented by the specified semantics. + // represented by the specified semantics. Does not take into account + // the value of fltNonfiniteBehavior. bool isRepresentableBy(const fltSemantics &S) const { return maxExponent <= S.maxExponent && minExponent >= S.minExponent && precision <= S.precision; @@ -81,6 +101,8 @@ static const fltSemantics semIEEEdouble = {1023, -1022, 53, 64}; static const fltSemantics semIEEEquad = {16383, -16382, 113, 128}; static const fltSemantics semFloat8E5M2 = {15, -14, 3, 8}; + static const fltSemantics semFloat8E4M3FN = {8, -6, 4, 8, + fltNonfiniteBehavior::NanOnly}; static const fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80}; static const fltSemantics semBogus = {0, 0, 0, 0}; @@ -138,6 +160,8 @@ return PPCDoubleDouble(); case S_Float8E5M2: return Float8E5M2(); + case S_Float8E4M3FN: + return Float8E4M3FN(); case S_x87DoubleExtended: return x87DoubleExtended(); } @@ -160,6 +184,8 @@ return S_PPCDoubleDouble; else if (&Sem == &llvm::APFloat::Float8E5M2()) return S_Float8E5M2; + else if (&Sem == &llvm::APFloat::Float8E4M3FN()) + return S_Float8E4M3FN; else if (&Sem == &llvm::APFloat::x87DoubleExtended()) return S_x87DoubleExtended; else @@ -183,6 +209,7 @@ return semPPCDoubleDouble; } const fltSemantics &APFloatBase::Float8E5M2() { return semFloat8E5M2; } + const fltSemantics &APFloatBase::Float8E4M3FN() { return semFloat8E4M3FN; } const fltSemantics &APFloatBase::x87DoubleExtended() { return semX87DoubleExtended; } @@ -769,6 +796,15 @@ integerPart *significand = significandParts(); unsigned numParts = partCount(); + APInt fill_storage; + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) { + // The only NaN representation is where the mantissa is all 1s, which is + // non-signalling. + SNaN = false; + fill_storage = APInt::getAllOnes(semantics->precision - 1); + fill = &fill_storage; + } + // Set the significand bits to the fill. if (!fill || fill->getNumWords() < numParts) APInt::tcSet(significand, 0, numParts); @@ -869,6 +905,33 @@ return true; } +bool IEEEFloat::isSignificandAllOnesExceptLSB() const { + // Test if the significand excluding the integral bit is all ones except for + // the least significant bit. + const integerPart *Parts = significandParts(); + + if (Parts[0] & 1) + return false; + + const unsigned PartCount = partCountForBits(semantics->precision); + for (unsigned i = 0; i < PartCount - 1; i++) { + if (~Parts[i] & ~unsigned{!i}) + return false; + } + + // Set the unused high bits to all ones when we compare. + const unsigned NumHighBits = + PartCount * integerPartWidth - semantics->precision + 1; + assert(NumHighBits <= integerPartWidth && NumHighBits > 0 && + "Can not have more high bits to fill than integerPartWidth"); + const integerPart HighBitFill = ~integerPart(0) + << (integerPartWidth - NumHighBits); + if (~(Parts[PartCount - 1] | HighBitFill | 0x1)) + return false; + + return true; +} + bool IEEEFloat::isSignificandAllZeros() const { // Test if the significand excluding the integral bit is all zeros. This // allows us to test for binade boundaries. @@ -893,10 +956,18 @@ } bool IEEEFloat::isLargest() const { - // The largest number by magnitude in our format will be the floating point - // number with maximum exponent and with significand that is all ones. - return isFiniteNonZero() && exponent == semantics->maxExponent - && isSignificandAllOnes(); + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) { + // The largest number by magnitude in our format will be the floating point + // number with maximum exponent and with significand that is all ones except + // the LSB. + return isFiniteNonZero() && exponent == semantics->maxExponent && + isSignificandAllOnesExceptLSB(); + } else { + // The largest number by magnitude in our format will be the floating point + // number with maximum exponent and with significand that is all ones. + return isFiniteNonZero() && exponent == semantics->maxExponent && + isSignificandAllOnes(); + } } bool IEEEFloat::isInteger() const { @@ -1315,7 +1386,10 @@ rounding_mode == rmNearestTiesToAway || (rounding_mode == rmTowardPositive && !sign) || (rounding_mode == rmTowardNegative && sign)) { - category = fcInfinity; + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) + makeNaN(false, sign); + else + category = fcInfinity; return (opStatus) (opOverflow | opInexact); } @@ -1324,6 +1398,8 @@ exponent = semantics->maxExponent; tcSetLeastSignificantBits(significandParts(), partCount(), semantics->precision); + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) + APInt::tcClearBit(significandParts(), 0); return opInexact; } @@ -1423,6 +1499,10 @@ } } + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly && + exponent == semantics->maxExponent && isSignificandAllOnes()) + return handleOverflow(rounding_mode); + /* Now round the number according to rounding_mode given the lost fraction. */ @@ -1459,6 +1539,10 @@ return opInexact; } + + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly && + exponent == semantics->maxExponent && isSignificandAllOnes()) + return handleOverflow(rounding_mode); } /* The normal case - we were and are not denormal, and any @@ -1679,7 +1763,10 @@ return opOK; case PackCategoriesIntoKey(fcNormal, fcZero): - category = fcInfinity; + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) + makeNaN(false, sign); + else + category = fcInfinity; return opDivByZero; case PackCategoriesIntoKey(fcInfinity, fcInfinity): @@ -1965,9 +2052,12 @@ while (isFiniteNonZero() && rhs.isFiniteNonZero() && compareAbsoluteValue(rhs) != cmpLessThan) { - IEEEFloat V = scalbn(rhs, ilogb(*this) - ilogb(rhs), rmNearestTiesToEven); - if (compareAbsoluteValue(V) == cmpLessThan) - V = scalbn(V, -1, rmNearestTiesToEven); + int Exp = ilogb(*this) - ilogb(rhs); + IEEEFloat V = scalbn(rhs, Exp, rmNearestTiesToEven); + // V can overflow to NaN with fltNonfiniteBehavior::NanOnly, so explicitly + // check for it. + if (V.isNaN() || compareAbsoluteValue(V) == cmpLessThan) + V = scalbn(rhs, Exp - 1, rmNearestTiesToEven); V.sign = sign; fs = subtract(V, rmNearestTiesToEven); @@ -2194,6 +2284,7 @@ opStatus fs; int shift; const fltSemantics &fromSemantics = *semantics; + bool is_signaling = isSignaling(); lostFraction = lfExactlyZero; newPartCount = partCountForBits(toSemantics.precision + 1); @@ -2235,7 +2326,9 @@ } // If this is a truncation, perform the shift before we narrow the storage. - if (shift < 0 && (isFiniteNonZero() || category==fcNaN)) + if (shift < 0 && (isFiniteNonZero() || + (category == fcNaN && semantics->nonFiniteBehavior != + fltNonfiniteBehavior::NanOnly))) lostFraction = shiftRight(significandParts(), oldPartCount, -shift); // Fix the storage so it can hold to new value. @@ -2269,6 +2362,13 @@ fs = normalize(rounding_mode, lostFraction); *losesInfo = (fs != opOK); } else if (category == fcNaN) { + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) { + *losesInfo = + fromSemantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnly; + makeNaN(false, sign); + return is_signaling ? opInvalidOp : opOK; + } + *losesInfo = lostFraction != lfExactlyZero || X86SpecialNan; // For x87 extended precision, we want to make a NaN, not a special NaN if @@ -2279,12 +2379,17 @@ // Convert of sNaN creates qNaN and raises an exception (invalid op). // This also guarantees that a sNaN does not become Inf on a truncation // that loses all payload bits. - if (isSignaling()) { + if (is_signaling) { makeQuiet(); fs = opInvalidOp; } else { fs = opOK; } + } else if (category == fcInfinity && + semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) { + makeNaN(false, sign); + *losesInfo = true; + fs = opInexact; } else { *losesInfo = false; fs = opOK; @@ -3382,6 +3487,33 @@ (mysignificand & 0x3))); } +APInt IEEEFloat::convertFloat8E4M3FNAPFloatToAPInt() const { + assert(semantics == (const llvm::fltSemantics *)&semFloat8E4M3FN); + assert(partCount() == 1); + + uint32_t myexponent, mysignificand; + + if (isFiniteNonZero()) { + myexponent = exponent + 7; // bias + mysignificand = (uint32_t)*significandParts(); + if (myexponent == 1 && !(mysignificand & 0x8)) + myexponent = 0; // denormal + } else if (category == fcZero) { + myexponent = 0; + mysignificand = 0; + } else if (category == fcInfinity) { + myexponent = 0xf; + mysignificand = 0; + } else { + assert(category == fcNaN && "Unknown category!"); + myexponent = 0xf; + mysignificand = (uint32_t)*significandParts(); + } + + return APInt(8, (((sign & 1) << 7) | ((myexponent & 0xf) << 3) | + (mysignificand & 0x7))); +} + // This function creates an APInt that is just a bit map of the floating // point constant as it would appear in memory. It is not a conversion, // and treating the result as a normal integer is unlikely to be useful. @@ -3408,6 +3540,9 @@ if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2) return convertFloat8E5M2APFloatToAPInt(); + if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FN) + return convertFloat8E4M3FNAPFloatToAPInt(); + assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended && "unknown format!"); return convertF80LongDoubleAPFloatToAPInt(); @@ -3663,10 +3798,33 @@ } } -/// Treat api as containing the bits of a floating point number. Currently -/// we infer the floating point type from the size of the APInt. The -/// isIEEE argument distinguishes between PPC128 and IEEE128 (not meaningful -/// when the size is anything else). +void IEEEFloat::initFromFloat8E4M3FNAPInt(const APInt &api) { + uint32_t i = (uint32_t)*api.getRawData(); + uint32_t myexponent = (i >> 3) & 0xf; + uint32_t mysignificand = i & 0x7; + + initialize(&semFloat8E4M3FN); + assert(partCount() == 1); + + sign = i >> 7; + if (myexponent == 0 && mysignificand == 0) { + makeZero(sign); + } else if (myexponent == 0xf && mysignificand == 7) { + category = fcNaN; + exponent = exponentNaN(); + *significandParts() = mysignificand; + } else { + category = fcNormal; + exponent = myexponent - 7; // bias + *significandParts() = mysignificand; + if (myexponent == 0) // denormal + exponent = -6; + else + *significandParts() |= 0x8; // integer bit + } +} + +/// Treat api as containing the bits of a floating point number. void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) { assert(api.getBitWidth() == Sem->sizeInBits); if (Sem == &semIEEEhalf) @@ -3685,6 +3843,8 @@ return initFromPPCDoubleDoubleAPInt(api); if (Sem == &semFloat8E5M2) return initFromFloat8E5M2APInt(api); + if (Sem == &semFloat8E4M3FN) + return initFromFloat8E4M3FNAPInt(api); llvm_unreachable(nullptr); } @@ -3712,6 +3872,9 @@ significand[PartCount - 1] = (NumUnusedHighBits < integerPartWidth) ? (~integerPart(0) >> NumUnusedHighBits) : 0; + + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) + significand[0] &= ~integerPart(1); } /// Make this number the smallest magnitude denormal number in the given @@ -4085,6 +4248,8 @@ bool IEEEFloat::isSignaling() const { if (!isNaN()) return false; + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) + return false; // IEEE-754R 2008 6.2.1: A signaling NaN bit string should be encoded with the // first bit of the trailing significand being 0. @@ -4135,12 +4300,18 @@ break; } - // nextUp(getLargest()) == INFINITY if (isLargest() && !isNegative()) { - APInt::tcSet(significandParts(), 0, partCount()); - category = fcInfinity; - exponent = semantics->maxExponent + 1; - break; + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) { + // nextUp(getLargest()) == NAN + makeNaN(); + break; + } else { + // nextUp(getLargest()) == INFINITY + APInt::tcSet(significandParts(), 0, partCount()); + category = fcInfinity; + exponent = semantics->maxExponent + 1; + break; + } } // nextUp(normal) == normal + inc. @@ -4212,6 +4383,8 @@ } APFloatBase::ExponentType IEEEFloat::exponentNaN() const { + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) + return semantics->maxExponent; return semantics->maxExponent + 1; } @@ -4224,6 +4397,11 @@ } void IEEEFloat::makeInf(bool Negative) { + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) { + // There is no Inf, so make NaN instead. + makeNaN(false, Negative); + return; + } category = fcInfinity; sign = Negative; exponent = exponentInf(); @@ -4239,7 +4417,8 @@ void IEEEFloat::makeQuiet() { assert(isNaN()); - APInt::tcSetBit(significandParts(), semantics->precision - 2); + if (semantics->nonFiniteBehavior != fltNonfiniteBehavior::NanOnly) + APInt::tcSetBit(significandParts(), semantics->precision - 2); } int ilogb(const IEEEFloat &Arg) { diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp --- a/llvm/unittests/ADT/APFloatTest.cpp +++ b/llvm/unittests/ADT/APFloatTest.cpp @@ -1683,6 +1683,7 @@ TEST(APFloatTest, getLargest) { EXPECT_EQ(3.402823466e+38f, APFloat::getLargest(APFloat::IEEEsingle()).convertToFloat()); EXPECT_EQ(1.7976931348623158e+308, APFloat::getLargest(APFloat::IEEEdouble()).convertToDouble()); + EXPECT_EQ(448, APFloat::getLargest(APFloat::Float8E4M3FN()).convertToDouble()); } TEST(APFloatTest, getSmallest) { @@ -1766,6 +1767,8 @@ {&APFloat::x87DoubleExtended(), true, {0, 0x8000ULL}, 2}, {&APFloat::Float8E5M2(), false, {0, 0}, 1}, {&APFloat::Float8E5M2(), true, {0x80ULL, 0}, 1}, + {&APFloat::Float8E4M3FN(), false, {0, 0}, 1}, + {&APFloat::Float8E4M3FN(), true, {0x80ULL, 0}, 1}, }; const unsigned NumGetZeroTests = 12; for (unsigned i = 0; i < NumGetZeroTests; ++i) { @@ -3665,6 +3668,16 @@ EXPECT_EQ(f1.mod(f2), APFloat::opOK); EXPECT_TRUE(f1.bitwiseIsEqual(expected)); } + { + // Test E4M3FN mod where the LHS exponent is maxExponent (8) and the RHS is + // the max value whose exponent is minExponent (-6). This requires special + // logic in the mod implementation to prevent overflow to NaN. + APFloat f1(APFloat::Float8E4M3FN(), "0x1p8"); // 256 + APFloat f2(APFloat::Float8E4M3FN(), "0x1.ep-6"); // 0.029296875 + APFloat expected(APFloat::Float8E4M3FN(), "0x1p-8"); // 0.00390625 + EXPECT_EQ(f1.mod(f2), APFloat::opOK); + EXPECT_TRUE(f1.bitwiseIsEqual(expected)); + } } TEST(APFloatTest, remainder) { @@ -4756,6 +4769,389 @@ EXPECT_TRUE(ilogb(F) == -1); } +TEST(APFloatTest, ConvertE4M3FNToE5M2) { + bool losesInfo; + APFloat test(APFloat::Float8E4M3FN(), "1.0"); + APFloat::opStatus status = test.convert( + APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven, &losesInfo); + EXPECT_EQ(1.0f, test.convertToFloat()); + EXPECT_FALSE(losesInfo); + EXPECT_EQ(status, APFloat::opOK); + + test = APFloat(APFloat::Float8E4M3FN(), "0.0"); + status = test.convert(APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(0.0f, test.convertToFloat()); + EXPECT_FALSE(losesInfo); + EXPECT_EQ(status, APFloat::opOK); + + test = APFloat(APFloat::Float8E4M3FN(), "0x1.2p0"); // 1.125 + status = test.convert(APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(0x1.0p0 /* 1.0 */, test.convertToFloat()); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opInexact); + + test = APFloat(APFloat::Float8E4M3FN(), "0x1.6p0"); // 1.375 + status = test.convert(APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(0x1.8p0 /* 1.5 */, test.convertToFloat()); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opInexact); + + // Convert E4M3 denormal to E5M2 normal. Should not be truncated, despite the + // destination format having one fewer significand bit + test = APFloat(APFloat::Float8E4M3FN(), "0x1.Cp-7"); + status = test.convert(APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(0x1.Cp-7, test.convertToFloat()); + EXPECT_FALSE(losesInfo); + EXPECT_EQ(status, APFloat::opOK); + + // Test convert from NaN + test = APFloat(APFloat::Float8E4M3FN(), "nan"); + status = test.convert(APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(std::isnan(test.convertToFloat())); + EXPECT_FALSE(losesInfo); + EXPECT_EQ(status, APFloat::opOK); +} + +TEST(APFloatTest, ConvertE5M2ToE4M3FN) { + bool losesInfo; + APFloat test(APFloat::Float8E5M2(), "1.0"); + APFloat::opStatus status = test.convert( + APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, &losesInfo); + EXPECT_EQ(1.0f, test.convertToFloat()); + EXPECT_FALSE(losesInfo); + EXPECT_EQ(status, APFloat::opOK); + + test = APFloat(APFloat::Float8E5M2(), "0.0"); + status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(0.0f, test.convertToFloat()); + EXPECT_FALSE(losesInfo); + EXPECT_EQ(status, APFloat::opOK); + + test = APFloat(APFloat::Float8E5M2(), "0x1.Cp8"); // 448 + status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(0x1.Cp8 /* 448 */, test.convertToFloat()); + EXPECT_FALSE(losesInfo); + EXPECT_EQ(status, APFloat::opOK); + + // Test overflow + test = APFloat(APFloat::Float8E5M2(), "0x1.0p9"); // 512 + status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(std::isnan(test.convertToFloat())); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opOverflow | APFloat::opInexact); + + // Test underflow + test = APFloat(APFloat::Float8E5M2(), "0x1.0p-10"); + status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(0., test.convertToFloat()); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opUnderflow | APFloat::opInexact); + + // Test rounding up to smallest denormal number + test = APFloat(APFloat::Float8E5M2(), "0x1.8p-10"); + status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(0x1.0p-9, test.convertToFloat()); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opUnderflow | APFloat::opInexact); + + // Testing inexact rounding to denormal number + test = APFloat(APFloat::Float8E5M2(), "0x1.8p-9"); + status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(0x1.0p-8, test.convertToFloat()); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opUnderflow | APFloat::opInexact); + + APFloat nan = APFloat(APFloat::Float8E4M3FN(), "nan"); + + // Testing convert from Inf + test = APFloat(APFloat::Float8E5M2(), "inf"); + status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(std::isnan(test.convertToFloat())); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opInexact); + EXPECT_TRUE(test.bitwiseIsEqual(nan)); + + // Testing convert from quiet NaN + test = APFloat(APFloat::Float8E5M2(), "nan"); + status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(std::isnan(test.convertToFloat())); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(nan)); + + // Testing convert from signaling NaN + test = APFloat(APFloat::Float8E5M2(), "snan"); + status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(std::isnan(test.convertToFloat())); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opInvalidOp); + EXPECT_TRUE(test.bitwiseIsEqual(nan)); +} + +TEST(APFloatTest, Float8E4M3FNGetInf) { + APFloat t = APFloat::getInf(APFloat::Float8E4M3FN()); + EXPECT_TRUE(t.isNaN()); + EXPECT_FALSE(t.isInfinity()); +} + +TEST(APFloatTest, Float8E4M3FNFromString) { + // Exactly representable + EXPECT_EQ(448, APFloat(APFloat::Float8E4M3FN(), "448").convertToDouble()); + // Round down to maximum value + EXPECT_EQ(448, APFloat(APFloat::Float8E4M3FN(), "464").convertToDouble()); + // Round up, causing overflow to NaN + EXPECT_TRUE(APFloat(APFloat::Float8E4M3FN(), "465").isNaN()); + // Overflow without rounding + EXPECT_TRUE(APFloat(APFloat::Float8E4M3FN(), "480").isNaN()); + // Inf converted to NaN + EXPECT_TRUE(APFloat(APFloat::Float8E4M3FN(), "inf").isNaN()); + // NaN converted to NaN + EXPECT_TRUE(APFloat(APFloat::Float8E4M3FN(), "nan").isNaN()); +} + +TEST(APFloatTest, Float8E4M3FNAdd) { + APFloat QNaN = APFloat::getNaN(APFloat::Float8E4M3FN(), false); + + auto FromStr = [](StringRef S) { + return APFloat(APFloat::Float8E4M3FN(), S); + }; + + struct { + APFloat x; + APFloat y; + const char *result; + int status; + int category; + APFloat::roundingMode roundingMode = APFloat::rmNearestTiesToEven; + } AdditionTests[] = { + // Test addition operations involving NaN, overflow, and the max E4M3 + // value (448) because E4M3 differs from IEEE-754 types in these regards + {FromStr("448"), FromStr("16"), "448", APFloat::opInexact, + APFloat::fcNormal}, + {FromStr("448"), FromStr("18"), "NaN", + APFloat::opOverflow | APFloat::opInexact, APFloat::fcNaN}, + {FromStr("448"), FromStr("32"), "NaN", + APFloat::opOverflow | APFloat::opInexact, APFloat::fcNaN}, + {FromStr("-448"), FromStr("-32"), "-NaN", + APFloat::opOverflow | APFloat::opInexact, APFloat::fcNaN}, + {QNaN, FromStr("-448"), "NaN", APFloat::opOK, APFloat::fcNaN}, + {FromStr("448"), FromStr("-32"), "416", APFloat::opOK, APFloat::fcNormal}, + {FromStr("448"), FromStr("0"), "448", APFloat::opOK, APFloat::fcNormal}, + {FromStr("448"), FromStr("32"), "448", APFloat::opInexact, + APFloat::fcNormal, APFloat::rmTowardZero}, + {FromStr("448"), FromStr("448"), "448", APFloat::opInexact, + APFloat::fcNormal, APFloat::rmTowardZero}, + }; + + for (size_t i = 0; i < std::size(AdditionTests); ++i) { + APFloat x(AdditionTests[i].x); + APFloat y(AdditionTests[i].y); + APFloat::opStatus status = x.add(y, AdditionTests[i].roundingMode); + + APFloat result(APFloat::Float8E4M3FN(), AdditionTests[i].result); + + EXPECT_TRUE(result.bitwiseIsEqual(x)); + EXPECT_EQ(AdditionTests[i].status, (int)status); + EXPECT_EQ(AdditionTests[i].category, (int)x.getCategory()); + } +} + +TEST(APFloatTest, Float8E4M3FNDivideByZero) { + APFloat x(APFloat::Float8E4M3FN(), "1"); + APFloat zero(APFloat::Float8E4M3FN(), "0"); + EXPECT_EQ(x.divide(zero, APFloat::rmNearestTiesToEven), APFloat::opDivByZero); + EXPECT_TRUE(x.isNaN()); +} + +TEST(APFloatTest, Float8E4M3FNNext) { + APFloat test(APFloat::Float8E4M3FN(), APFloat::uninitialized); + APFloat expected(APFloat::Float8E4M3FN(), APFloat::uninitialized); + + // nextUp on positive numbers + for (int i = 0; i < 127; i++) { + test = APFloat(APFloat::Float8E4M3FN(), APInt(8, i)); + expected = APFloat(APFloat::Float8E4M3FN(), APInt(8, i + 1)); + EXPECT_EQ(test.next(false), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + } + + // nextUp on negative zero + test = APFloat::getZero(APFloat::Float8E4M3FN(), true); + expected = APFloat::getSmallest(APFloat::Float8E4M3FN(), false); + EXPECT_EQ(test.next(false), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + + // nextUp on negative nonzero numbers + for (int i = 129; i < 255; i++) { + test = APFloat(APFloat::Float8E4M3FN(), APInt(8, i)); + expected = APFloat(APFloat::Float8E4M3FN(), APInt(8, i - 1)); + EXPECT_EQ(test.next(false), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + } + + // nextUp on NaN + test = APFloat::getQNaN(APFloat::Float8E4M3FN(), false); + expected = APFloat::getQNaN(APFloat::Float8E4M3FN(), false); + EXPECT_EQ(test.next(false), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + + // nextDown on positive nonzero finite numbers + for (int i = 1; i < 127; i++) { + test = APFloat(APFloat::Float8E4M3FN(), APInt(8, i)); + expected = APFloat(APFloat::Float8E4M3FN(), APInt(8, i - 1)); + EXPECT_EQ(test.next(true), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + } + + // nextDown on positive zero + test = APFloat::getZero(APFloat::Float8E4M3FN(), true); + expected = APFloat::getSmallest(APFloat::Float8E4M3FN(), true); + EXPECT_EQ(test.next(true), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + + // nextDown on negative finite numbers + for (int i = 128; i < 255; i++) { + test = APFloat(APFloat::Float8E4M3FN(), APInt(8, i)); + expected = APFloat(APFloat::Float8E4M3FN(), APInt(8, i + 1)); + EXPECT_EQ(test.next(true), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + } + + // nextDown on NaN + test = APFloat::getQNaN(APFloat::Float8E4M3FN(), false); + expected = APFloat::getQNaN(APFloat::Float8E4M3FN(), false); + EXPECT_EQ(test.next(true), APFloat::opOK); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); +} + +TEST(APFloatTest, Float8E4M3FNExhaustive) { + // Test each of the 256 Float8E4M3FN values. + for (int i = 0; i < 256; i++) { + APFloat test(APFloat::Float8E4M3FN(), APInt(8, i)); + SCOPED_TRACE("i=" + std::to_string(i)); + + // isLargest + if (i == 126 || i == 254) { + EXPECT_TRUE(test.isLargest()); + EXPECT_EQ(abs(test).convertToDouble(), 448.); + } else { + EXPECT_FALSE(test.isLargest()); + } + + // isSmallest + if (i == 1 || i == 129) { + EXPECT_TRUE(test.isSmallest()); + EXPECT_EQ(abs(test).convertToDouble(), 0x1p-9); + } else { + EXPECT_FALSE(test.isSmallest()); + } + + // convert to BFloat + APFloat test2 = test; + bool loses_info; + APFloat::opStatus status = test2.convert( + APFloat::BFloat(), APFloat::rmNearestTiesToEven, &loses_info); + EXPECT_EQ(status, APFloat::opOK); + EXPECT_FALSE(loses_info); + if (i == 127 || i == 255) + EXPECT_TRUE(test2.isNaN()); + else + EXPECT_EQ(test.convertToFloat(), test2.convertToFloat()); + + // bitcastToAPInt + EXPECT_EQ(i, test.bitcastToAPInt()); + } +} + +TEST(APFloatTest, Float8E4M3FNExhaustivePair) { + // Test each pair of Float8E4M3FN values. + for (int i = 0; i < 256; i++) { + for (int j = 0; j < 256; j++) { + SCOPED_TRACE("i=" + std::to_string(i) + ",j=" + std::to_string(j)); + APFloat x(APFloat::Float8E4M3FN(), APInt(8, i)); + APFloat y(APFloat::Float8E4M3FN(), APInt(8, j)); + + bool losesInfo; + APFloat x16 = x; + x16.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_FALSE(losesInfo); + APFloat y16 = y; + y16.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_FALSE(losesInfo); + + // Add + APFloat z = x; + z.add(y, APFloat::rmNearestTiesToEven); + APFloat z16 = x16; + z16.add(y16, APFloat::rmNearestTiesToEven); + z16.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)); + + // Subtract + z = x; + z.subtract(y, APFloat::rmNearestTiesToEven); + z16 = x16; + z16.subtract(y16, APFloat::rmNearestTiesToEven); + z16.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)); + + // Multiply + z = x; + z.multiply(y, APFloat::rmNearestTiesToEven); + z16 = x16; + z16.multiply(y16, APFloat::rmNearestTiesToEven); + z16.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j; + + // Divide + z = x; + z.divide(y, APFloat::rmNearestTiesToEven); + z16 = x16; + z16.divide(y16, APFloat::rmNearestTiesToEven); + z16.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j; + + // Mod + z = x; + z.mod(y); + z16 = x16; + z16.mod(y16); + z16.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j; + + // Remainder + z = x; + z.remainder(y); + z16 = x16; + z16.remainder(y16); + z16.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j; + } + } +} + TEST(APFloatTest, IEEEdoubleToDouble) { APFloat DPosZero(0.0); APFloat DPosZeroToDouble(DPosZero.convertToDouble()); @@ -4937,6 +5333,30 @@ EXPECT_TRUE(std::isnan(QNaN.convertToDouble())); } +TEST(APFloatTest, Float8E4M3FNToDouble) { + APFloat One(APFloat::Float8E4M3FN(), "1.0"); + EXPECT_EQ(1.0, One.convertToDouble()); + APFloat Two(APFloat::Float8E4M3FN(), "2.0"); + EXPECT_EQ(2.0, Two.convertToDouble()); + APFloat PosLargest = APFloat::getLargest(APFloat::Float8E4M3FN(), false); + EXPECT_EQ(448., PosLargest.convertToDouble()); + APFloat NegLargest = APFloat::getLargest(APFloat::Float8E4M3FN(), true); + EXPECT_EQ(-448., NegLargest.convertToDouble()); + APFloat PosSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E4M3FN(), false); + EXPECT_EQ(0x1.p-6, PosSmallest.convertToDouble()); + APFloat NegSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E4M3FN(), true); + EXPECT_EQ(-0x1.p-6, NegSmallest.convertToDouble()); + + APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E4M3FN(), false); + EXPECT_TRUE(SmallestDenorm.isDenormal()); + EXPECT_EQ(0x1p-9, SmallestDenorm.convertToDouble()); + + APFloat QNaN = APFloat::getQNaN(APFloat::Float8E4M3FN()); + EXPECT_TRUE(std::isnan(QNaN.convertToDouble())); +} + TEST(APFloatTest, IEEEsingleToFloat) { APFloat FPosZero(0.0F); APFloat FPosZeroToFloat(FPosZero.convertToFloat()); @@ -5085,4 +5505,36 @@ EXPECT_TRUE(std::isnan(QNaN.convertToFloat())); } +TEST(APFloatTest, Float8E4M3FNToFloat) { + APFloat PosZero = APFloat::getZero(APFloat::Float8E4M3FN()); + APFloat PosZeroToFloat(PosZero.convertToFloat()); + EXPECT_TRUE(PosZeroToFloat.isPosZero()); + APFloat NegZero = APFloat::getZero(APFloat::Float8E4M3FN(), true); + APFloat NegZeroToFloat(NegZero.convertToFloat()); + EXPECT_TRUE(NegZeroToFloat.isNegZero()); + + APFloat One(APFloat::Float8E4M3FN(), "1.0"); + EXPECT_EQ(1.0F, One.convertToFloat()); + APFloat Two(APFloat::Float8E4M3FN(), "2.0"); + EXPECT_EQ(2.0F, Two.convertToFloat()); + + APFloat PosLargest = APFloat::getLargest(APFloat::Float8E4M3FN(), false); + EXPECT_EQ(448., PosLargest.convertToFloat()); + APFloat NegLargest = APFloat::getLargest(APFloat::Float8E4M3FN(), true); + EXPECT_EQ(-448, NegLargest.convertToFloat()); + APFloat PosSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E4M3FN(), false); + EXPECT_EQ(0x1.p-6, PosSmallest.convertToFloat()); + APFloat NegSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E4M3FN(), true); + EXPECT_EQ(-0x1.p-6, NegSmallest.convertToFloat()); + + APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E4M3FN(), false); + EXPECT_TRUE(SmallestDenorm.isDenormal()); + EXPECT_EQ(0x1.p-9, SmallestDenorm.convertToFloat()); + + APFloat QNaN = APFloat::getQNaN(APFloat::Float8E4M3FN()); + EXPECT_TRUE(std::isnan(QNaN.convertToFloat())); +} + } // namespace