diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp --- a/clang/lib/AST/MicrosoftMangle.cpp +++ b/clang/lib/AST/MicrosoftMangle.cpp @@ -845,6 +845,7 @@ case APFloat::S_Float8E4M3FN: case APFloat::S_Float8E5M2FNUZ: case APFloat::S_Float8E4M3FNUZ: + case APFloat::S_Float8E4M3B11FNUZ: llvm_unreachable("Tried to mangle unexpected APFloat semantics"); } diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h --- a/llvm/include/llvm/ADT/APFloat.h +++ b/llvm/include/llvm/ADT/APFloat.h @@ -177,6 +177,13 @@ // This format's exponent bias is 8, instead of the 7 (2 ** (4 - 1) - 1) // that IEEE precedent would imply. S_Float8E4M3FNUZ, + // 8-bit floating point number mostly following IEEE-754 conventions + // and bit layout S1E4M3 with expanded range and with no infinity or signed + // zero. + // NaN is represnted as negative zero. (FN -> Finite, UZ -> unsigned zero). + // This format's exponent bias is 11, instead of the 7 (2 ** (4 - 1) - 1) + // that IEEE precedent would imply. + S_Float8E4M3B11FNUZ, S_x87DoubleExtended, S_MaxSemantics = S_x87DoubleExtended, @@ -195,6 +202,7 @@ static const fltSemantics &Float8E5M2FNUZ() LLVM_READNONE; static const fltSemantics &Float8E4M3FN() LLVM_READNONE; static const fltSemantics &Float8E4M3FNUZ() LLVM_READNONE; + static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE; static const fltSemantics &x87DoubleExtended() LLVM_READNONE; /// A Pseudo fltsemantic used to construct APFloats that cannot conflict with @@ -590,6 +598,7 @@ APInt convertFloat8E5M2FNUZAPFloatToAPInt() const; APInt convertFloat8E4M3FNAPFloatToAPInt() const; APInt convertFloat8E4M3FNUZAPFloatToAPInt() const; + APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const; void initFromAPInt(const fltSemantics *Sem, const APInt &api); void initFromHalfAPInt(const APInt &api); void initFromBFloatAPInt(const APInt &api); @@ -602,6 +611,7 @@ void initFromFloat8E5M2FNUZAPInt(const APInt &api); void initFromFloat8E4M3FNAPInt(const APInt &api); void initFromFloat8E4M3FNUZAPInt(const APInt &api); + void initFromFloat8E4M3B11FNUZAPInt(const APInt &api); void assign(const IEEEFloat &); void copySignificand(const IEEEFloat &); diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp --- a/llvm/lib/Support/APFloat.cpp +++ b/llvm/lib/Support/APFloat.cpp @@ -60,8 +60,9 @@ IEEE754, // This behavior is present in the Float8ExMyFN* types (Float8E4M3FN, - // Float8E5M2FNUZ, and Float8E4M3FNUZ). There is no representation for Inf, - // and operations that would ordinarily produce Inf produce NaN instead. + // Float8E5M2FNUZ, Float8E4M3FNUZ, and Float8E4M3B11FNUZ). There is no + // representation for Inf, and operations that would ordinarily produce Inf + // produce NaN instead. // The details of the NaN representation(s) in this form are determined by the // `fltNanEncoding` enum. We treat all NaNs as quiet, as the available // encodings do not distinguish between signalling and quiet NaN. @@ -138,6 +139,13 @@ 8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes}; static const fltSemantics semFloat8E4M3FNUZ = { 7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero}; + static const fltSemantics semFloat8E4M3B11FNUZ = { + 4, + -10, + 4, + 8, + fltNonfiniteBehavior::NanOnly, + fltNanEncoding::NegativeZero}; static const fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80}; static const fltSemantics semBogus = {0, 0, 0, 0}; @@ -201,6 +209,8 @@ return Float8E4M3FN(); case S_Float8E4M3FNUZ: return Float8E4M3FNUZ(); + case S_Float8E4M3B11FNUZ: + return Float8E4M3B11FNUZ(); case S_x87DoubleExtended: return x87DoubleExtended(); } @@ -229,6 +239,8 @@ return S_Float8E4M3FN; else if (&Sem == &llvm::APFloat::Float8E4M3FNUZ()) return S_Float8E4M3FNUZ; + else if (&Sem == &llvm::APFloat::Float8E4M3B11FNUZ()) + return S_Float8E4M3B11FNUZ; else if (&Sem == &llvm::APFloat::x87DoubleExtended()) return S_x87DoubleExtended; else @@ -259,6 +271,9 @@ const fltSemantics &APFloatBase::Float8E4M3FNUZ() { return semFloat8E4M3FNUZ; } + const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() { + return semFloat8E4M3B11FNUZ; + } const fltSemantics &APFloatBase::x87DoubleExtended() { return semX87DoubleExtended; } @@ -3709,6 +3724,33 @@ (mysignificand & 0x7))); } +APInt IEEEFloat::convertFloat8E4M3B11FNUZAPFloatToAPInt() const { + assert(semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ); + assert(partCount() == 1); + + uint32_t myexponent, mysignificand; + + if (isFiniteNonZero()) { + myexponent = exponent + 11; // bias + mysignificand = (uint32_t)*significandParts(); + if (myexponent == 1 && !(mysignificand & 0x8)) + myexponent = 0; // denormal + } else if (category == fcZero) { + myexponent = 0; + mysignificand = 0; + } else if (category == fcInfinity) { + myexponent = 0; + mysignificand = 0; + } else { + assert(category == fcNaN && "Unknown category!"); + myexponent = 0; + mysignificand = (uint32_t)*significandParts(); + } + + return APInt(8, (((sign & 1) << 7) | ((myexponent & 0xf) << 3) | + (mysignificand & 0x7))); +} + // This function creates an APInt that is just a bit map of the floating // point constant as it would appear in memory. It is not a conversion, // and treating the result as a normal integer is unlikely to be useful. @@ -3744,6 +3786,9 @@ if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FNUZ) return convertFloat8E4M3FNUZAPFloatToAPInt(); + if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ) + return convertFloat8E4M3B11FNUZAPFloatToAPInt(); + assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended && "unknown format!"); return convertF80LongDoubleAPFloatToAPInt(); @@ -4077,6 +4122,32 @@ } } +void IEEEFloat::initFromFloat8E4M3B11FNUZAPInt(const APInt &api) { + uint32_t i = (uint32_t)*api.getRawData(); + uint32_t myexponent = (i >> 3) & 0xf; + uint32_t mysignificand = i & 0x7; + + initialize(&semFloat8E4M3B11FNUZ); + assert(partCount() == 1); + + sign = i >> 7; + if (myexponent == 0 && mysignificand == 0 && sign == 0) { + makeZero(sign); + } else if (myexponent == 0 && mysignificand == 0 && sign == 1) { + category = fcNaN; + exponent = exponentNaN(); + *significandParts() = mysignificand; + } else { + category = fcNormal; + exponent = myexponent - 11; // bias + *significandParts() = mysignificand; + if (myexponent == 0) // denormal + exponent = -10; + else + *significandParts() |= 0x8; // integer bit + } +} + /// Treat api as containing the bits of a floating point number. void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) { assert(api.getBitWidth() == Sem->sizeInBits); @@ -4102,6 +4173,8 @@ return initFromFloat8E4M3FNAPInt(api); if (Sem == &semFloat8E4M3FNUZ) return initFromFloat8E4M3FNUZAPInt(api); + if (Sem == &semFloat8E4M3B11FNUZ) + return initFromFloat8E4M3B11FNUZAPInt(api); llvm_unreachable(nullptr); } diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp --- a/llvm/unittests/ADT/APFloatTest.cpp +++ b/llvm/unittests/ADT/APFloatTest.cpp @@ -1346,6 +1346,10 @@ { 0x80ULL, APFloat::Float8E4M3FNUZ(), false, true, 0xaaULL }, { 0x80ULL, APFloat::Float8E4M3FNUZ(), true, false, 0xaaULL }, { 0x80ULL, APFloat::Float8E4M3FNUZ(), true, true, 0xaaULL }, + { 0x80ULL, APFloat::Float8E4M3B11FNUZ(), false, false, 0xaaULL }, + { 0x80ULL, APFloat::Float8E4M3B11FNUZ(), false, true, 0xaaULL }, + { 0x80ULL, APFloat::Float8E4M3B11FNUZ(), true, false, 0xaaULL }, + { 0x80ULL, APFloat::Float8E4M3B11FNUZ(), true, true, 0xaaULL }, // clang-format on }; @@ -1774,6 +1778,8 @@ APFloat::getLargest(APFloat::Float8E4M3FNUZ()).convertToDouble()); EXPECT_EQ(57344, APFloat::getLargest(APFloat::Float8E5M2FNUZ()).convertToDouble()); + EXPECT_EQ( + 30, APFloat::getLargest(APFloat::Float8E4M3B11FNUZ()).convertToDouble()); } TEST(APFloatTest, getSmallest) { @@ -1818,6 +1824,13 @@ EXPECT_TRUE(test.isFiniteNonZero()); EXPECT_TRUE(test.isDenormal()); EXPECT_TRUE(test.bitwiseIsEqual(expected)); + + test = APFloat::getSmallest(APFloat::Float8E4M3B11FNUZ(), false); + expected = APFloat(APFloat::Float8E4M3B11FNUZ(), "0x0.2p-10"); + EXPECT_FALSE(test.isNegative()); + EXPECT_TRUE(test.isFiniteNonZero()); + EXPECT_TRUE(test.isDenormal()); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); } TEST(APFloatTest, getSmallestNormalized) { @@ -1884,6 +1897,14 @@ EXPECT_FALSE(test.isDenormal()); EXPECT_TRUE(test.bitwiseIsEqual(expected)); EXPECT_TRUE(test.isSmallestNormalized()); + + test = APFloat::getSmallestNormalized(APFloat::Float8E4M3B11FNUZ(), false); + expected = APFloat(APFloat::Float8E4M3B11FNUZ(), "0x1.0p-10"); + EXPECT_FALSE(test.isNegative()); + EXPECT_TRUE(test.isFiniteNonZero()); + EXPECT_FALSE(test.isDenormal()); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + EXPECT_TRUE(test.isSmallestNormalized()); } TEST(APFloatTest, getZero) { @@ -1913,7 +1934,9 @@ {&APFloat::Float8E4M3FN(), false, true, {0, 0}, 1}, {&APFloat::Float8E4M3FN(), true, true, {0x80ULL, 0}, 1}, {&APFloat::Float8E4M3FNUZ(), false, false, {0, 0}, 1}, - {&APFloat::Float8E4M3FNUZ(), true, false, {0, 0}, 1}}; + {&APFloat::Float8E4M3FNUZ(), true, false, {0, 0}, 1}, + {&APFloat::Float8E4M3B11FNUZ(), false, false, {0, 0}, 1}, + {&APFloat::Float8E4M3B11FNUZ(), true, false, {0, 0}, 1}}; const unsigned NumGetZeroTests = std::size(GetZeroTest); for (unsigned i = 0; i < NumGetZeroTests; ++i) { APFloat test = APFloat::getZero(*GetZeroTest[i].semantics, @@ -1944,14 +1967,14 @@ EXPECT_TRUE(APFloat(42.0).bitwiseIsEqual( APFloat::copySign(APFloat(42.0), APFloat(1.0)))); // For floating-point formats with unsigned 0, copySign() to a zero is a noop - EXPECT_TRUE( - APFloat::getZero(APFloat::Float8E4M3FNUZ()) - .bitwiseIsEqual(APFloat::copySign( - APFloat::getZero(APFloat::Float8E4M3FNUZ()), APFloat(-1.0)))); - EXPECT_TRUE( - APFloat::getNaN(APFloat::Float8E4M3FNUZ(), true) - .bitwiseIsEqual(APFloat::copySign( - APFloat::getNaN(APFloat::Float8E4M3FNUZ(), true), APFloat(1.0)))); + for (APFloat::Semantics S : + {APFloat::S_Float8E4M3FNUZ, APFloat::S_Float8E4M3B11FNUZ}) { + const llvm::fltSemantics &Sem = APFloat::EnumToSemantics(S); + EXPECT_TRUE(APFloat::getZero(Sem).bitwiseIsEqual( + APFloat::copySign(APFloat::getZero(Sem), APFloat(-1.0)))); + EXPECT_TRUE(APFloat::getNaN(Sem, true).bitwiseIsEqual( + APFloat::copySign(APFloat::getNaN(Sem, true), APFloat(1.0)))); + } } TEST(APFloatTest, convert) { @@ -2073,17 +2096,18 @@ {APFloat::getSNaN(APFloat::IEEEsingle(), true), APFloat::opInvalidOp}, {APFloat::getInf(APFloat::IEEEsingle(), false), APFloat::opInexact}, {APFloat::getInf(APFloat::IEEEsingle(), true), APFloat::opInexact}}; - for (auto [toTest, expectedRes] : toNaNTests) { - llvm::SmallString<16> value; - toTest.toString(value); - SCOPED_TRACE("toTest = " + value); - for (const fltSemantics *sem : - {&APFloat::Float8E4M3FNUZ(), &APFloat::Float8E5M2FNUZ()}) { - SCOPED_TRACE("Semantics = " + - std::to_string(APFloat::SemanticsToEnum(*sem))); + for (APFloat::Semantics S : + {APFloat::S_Float8E5M2FNUZ, APFloat::S_Float8E4M3FNUZ, + APFloat::S_Float8E4M3B11FNUZ}) { + const llvm::fltSemantics &Sem = APFloat::EnumToSemantics(S); + SCOPED_TRACE("Semantics = " + std::to_string(S)); + for (auto [toTest, expectedRes] : toNaNTests) { + llvm::SmallString<16> value; + toTest.toString(value); + SCOPED_TRACE("toTest = " + value); losesInfo = false; APFloat test = toTest; - EXPECT_EQ(test.convert(*sem, APFloat::rmNearestTiesToAway, &losesInfo), + EXPECT_EQ(test.convert(Sem, APFloat::rmNearestTiesToAway, &losesInfo), expectedRes); EXPECT_TRUE(test.isNaN()); EXPECT_TRUE(test.isNegative()); @@ -2092,37 +2116,34 @@ EXPECT_EQ(0x80, test.bitcastToAPInt()); EXPECT_TRUE(losesInfo); } - } - // Negative zero conversions are information losing. - losesInfo = false; - APFloat test = APFloat::getZero(APFloat::IEEEsingle(), true); - EXPECT_EQ(test.convert(APFloat::Float8E5M2FNUZ(), - APFloat::rmNearestTiesToAway, &losesInfo), - APFloat::opInexact); - EXPECT_TRUE(test.isZero()); - EXPECT_FALSE(test.isNegative()); - EXPECT_TRUE(losesInfo); - EXPECT_EQ(0x0, test.bitcastToAPInt()); - - losesInfo = true; - test = APFloat::getZero(APFloat::IEEEsingle(), false); - EXPECT_EQ(test.convert(APFloat::Float8E5M2FNUZ(), - APFloat::rmNearestTiesToAway, &losesInfo), - APFloat::opOK); - EXPECT_TRUE(test.isZero()); - EXPECT_FALSE(test.isNegative()); - EXPECT_FALSE(losesInfo); - EXPECT_EQ(0x0, test.bitcastToAPInt()); + // Negative zero conversions are information losing. + losesInfo = false; + APFloat test = APFloat::getZero(APFloat::IEEEsingle(), true); + EXPECT_EQ(test.convert(Sem, APFloat::rmNearestTiesToAway, &losesInfo), + APFloat::opInexact); + EXPECT_TRUE(test.isZero()); + EXPECT_FALSE(test.isNegative()); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(0x0, test.bitcastToAPInt()); + + losesInfo = true; + test = APFloat::getZero(APFloat::IEEEsingle(), false); + EXPECT_EQ(test.convert(Sem, APFloat::rmNearestTiesToAway, &losesInfo), + APFloat::opOK); + EXPECT_TRUE(test.isZero()); + EXPECT_FALSE(test.isNegative()); + EXPECT_FALSE(losesInfo); + EXPECT_EQ(0x0, test.bitcastToAPInt()); - // Except in casts between ourselves. - losesInfo = true; - test = APFloat::getZero(APFloat::Float8E5M2FNUZ()); - EXPECT_EQ(test.convert(APFloat::Float8E4M3FNUZ(), - APFloat::rmNearestTiesToAway, &losesInfo), - APFloat::opOK); - EXPECT_FALSE(losesInfo); - EXPECT_EQ(0x0, test.bitcastToAPInt()); + // Except in casts between ourselves. + losesInfo = true; + test = APFloat::getZero(Sem); + EXPECT_EQ(test.convert(Sem, APFloat::rmNearestTiesToAway, &losesInfo), + APFloat::opOK); + EXPECT_FALSE(losesInfo); + EXPECT_EQ(0x0, test.bitcastToAPInt()); + } } TEST(APFloatTest, PPCDoubleDouble) { @@ -5003,7 +5024,7 @@ // Test each pair of 8-bit floats with non-standard semantics for (APFloat::Semantics Sem : {APFloat::S_Float8E4M3FN, APFloat::S_Float8E5M2FNUZ, - APFloat::S_Float8E4M3FNUZ}) { + APFloat::S_Float8E4M3FNUZ, APFloat::S_Float8E4M3B11FNUZ}) { const llvm::fltSemantics &S = APFloat::EnumToSemantics(Sem); for (int i = 0; i < 256; i++) { for (int j = 0; j < 256; j++) { @@ -5483,50 +5504,54 @@ // cases and so are not repeated here. // The IEEE round towards negative rule doesn't apply - APFloat test = APFloat::getSmallest(APFloat::Float8E4M3FNUZ()); - APFloat rhs = test; - EXPECT_EQ(test.subtract(rhs, APFloat::rmTowardNegative), APFloat::opOK); - EXPECT_TRUE(test.isZero()); - EXPECT_FALSE(test.isNegative()); - - // Multiplication of (small) * (-small) is +0 - test = APFloat::getSmallestNormalized(APFloat::Float8E4M3FNUZ()); - rhs = -test; - EXPECT_EQ(test.multiply(rhs, APFloat::rmNearestTiesToAway), - APFloat::opInexact | APFloat::opUnderflow); - EXPECT_TRUE(test.isZero()); - EXPECT_FALSE(test.isNegative()); + for (APFloat::Semantics S : + {APFloat::S_Float8E4M3FNUZ, APFloat::S_Float8E4M3B11FNUZ}) { + const llvm::fltSemantics &Sem = APFloat::EnumToSemantics(S); + APFloat test = APFloat::getSmallest(Sem); + APFloat rhs = test; + EXPECT_EQ(test.subtract(rhs, APFloat::rmTowardNegative), APFloat::opOK); + EXPECT_TRUE(test.isZero()); + EXPECT_FALSE(test.isNegative()); - // Dividing the negatize float_min by anything gives +0 - test = APFloat::getSmallest(APFloat::Float8E4M3FNUZ(), true); - rhs = APFloat(APFloat::Float8E4M3FNUZ(), "2.0"); - EXPECT_EQ(test.divide(rhs, APFloat::rmNearestTiesToEven), - APFloat::opInexact | APFloat::opUnderflow); - EXPECT_TRUE(test.isZero()); - EXPECT_FALSE(test.isNegative()); + // Multiplication of (small) * (-small) is +0 + test = APFloat::getSmallestNormalized(Sem); + rhs = -test; + EXPECT_EQ(test.multiply(rhs, APFloat::rmNearestTiesToAway), + APFloat::opInexact | APFloat::opUnderflow); + EXPECT_TRUE(test.isZero()); + EXPECT_FALSE(test.isNegative()); - // Remainder can't copy sign because there's only one zero - test = APFloat(APFloat::Float8E4M3FNUZ(), "-4.0"); - rhs = APFloat(APFloat::Float8E4M3FNUZ(), "2.0"); - EXPECT_EQ(test.remainder(rhs), APFloat::opOK); - EXPECT_TRUE(test.isZero()); - EXPECT_FALSE(test.isNegative()); + // Dividing the negatize float_min by anything gives +0 + test = APFloat::getSmallest(Sem, true); + rhs = APFloat(Sem, "2.0"); + EXPECT_EQ(test.divide(rhs, APFloat::rmNearestTiesToEven), + APFloat::opInexact | APFloat::opUnderflow); + EXPECT_TRUE(test.isZero()); + EXPECT_FALSE(test.isNegative()); - // And same for mod - test = APFloat(APFloat::Float8E4M3FNUZ(), "-4.0"); - rhs = APFloat(APFloat::Float8E4M3FNUZ(), "2.0"); - EXPECT_EQ(test.mod(rhs), APFloat::opOK); - EXPECT_TRUE(test.isZero()); - EXPECT_FALSE(test.isNegative()); + // Remainder can't copy sign because there's only one zero + test = APFloat(Sem, "-4.0"); + rhs = APFloat(Sem, "2.0"); + EXPECT_EQ(test.remainder(rhs), APFloat::opOK); + EXPECT_TRUE(test.isZero()); + EXPECT_FALSE(test.isNegative()); - // FMA correctly handles both the multiply and add parts of all this - test = APFloat(APFloat::Float8E4M3FNUZ(), "2.0"); - rhs = test; - APFloat addend = APFloat(APFloat::Float8E4M3FNUZ(), "-4.0"); - EXPECT_EQ(test.fusedMultiplyAdd(rhs, addend, APFloat::rmTowardNegative), - APFloat::opOK); - EXPECT_TRUE(test.isZero()); - EXPECT_FALSE(test.isNegative()); + // And same for mod + test = APFloat(Sem, "-4.0"); + rhs = APFloat(Sem, "2.0"); + EXPECT_EQ(test.mod(rhs), APFloat::opOK); + EXPECT_TRUE(test.isZero()); + EXPECT_FALSE(test.isNegative()); + + // FMA correctly handles both the multiply and add parts of all this + test = APFloat(Sem, "2.0"); + rhs = test; + APFloat addend = APFloat(Sem, "-4.0"); + EXPECT_EQ(test.fusedMultiplyAdd(rhs, addend, APFloat::rmTowardNegative), + APFloat::opOK); + EXPECT_TRUE(test.isZero()); + EXPECT_FALSE(test.isNegative()); + } } TEST(APFloatTest, Float8E5M2FNUZAdd) { @@ -5590,7 +5615,8 @@ const double largest; const double smallest; } const exhaustiveTests[] = {{&APFloat::Float8E5M2FNUZ(), 57344., 0x1.0p-17}, - {&APFloat::Float8E4M3FNUZ(), 240., 0x1.0p-10}}; + {&APFloat::Float8E4M3FNUZ(), 240., 0x1.0p-10}, + {&APFloat::Float8E4M3B11FNUZ(), 30., 0x1.0p-13}}; for (const auto &testInfo : exhaustiveTests) { const fltSemantics &sem = *testInfo.semantics; SCOPED_TRACE("Semantics=" + std::to_string(APFloat::SemanticsToEnum(sem))); @@ -5634,71 +5660,79 @@ } TEST(APFloatTest, Float8E4M3FNUZNext) { - APFloat test(APFloat::Float8E4M3FNUZ(), APFloat::uninitialized); - APFloat expected(APFloat::Float8E4M3FNUZ(), APFloat::uninitialized); - - // 1. NextUp of largest bit pattern is nan - test = APFloat::getLargest(APFloat::Float8E4M3FNUZ()); - expected = APFloat::getNaN(APFloat::Float8E4M3FNUZ()); - EXPECT_EQ(test.next(false), APFloat::opOK); - EXPECT_FALSE(test.isInfinity()); - EXPECT_FALSE(test.isZero()); - EXPECT_TRUE(test.isNaN()); - EXPECT_TRUE(test.bitwiseIsEqual(expected)); + for (APFloat::Semantics S : + {APFloat::S_Float8E4M3FNUZ, APFloat::S_Float8E4M3B11FNUZ}) { + const llvm::fltSemantics &Sem = APFloat::EnumToSemantics(S); + APFloat test(Sem, APFloat::uninitialized); + APFloat expected(Sem, APFloat::uninitialized); + + // 1. NextUp of largest bit pattern is nan + test = APFloat::getLargest(Sem); + expected = APFloat::getNaN(Sem); + EXPECT_EQ(test.next(false), APFloat::opOK); + EXPECT_FALSE(test.isInfinity()); + EXPECT_FALSE(test.isZero()); + EXPECT_TRUE(test.isNaN()); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); - // 2. NextUp of smallest negative denormal is +0 - test = APFloat::getSmallest(APFloat::Float8E4M3FNUZ(), true); - expected = APFloat::getZero(APFloat::Float8E4M3FNUZ(), false); - EXPECT_EQ(test.next(false), APFloat::opOK); - EXPECT_FALSE(test.isNegZero()); - EXPECT_TRUE(test.isPosZero()); - EXPECT_TRUE(test.bitwiseIsEqual(expected)); + // 2. NextUp of smallest negative denormal is +0 + test = APFloat::getSmallest(Sem, true); + expected = APFloat::getZero(Sem, false); + EXPECT_EQ(test.next(false), APFloat::opOK); + EXPECT_FALSE(test.isNegZero()); + EXPECT_TRUE(test.isPosZero()); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); - // 3. nextDown of negative of largest value is NaN - test = APFloat::getLargest(APFloat::Float8E4M3FNUZ(), true); - expected = APFloat::getNaN(APFloat::Float8E4M3FNUZ()); - EXPECT_EQ(test.next(true), APFloat::opOK); - EXPECT_FALSE(test.isInfinity()); - EXPECT_FALSE(test.isZero()); - EXPECT_TRUE(test.isNaN()); - EXPECT_TRUE(test.bitwiseIsEqual(expected)); + // 3. nextDown of negative of largest value is NaN + test = APFloat::getLargest(Sem, true); + expected = APFloat::getNaN(Sem); + EXPECT_EQ(test.next(true), APFloat::opOK); + EXPECT_FALSE(test.isInfinity()); + EXPECT_FALSE(test.isZero()); + EXPECT_TRUE(test.isNaN()); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); - // 4. nextDown of +0 is smallest negative denormal - test = APFloat::getZero(APFloat::Float8E4M3FNUZ(), false); - expected = APFloat::getSmallest(APFloat::Float8E4M3FNUZ(), true); - EXPECT_EQ(test.next(true), APFloat::opOK); - EXPECT_FALSE(test.isZero()); - EXPECT_TRUE(test.isDenormal()); - EXPECT_TRUE(test.bitwiseIsEqual(expected)); + // 4. nextDown of +0 is smallest negative denormal + test = APFloat::getZero(Sem, false); + expected = APFloat::getSmallest(Sem, true); + EXPECT_EQ(test.next(true), APFloat::opOK); + EXPECT_FALSE(test.isZero()); + EXPECT_TRUE(test.isDenormal()); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); - // 5. nextUp of NaN is NaN - test = APFloat::getNaN(APFloat::Float8E4M3FNUZ(), false); - expected = APFloat::getNaN(APFloat::Float8E4M3FNUZ(), true); - EXPECT_EQ(test.next(false), APFloat::opOK); - EXPECT_TRUE(test.isNaN()); + // 5. nextUp of NaN is NaN + test = APFloat::getNaN(Sem, false); + expected = APFloat::getNaN(Sem, true); + EXPECT_EQ(test.next(false), APFloat::opOK); + EXPECT_TRUE(test.isNaN()); - // 6. nextDown of NaN is NaN - test = APFloat::getNaN(APFloat::Float8E4M3FNUZ(), false); - expected = APFloat::getNaN(APFloat::Float8E4M3FNUZ(), true); - EXPECT_EQ(test.next(true), APFloat::opOK); - EXPECT_TRUE(test.isNaN()); + // 6. nextDown of NaN is NaN + test = APFloat::getNaN(Sem, false); + expected = APFloat::getNaN(Sem, true); + EXPECT_EQ(test.next(true), APFloat::opOK); + EXPECT_TRUE(test.isNaN()); + } } TEST(APFloatTest, Float8E4M3FNUZChangeSign) { - APFloat test = APFloat(APFloat::Float8E4M3FNUZ(), "1.0"); - APFloat expected = APFloat(APFloat::Float8E4M3FNUZ(), "-1.0"); - test.changeSign(); - EXPECT_TRUE(test.bitwiseIsEqual(expected)); + for (APFloat::Semantics S : + {APFloat::S_Float8E4M3FNUZ, APFloat::S_Float8E4M3B11FNUZ}) { + const llvm::fltSemantics &Sem = APFloat::EnumToSemantics(S); + APFloat test = APFloat(Sem, "1.0"); + APFloat expected = APFloat(Sem, "-1.0"); + test.changeSign(); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); - test = APFloat::getZero(APFloat::Float8E4M3FNUZ()); - expected = test; - test.changeSign(); - EXPECT_TRUE(test.bitwiseIsEqual(expected)); + test = APFloat::getZero(Sem); + expected = test; + test.changeSign(); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); - test = APFloat::getNaN(APFloat::Float8E4M3FNUZ()); - expected = test; - test.changeSign(); - EXPECT_TRUE(test.bitwiseIsEqual(expected)); + test = APFloat::getNaN(Sem); + expected = test; + test.changeSign(); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + } } TEST(APFloatTest, Float8E4M3FNUZFromString) { diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -95,6 +95,13 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx); +/// Checks whether the given type is an f8E4M3B11FNUZ type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type); + +/// Creates an f8E4M3B11FNUZ type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -64,6 +64,7 @@ FloatType getFloat8E4M3FNType(); FloatType getFloat8E5M2FNUZType(); FloatType getFloat8E4M3FNUZType(); + FloatType getFloat8E4M3B11FNUZType(); FloatType getBF16Type(); FloatType getF16Type(); FloatType getF32Type(); diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -49,6 +49,7 @@ static FloatType getFloat8E4M3FN(MLIRContext *ctx); static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx); static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx); + static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx); /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Type type); @@ -376,9 +377,10 @@ } inline bool FloatType::classof(Type type) { - return type.isa(); + return type + .isa(); } inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) { @@ -397,6 +399,10 @@ return Float8E4M3FNUZType::get(ctx); } +inline FloatType FloatType::getFloat8E4M3B11FNUZ(MLIRContext *ctx) { + return Float8E4M3B11FNUZType::get(ctx); +} + inline FloatType FloatType::getBF16(MLIRContext *ctx) { return BFloat16Type::get(ctx); } diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -162,6 +162,28 @@ }]; } +//===----------------------------------------------------------------------===// +// Float8E4M3B11FNUZType + +def Builtin_Float8E4M3B11FNUZ : Builtin_FloatType<"Float8E4M3B11FNUZ"> { + let summary = "8-bit floating point with 3 bit mantissa"; + let description = [{ + An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits + mantissa. This is not a standard type as defined by IEEE-754, but it follows + similar conventions, with the exception that there are no infinity values, + no negative zero, and only one NaN representation. This type has the + following characteristics: + + * bit encoding: S1E4M3 + * exponent bias: 11 + * infinities: Not supported + * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s + * denormals when exponent is 0 + + Related to: https://dl.acm.org/doi/10.5555/3454287.3454728 + }]; +} + //===----------------------------------------------------------------------===// // BFloat16Type diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -490,6 +490,8 @@ BuildableType<"$_builder.getFloat8E5M2Type()">; def F8E4M3FNUZ : Type, "f8E4M3FNUZ type">, BuildableType<"$_builder.getFloat8E4M3FNUZType()">; +def F8E4M3B11FNUZ : Type, "f8E4M3B11FNUZ type">, + BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">; def F8E5M2FNUZ : Type, "f8E5M2FNUZ type">, BuildableType<"$_builder.getFloat8E5M2FNUZType()">; diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -124,6 +124,7 @@ bool isFloat8E4M3FN() const; bool isFloat8E5M2FNUZ() const; bool isFloat8E4M3FNUZ() const; + bool isFloat8E4M3B11FNUZ() const; bool isBF16() const; bool isF16() const; bool isF32() const; diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -97,6 +97,7 @@ TOK_KEYWORD(f8E4M3FN) TOK_KEYWORD(f8E5M2FNUZ) TOK_KEYWORD(f8E4M3FNUZ) +TOK_KEYWORD(f8E4M3B11FNUZ) TOK_KEYWORD(f128) TOK_KEYWORD(false) TOK_KEYWORD(floordiv) diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -35,6 +35,7 @@ case Token::kw_f8E4M3FN: case Token::kw_f8E5M2FNUZ: case Token::kw_f8E4M3FNUZ: + case Token::kw_f8E4M3B11FNUZ: case Token::kw_bf16: case Token::kw_f16: case Token::kw_f32: @@ -303,6 +304,9 @@ case Token::kw_f8E4M3FNUZ: consumeToken(Token::kw_f8E4M3FNUZ); return builder.getFloat8E4M3FNUZType(); + case Token::kw_f8E4M3B11FNUZ: + consumeToken(Token::kw_f8E4M3B11FNUZ); + return builder.getFloat8E4M3B11FNUZType(); case Token::kw_bf16: consumeToken(Token::kw_bf16); return builder.getBF16Type(); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -157,6 +157,24 @@ } }; +/// Floating Point Type subclass - Float8E4M3B11FNUZ. +class PyFloat8E4M3B11FNUZType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; + static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); + return PyFloat8E4M3B11FNUZType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e4m3fnuz type."); + } +}; + /// Floating Point Type subclass - Float8E5M2FNUZ. class PyFloat8E5M2FNUZType : public PyConcreteType { public: diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -100,6 +100,14 @@ return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx))); } +bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) { + return unwrap(type).isFloat8E4M3B11FNUZ(); +} + +MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx))); +} + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2428,6 +2428,7 @@ .Case([&](Type) { os << "f8E4M3FN"; }) .Case([&](Type) { os << "f8E5M2FNUZ"; }) .Case([&](Type) { os << "f8E4M3FNUZ"; }) + .Case([&](Type) { os << "f8E4M3B11FNUZ"; }) .Case([&](Type) { os << "bf16"; }) .Case([&](Type) { os << "f16"; }) .Case([&](Type) { os << "f32"; }) diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -49,6 +49,10 @@ return FloatType::getFloat8E4M3FNUZ(context); } +FloatType Builder::getFloat8E4M3B11FNUZType() { + return FloatType::getFloat8E4M3B11FNUZ(context); +} + FloatType Builder::getBF16Type() { return FloatType::getBF16(context); } FloatType Builder::getF16Type() { return FloatType::getF16(context); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -89,7 +89,7 @@ unsigned FloatType::getWidth() { if (isa()) + Float8E4M3FNUZType, Float8E4M3B11FNUZType>()) return 8; if (isa()) return 16; @@ -114,6 +114,8 @@ return APFloat::Float8E5M2FNUZ(); if (isa()) return APFloat::Float8E4M3FNUZ(); + if (isa()) + return APFloat::Float8E4M3B11FNUZ(); if (isa()) return APFloat::BFloat(); if (isa()) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -214,6 +214,7 @@ Float8E4M3FNType f8E4M3FNTy; Float8E5M2FNUZType f8E5M2FNUZTy; Float8E4M3FNUZType f8E4M3FNUZTy; + Float8E4M3B11FNUZType f8E4M3B11FNUZTy; BFloat16Type bf16Ty; Float16Type f16Ty; Float32Type f32Ty; @@ -288,6 +289,7 @@ impl->f8E4M3FNTy = TypeUniquer::get(this); impl->f8E5M2FNUZTy = TypeUniquer::get(this); impl->f8E4M3FNUZTy = TypeUniquer::get(this); + impl->f8E4M3B11FNUZTy = TypeUniquer::get(this); impl->bf16Ty = TypeUniquer::get(this); impl->f16Ty = TypeUniquer::get(this); impl->f32Ty = TypeUniquer::get(this); @@ -892,6 +894,9 @@ Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) { return context->getImpl().f8E4M3FNUZTy; } +Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) { + return context->getImpl().f8E4M3B11FNUZTy; +} BFloat16Type BFloat16Type::get(MLIRContext *context) { return context->getImpl().bf16Ty; } diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -38,6 +38,7 @@ bool Type::isFloat8E4M3FN() const { return isa(); } bool Type::isFloat8E5M2FNUZ() const { return isa(); } bool Type::isFloat8E4M3FNUZ() const { return isa(); } +bool Type::isFloat8E4M3B11FNUZ() const { return isa(); } bool Type::isBF16() const { return isa(); } bool Type::isF16() const { return isa(); } bool Type::isF32() const { return isa(); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -53,6 +53,7 @@ "Float8E4M3FNType", "Float8E5M2Type", "Float8E4M3FNUZType", + "Float8E4M3B11FNUZType", "Float8E5M2FNUZType", "F16Type", "F32Type", @@ -602,6 +603,13 @@ @staticmethod def isinstance(arg: Any) -> bool: ... +class Float8E4M3B11FNUZType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> Float8E4M3B11FNUZType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + class Float8E5M2FNUZType(Type): def __init__(self, cast_from_type: Type) -> None: ... @staticmethod diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -52,6 +52,10 @@ // CHECK: float_attr = 2.000000e+00 : f8E4M3FNUZ float_attr = 2. : f8E4M3FNUZ } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : f8E4M3B11FNUZ + float_attr = 2. : f8E4M3B11FNUZ + } : () -> () "test.float_attrs"() { // CHECK: float_attr = 2.000000e+00 : f16 float_attr = 2. : f16 diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -202,6 +202,8 @@ print("float:", Float8E5M2FNUZType.get()) # CHECK: float: f8E4M3FNUZ print("float:", Float8E4M3FNUZType.get()) + # CHECK: float: f8E4M3B11FNUZ + print("float:", Float8E4M3B11FNUZType.get()) # CHECK: float: bf16 print("float:", BF16Type.get()) # CHECK: float: f16 diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py --- a/mlir/utils/lldb-scripts/mlirDataFormatters.py +++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py @@ -54,6 +54,7 @@ "mlir::Float8E4M3FNType": '"f8E4M3FN"', "mlir::Float8E5M2FNUZType": '"f8E5M2FNUZ"', "mlir::Float8E4M3FNUZType": '"f8E4M3FNUZ"', + "mlir::Float8E4M3B11FNUZType": '"f8E4M3B11FNUZ"', "mlir::BFloat16Type": '"bf16"', "mlir::Float16Type": '"f16"', "mlir::Float32Type": '"f32"',