diff --git a/libc/config/linux/aarch64/entrypoints.txt b/libc/config/linux/aarch64/entrypoints.txt --- a/libc/config/linux/aarch64/entrypoints.txt +++ b/libc/config/linux/aarch64/entrypoints.txt @@ -65,6 +65,7 @@ libc.src.math.floor libc.src.math.floorf libc.src.math.floorl + libc.src.math.fmaf libc.src.math.fmax libc.src.math.fmaxf libc.src.math.fmaxl diff --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt --- a/libc/config/linux/x86_64/entrypoints.txt +++ b/libc/config/linux/x86_64/entrypoints.txt @@ -106,6 +106,7 @@ libc.src.math.floor libc.src.math.floorf libc.src.math.floorl + libc.src.math.fmaf libc.src.math.fmin libc.src.math.fminf libc.src.math.fminl diff --git a/libc/spec/stdc.td b/libc/spec/stdc.td --- a/libc/spec/stdc.td +++ b/libc/spec/stdc.td @@ -322,6 +322,8 @@ FunctionSpec<"fmaxf", RetValSpec, [ArgSpec, ArgSpec]>, FunctionSpec<"fmaxl", RetValSpec, [ArgSpec, ArgSpec]>, + FunctionSpec<"fmaf", RetValSpec, [ArgSpec, ArgSpec, ArgSpec]>, + FunctionSpec<"frexp", RetValSpec, [ArgSpec, ArgSpec]>, FunctionSpec<"frexpf", RetValSpec, [ArgSpec, ArgSpec]>, FunctionSpec<"frexpl", RetValSpec, [ArgSpec, ArgSpec]>, diff --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt --- a/libc/src/math/CMakeLists.txt +++ b/libc/src/math/CMakeLists.txt @@ -978,3 +978,14 @@ -O2 ) +add_entrypoint_object( + fmaf + SRCS + fmaf.cpp + HDRS + fmaf.h + DEPENDS + libc.utils.FPUtil.fputil + COMPILE_OPTIONS + -O2 +) diff --git a/libc/src/math/fmaf.h b/libc/src/math/fmaf.h new file mode 100644 --- /dev/null +++ b/libc/src/math/fmaf.h @@ -0,0 +1,18 @@ +//===-- Implementation header for fmaf --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_MATH_FMAF_H +#define LLVM_LIBC_SRC_MATH_FMAF_H + +namespace __llvm_libc { + +float fmaf(float x, float y, float z); + +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_MATH_FMAF_H diff --git a/libc/src/math/fmaf.cpp b/libc/src/math/fmaf.cpp new file mode 100644 --- /dev/null +++ b/libc/src/math/fmaf.cpp @@ -0,0 +1,64 @@ +//===-- Implementation of fmaf function -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/__support/common.h" + +#include "utils/FPUtil/FEnv.h" +#include "utils/FPUtil/FPBits.h" + +namespace __llvm_libc { + +float LLVM_LIBC_ENTRYPOINT(fmaf)(float x, float y, float z) { + // Product is exact. + double prod = static_cast(x) * static_cast(y); + double z_d = static_cast(z); + double sum = prod + z_d; + fputil::FPBits bit_prod(prod), bitz(z_d), bit_sum(sum); + + if (!(bit_sum.isInfOrNaN() || bit_sum.isZero())) { + // Since the sum is computed in double precision, rounding might happen + // (for instance, when bitz.exponent > bit_prod.exponent + 5, or + // bit_prod.exponent > bitz.exponent + 40). In that case, when we round + // the sum back to float, double rounding error might occur. + // A concrete example of this phenomenon is as follows: + // x = y = 1 + 2^(-12), z = 2^(-53) + // The exact value of x*y + z is 1 + 2^(-11) + 2^(-24) + 2^(-53) + // So when rounding to float, fmaf(x, y, z) = 1 + 2^(-11) + 2^(-23) + // On the other hand, with the default rounding mode, + // double(x*y + z) = 1 + 2^(-11) + 2^(-24) + // and casting again to float gives us: + // float(double(x*y + z)) = 1 + 2^(-11). + // + // In order to correct this possible double rounding error, first we use + // Dekker's 2Sum algorithm to find t such that sum - t = prod + z exactly, + // assuming the (default) rounding mode is round-to-the-nearest, + // tie-to-even. Moreover, t satisfies the condition that t < eps(sum), + // i.e., t.exponent < sum.exponent - 52. So if t is not 0, meaning rounding + // occurs when computing the sum, we just need to use t to adjust (any) last + // bit of sum, so that the sticky bits used when rounding sum to float are + // correct (when it matters). + fputil::FPBits t( + (bit_prod.exponent >= bitz.exponent) + ? ((static_cast(bit_sum) - bit_prod) - bitz) + : ((static_cast(bit_sum) - bitz) - bit_prod)); + + // Update sticky bits if t != 0.0 and the least (52 - 23 - 1 = 28) bits are + // zero. + if (!t.isZero() && ((bit_sum.mantissa & 0xfff'ffffULL) == 0)) { + if (bit_sum.sign != t.sign) { + ++bit_sum.mantissa; + } else if (bit_sum.mantissa) { + --bit_sum.mantissa; + } + } + } + + return static_cast(static_cast(bit_sum)); +} + +} // namespace __llvm_libc diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt --- a/libc/test/src/math/CMakeLists.txt +++ b/libc/test/src/math/CMakeLists.txt @@ -1049,3 +1049,16 @@ libc.src.math.nextafterl libc.utils.FPUtil.fputil ) + +add_fp_unittest( + fmaf_test + NEED_MPFR + SUITE + libc_math_unittests + SRCS + fmaf_test.cpp + DEPENDS + libc.include.math + libc.src.math.fmaf + libc.utils.FPUtil.fputil +) diff --git a/libc/test/src/math/FmaTest.h b/libc/test/src/math/FmaTest.h new file mode 100644 --- /dev/null +++ b/libc/test/src/math/FmaTest.h @@ -0,0 +1,94 @@ +//===-- Utility class to test different flavors of fma --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_TEST_SRC_MATH_FMATEST_H +#define LLVM_LIBC_TEST_SRC_MATH_FMATEST_H + +#include "utils/FPUtil/FPBits.h" +#include "utils/FPUtil/TestHelpers.h" +#include "utils/MPFRWrapper/MPFRUtils.h" +#include "utils/UnitTest/Test.h" + +#include + +namespace mpfr = __llvm_libc::testing::mpfr; + +template +class FmaTestTemplate : public __llvm_libc::testing::Test { +private: + using Func = T (*)(T, T, T); + using FPBits = __llvm_libc::fputil::FPBits; + using UIntType = typename FPBits::UIntType; + const T nan = __llvm_libc::fputil::FPBits::buildNaN(1); + const T inf = __llvm_libc::fputil::FPBits::inf(); + const T negInf = __llvm_libc::fputil::FPBits::negInf(); + const T zero = __llvm_libc::fputil::FPBits::zero(); + const T negZero = __llvm_libc::fputil::FPBits::negZero(); + + UIntType getRandomBitPattern() { + UIntType bits{0}; + for (size_t i = 0; i < sizeof(UIntType) / 2; ++i) { + bits = (bits << 2) + static_cast(std::rand()); + } + return bits; + } + +public: + void testSpecialNumbers(Func func) { + EXPECT_FP_EQ(func(zero, zero, zero), zero); + EXPECT_FP_EQ(func(zero, negZero, negZero), negZero); + EXPECT_FP_EQ(func(inf, inf, zero), inf); + EXPECT_FP_EQ(func(negInf, inf, negInf), negInf); + EXPECT_FP_EQ(func(inf, zero, zero), nan); + EXPECT_FP_EQ(func(inf, negInf, inf), nan); + EXPECT_FP_EQ(func(nan, zero, inf), nan); + EXPECT_FP_EQ(func(inf, negInf, nan), nan); + + // Test underflow rounding up. + EXPECT_FP_EQ(func(T(0.5), FPBits(FPBits::minSubnormal), + FPBits(FPBits::minSubnormal)), + FPBits(UIntType(2))); + // Test underflow rounding down. + FPBits v(FPBits::minNormal + UIntType(1)); + EXPECT_FP_EQ( + func(T(1) / T(FPBits::minNormal << 1), v, FPBits(FPBits::minNormal)), + v); + // Test overflow. + FPBits z(FPBits::maxNormal); + EXPECT_FP_EQ(func(T(1.75), z, -z), T(0.75) * z); + } + + void testSubnormalRange(Func func) { + constexpr UIntType count = 1000001; + constexpr UIntType step = + (FPBits::maxSubnormal - FPBits::minSubnormal) / count; + for (UIntType v = FPBits::minSubnormal, w = FPBits::maxSubnormal; + v <= FPBits::maxSubnormal && w >= FPBits::minSubnormal; + v += step, w -= step) { + T x = FPBits(getRandomBitPattern()), y = FPBits(v), z = FPBits(w); + T result = func(x, y, z); + mpfr::TernaryInput input{x, y, z}; + ASSERT_MPFR_MATCH(mpfr::Operation::Fma, input, result, 0.5); + } + } + + void testNormalRange(Func func) { + constexpr UIntType count = 1000001; + constexpr UIntType step = (FPBits::maxNormal - FPBits::minNormal) / count; + for (UIntType v = FPBits::minNormal, w = FPBits::maxNormal; + v <= FPBits::maxNormal && w >= FPBits::minNormal; + v += step, w -= step) { + T x = FPBits(v), y = FPBits(w), z = FPBits(getRandomBitPattern()); + T result = func(x, y, z); + mpfr::TernaryInput input{x, y, z}; + ASSERT_MPFR_MATCH(mpfr::Operation::Fma, input, result, 0.5); + } + } +}; + +#endif // LLVM_LIBC_TEST_SRC_MATH_FMATEST_H diff --git a/libc/test/src/math/fmaf_test.cpp b/libc/test/src/math/fmaf_test.cpp new file mode 100644 --- /dev/null +++ b/libc/test/src/math/fmaf_test.cpp @@ -0,0 +1,19 @@ +//===-- Unittests for fmaf ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "FmaTest.h" + +#include "src/math/fmaf.h" + +using FmaTest = FmaTestTemplate; + +TEST_F(FmaTest, SpecialNumbers) { testSpecialNumbers(&__llvm_libc::fmaf); } + +TEST_F(FmaTest, SubnormalRange) { testSubnormalRange(&__llvm_libc::fmaf); } + +TEST_F(FmaTest, NormalRange) { testNormalRange(&__llvm_libc::fmaf); } diff --git a/libc/utils/FPUtil/FPBits.h b/libc/utils/FPUtil/FPBits.h --- a/libc/utils/FPUtil/FPBits.h +++ b/libc/utils/FPUtil/FPBits.h @@ -84,7 +84,10 @@ // We don't want accidental type promotions/conversions so we require exact // type match. template ::Value, int> = 0> + cpp::EnableIfType::Value || + (cpp::IsIntegral::Value && + (sizeof(XType) == sizeof(UIntType))), + int> = 0> explicit FPBits(XType x) { *this = *reinterpret_cast *>(&x); } @@ -106,13 +109,6 @@ // the potential software implementations of UIntType will not slow real // code. - template ::Value, int> = 0> - explicit FPBits(XType x) { - // The last 4 bytes of v are ignored in case of i386. - *this = *reinterpret_cast *>(&x); - } - UIntType bitsAsUInt() const { return *reinterpret_cast(this); } diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h --- a/libc/utils/MPFRWrapper/MPFRUtils.h +++ b/libc/utils/MPFRWrapper/MPFRUtils.h @@ -57,8 +57,11 @@ RemQuo, // The first output, the floating point output, is the remainder. EndBinaryOperationsTwoOutputs, + // Operations which take three floating point nubmers of the same type as + // input and produce a single floating point number of the same type as + // output. BeginTernaryOperationsSingleOuput, - // TODO: Add operations like fma. + Fma, EndTernaryOperationsSingleOutput, }; @@ -113,6 +116,11 @@ bool compareBinaryOperationOneOutput(Operation op, const BinaryInput &input, T libcOutput, double t); +template +bool compareTernaryOperationOneOutput(Operation op, + const TernaryInput &input, + T libcOutput, double t); + template void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue, testutils::StreamWrapper &OS); @@ -132,6 +140,12 @@ T matchValue, testutils::StreamWrapper &OS); +template +void explainTernaryOperationOneOutputError(Operation op, + const TernaryInput &input, + T matchValue, + testutils::StreamWrapper &OS); + template class MPFRMatcher : public testing::Matcher { InputType input; @@ -174,7 +188,7 @@ template static bool match(const TernaryInput &in, T out, double tolerance) { - // TODO: Implement the comparision function and error reporter. + return compareTernaryOperationOneOutput(op, in, out, tolerance); } template @@ -199,6 +213,12 @@ testutils::StreamWrapper &OS) { explainBinaryOperationOneOutputError(op, in, out, OS); } + + template + static void explainError(const TernaryInput &in, T out, + testutils::StreamWrapper &OS) { + explainTernaryOperationOneOutputError(op, in, out, OS); + } }; } // namespace internal diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp --- a/libc/utils/MPFRWrapper/MPFRUtils.cpp +++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp @@ -35,48 +35,69 @@ namespace testing { namespace mpfr { +template struct Precision; + +template <> struct Precision { + static constexpr unsigned int value = 24; +}; + +template <> struct Precision { + static constexpr unsigned int value = 53; +}; + +#if !(defined(__x86_64__) || defined(__i386__)) +template <> struct Precision { + static constexpr unsigned int value = 64; +}; +#else +template <> struct Precision { + static constexpr unsigned int value = 113; +}; +#endif + class MPFRNumber { // A precision value which allows sufficiently large additional // precision even compared to quad-precision floating point values. - static constexpr unsigned int mpfrPrecision = 128; + unsigned int mpfrPrecision; mpfr_t value; public: - MPFRNumber() { mpfr_init2(value, mpfrPrecision); } + MPFRNumber() : mpfrPrecision(128) { mpfr_init2(value, mpfrPrecision); } // We use explicit EnableIf specializations to disallow implicit // conversions. Implicit conversions can potentially lead to loss of // precision. template ::Value, int> = 0> - explicit MPFRNumber(XType x) { + explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) { mpfr_init2(value, mpfrPrecision); mpfr_set_flt(value, x, MPFR_RNDN); } template ::Value, int> = 0> - explicit MPFRNumber(XType x) { + explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) { mpfr_init2(value, mpfrPrecision); mpfr_set_d(value, x, MPFR_RNDN); } template ::Value, int> = 0> - explicit MPFRNumber(XType x) { + explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) { mpfr_init2(value, mpfrPrecision); mpfr_set_ld(value, x, MPFR_RNDN); } template ::Value, int> = 0> - explicit MPFRNumber(XType x) { + explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) { mpfr_init2(value, mpfrPrecision); mpfr_set_sj(value, x, MPFR_RNDN); } - MPFRNumber(const MPFRNumber &other) { + MPFRNumber(const MPFRNumber &other) : mpfrPrecision(other.mpfrPrecision) { + mpfr_init2(value, mpfrPrecision); mpfr_set(value, other.value, MPFR_RNDN); } @@ -85,6 +106,7 @@ } MPFRNumber &operator=(const MPFRNumber &rhs) { + mpfrPrecision = rhs.mpfrPrecision; mpfr_set(value, rhs.value, MPFR_RNDN); return *this; } @@ -193,6 +215,12 @@ return result; } + MPFRNumber fma(const MPFRNumber &b, const MPFRNumber &c) { + MPFRNumber result(*this); + mpfr_fma(result.value, value, b.value, c.value, MPFR_RNDN); + return result; + } + std::string str() const { // 200 bytes should be more than sufficient to hold a 100-digit number // plus additional bytes for the decimal point, '-' sign etc. @@ -328,6 +356,22 @@ } } +template +cpp::EnableIfType::Value, MPFRNumber> +ternaryOperationOneOutput(Operation op, InputType x, InputType y, InputType z) { + // For FMA function, we just need to compare with the mpfr_fma with the same + // precision as InputType. Using higher precision as the intermediate results + // to compare might incorrectly fail due to double-rounding errors. + constexpr unsigned int prec = Precision::value; + MPFRNumber inputX(x, prec), inputY(y, prec), inputZ(z, prec); + switch (op) { + case Operation::Fma: + return inputX.fma(inputY, inputZ); + default: + __builtin_unreachable(); + } +} + template void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue, testutils::StreamWrapper &OS) { @@ -476,6 +520,48 @@ Operation, const BinaryInput &, long double, testutils::StreamWrapper &); +template +void explainTernaryOperationOneOutputError(Operation op, + const TernaryInput &input, + T libcResult, + testutils::StreamWrapper &OS) { + MPFRNumber mpfrX(input.x, Precision::value); + MPFRNumber mpfrY(input.y, Precision::value); + MPFRNumber mpfrZ(input.z, Precision::value); + FPBits xbits(input.x); + FPBits ybits(input.y); + FPBits zbits(input.z); + MPFRNumber mpfrResult = + ternaryOperationOneOutput(op, input.x, input.y, input.z); + MPFRNumber mpfrMatchValue(libcResult); + + OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() + << " z: " << mpfrZ.str() << '\n'; + __llvm_libc::fputil::testing::describeValue("First input bits: ", input.x, + OS); + __llvm_libc::fputil::testing::describeValue("Second input bits: ", input.y, + OS); + __llvm_libc::fputil::testing::describeValue("Third input bits: ", input.z, + OS); + + OS << "Libc result: " << mpfrMatchValue.str() << '\n' + << "MPFR result: " << mpfrResult.str() << '\n'; + __llvm_libc::fputil::testing::describeValue( + "Libc floating point result bits: ", libcResult, OS); + __llvm_libc::fputil::testing::describeValue( + " MPFR rounded bits: ", mpfrResult.as(), OS); + OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult)) << '\n'; +} + +template void explainTernaryOperationOneOutputError( + Operation, const TernaryInput &, float, testutils::StreamWrapper &); +template void explainTernaryOperationOneOutputError( + Operation, const TernaryInput &, double, + testutils::StreamWrapper &); +template void explainTernaryOperationOneOutputError( + Operation, const TernaryInput &, long double, + testutils::StreamWrapper &); + template bool compareUnaryOperationSingleOutput(Operation op, T input, T libcResult, double ulpError) { @@ -575,6 +661,27 @@ template bool compareBinaryOperationOneOutput( Operation, const BinaryInput &, long double, double); +template +bool compareTernaryOperationOneOutput(Operation op, + const TernaryInput &input, + T libcResult, double ulpError) { + MPFRNumber mpfrResult = + ternaryOperationOneOutput(op, input.x, input.y, input.z); + double ulp = mpfrResult.ulp(libcResult); + + bool bitsAreEven = ((FPBits(libcResult).bitsAsUInt() & 1) == 0); + return (ulp < ulpError) || + ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven)); +} + +template bool +compareTernaryOperationOneOutput(Operation, const TernaryInput &, + float, double); +template bool compareTernaryOperationOneOutput( + Operation, const TernaryInput &, double, double); +template bool compareTernaryOperationOneOutput( + Operation, const TernaryInput &, long double, double); + static mpfr_rnd_t getMPFRRoundingMode(RoundingMode mode) { switch (mode) { case RoundingMode::Upward: