diff --git a/libc/src/__support/FPUtil/CMakeLists.txt b/libc/src/__support/FPUtil/CMakeLists.txt --- a/libc/src/__support/FPUtil/CMakeLists.txt +++ b/libc/src/__support/FPUtil/CMakeLists.txt @@ -177,4 +177,16 @@ ROUND_OPT ) +add_header_library( + dyadic_float + HDRS + dyadic_float.h + DEPENDS + .float_properties + .fp_bits + .multiply_add + libc.src.__support.common + libc.src.__support.uint +) + add_subdirectory(generic) diff --git a/libc/src/__support/FPUtil/dyadic_float.h b/libc/src/__support/FPUtil/dyadic_float.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/FPUtil/dyadic_float.h @@ -0,0 +1,183 @@ +//===-- A class to store high precision floating point numbers --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_DYADIC_FLOAT_H +#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_DYADIC_FLOAT_H + +#include "FPBits.h" +#include "multiply_add.h" +#include "src/__support/CPP/type_traits.h" +#include "src/__support/UInt.h" + +#include + +namespace __llvm_libc::fputil { + +// A class for computations of high precision floating point. We store the +// value in dyadic format: The exponent field will be the exponent value of the +// least significant bit of the mantissa. So the real value that is stored is: +// real value = (-1)^sign * 2^exponent * (mantissa as unsigned integer) +template struct DyadicFloat { + using mantissa_type = __llvm_libc::cpp::UInt; + + bool sign = false; + int exponent = 0; + mantissa_type mantissa = mantissa_type(0); + + DyadicFloat() = default; + + template , int> = 0> + DyadicFloat(T x) { + FPBits x_bits(x); + sign = x_bits.get_sign(); + exponent = x_bits.get_exponent() - FloatProperties::MANTISSA_WIDTH; + mantissa = mantissa_type(x_bits.get_explicit_mantissa()); + normalize(); + } + + DyadicFloat(bool s, int e, mantissa_type m) + : sign(s), exponent(e), mantissa(m) { + normalize(); + }; + + // Normalizing the mantissa, bringing the leading 1 bit to the most + // significant bit. + DyadicFloat &normalize() { + if (!mantissa.is_zero()) { + int shift_length = static_cast(mantissa.clz()); + exponent -= shift_length; + mantissa.shift_left(static_cast(shift_length)); + } + return *this; + } + + // Used for aligning exponents. Output might not be normalized. + DyadicFloat &shift_left(int shift_length) { + exponent -= shift_length; + mantissa <<= static_cast(shift_length); + return *this; + } + + // Used for aligning exponents. Output might not be normalized. + DyadicFloat &shift_right(int shift_length) { + exponent += shift_length; + mantissa >>= static_cast(shift_length); + return *this; + } + + // Assume that it is already normalized and output is also normal. + // Output is rounded correctly with respect to the current rounding mode. + // TODO(lntue): Test or add support for denormal output. + // TODO(lntue): Test or add specialization for x86 long double. + template , void>> + explicit operator T() const { + // TODO(lntue): Do we need to treat signed zeros properly? + if (mantissa.is_zero()) + return 0.0; + + // Assume that it is normalized, and output is also normal. + constexpr size_t PRECISION = FloatProperties::MANTISSA_WIDTH + 1; + using output_bits_t = typename FPBits::UIntType; + + mantissa_type m_hi(mantissa >> (Bits - PRECISION)); + auto d_hi = FPBits::create_value( + sign, exponent + (Bits - 1) + FloatProperties::EXPONENT_BIAS, + output_bits_t(m_hi) & FloatProperties::MANTISSA_MASK); + + const mantissa_type ROUND_MASK = mantissa_type(1) << (Bits - PRECISION - 1); + const mantissa_type STICKY_MASK = ROUND_MASK - mantissa_type(1); + + bool round_bit = !(mantissa & ROUND_MASK).is_zero(); + bool sticky_bit = !(mantissa & STICKY_MASK).is_zero(); + int round_and_sticky = int(round_bit) * 2 + int(sticky_bit); + auto d_lo = FPBits::create_value(sign, + exponent + (Bits - PRECISION - 2) + + FloatProperties::EXPONENT_BIAS, + output_bits_t(0)); + + return multiply_add(d_lo.get_val(), T(round_and_sticky), d_hi.get_val()); + } +}; + +// Quick add. +// Assume inputs are normalized. +// Output will be normalized. +template +constexpr DyadicFloat quick_add(DyadicFloat a, + DyadicFloat b) { + if (unlikely(a.mantissa.is_zero())) + return b; + if (unlikely(b.mantissa.is_zero())) + return a; + + // Align exponents + if (a.exponent > b.exponent) + b.shift_right(a.exponent - b.exponent); + else if (b.exponent > a.exponent) + a.shift_right(b.exponent - a.exponent); + + DyadicFloat result; + + if (a.sign == b.sign) { + // Addition + result.sign = a.sign; + result.exponent = a.exponent; + result.mantissa = a.mantissa; + if (result.mantissa.add(b.mantissa)) { + // Mantissa addition overflow. + result.shift_right(1); + result.mantissa.val[DyadicFloat::mantissa_type::WordCount - 1] |= + (uint64_t(1) << 63); + } + // Result is already normalized. + return result; + } + + // Subtraction + if (a.mantissa >= b.mantissa) { + result.sign = a.sign; + result.exponent = a.exponent; + result.mantissa = a.mantissa - b.mantissa; + } else { + result.sign = b.sign; + result.exponent = b.exponent; + result.mantissa = b.mantissa - a.mantissa; + } + + return result.normalize(); +} + +// Quick Mul. +// Assume inputs are normalized. +// Output will be normalized. +// Error bound: 2 * errors of quick_mul_hi. +template +constexpr DyadicFloat quick_mul(DyadicFloat a, + DyadicFloat b) { + DyadicFloat result; + result.sign = (a.sign != b.sign); + result.exponent = a.exponent + b.exponent + int(Bits); + + if (!(a.mantissa.is_zero() || b.mantissa.is_zero())) { + result.mantissa = a.mantissa.quick_mul_hi(b.mantissa); + // Check the leading bit directly, should be faster than using clz in + // normalize(). + if (result.mantissa.val[DyadicFloat::mantissa_type::WordCount - 1] >> + 63 == + 0) + result.shift_left(1); + } else { + result.mantissa = (typename DyadicFloat::mantissa_type)(0); + } + return result; +} + +} // namespace __llvm_libc::fputil + +#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_DYADIC_FLOAT_H diff --git a/libc/src/__support/UInt.h b/libc/src/__support/UInt.h --- a/libc/src/__support/UInt.h +++ b/libc/src/__support/UInt.h @@ -14,6 +14,7 @@ #include "src/__support/CPP/optional.h" #include "src/__support/CPP/type_traits.h" #include "src/__support/builtin_wrappers.h" +#include "src/__support/common.h" #include "src/__support/integer_utils.h" #include "src/__support/number_pair.h" @@ -95,6 +96,14 @@ return *this; } + constexpr bool is_zero() const { + for (size_t i = 0; i < WordCount; ++i) { + if (val[i] != 0) + return false; + } + return true; + } + // Add x to this number and store the result in this number. // Returns the carry value produced by the addition operation. constexpr uint64_t add(const UInt &x) { @@ -356,6 +365,9 @@ return; } #endif // __SIZEOF_INT128__ + if (unlikely(s == 0)) + return; + const size_t drop = s / 64; // Number of words to drop const size_t shift = s % 64; // Bits to shift in the remaining words. size_t i = WordCount; @@ -386,6 +398,8 @@ } constexpr void shift_right(size_t s) { + if (unlikely(s == 0)) + return; const size_t drop = s / 64; // Number of words to drop const size_t shift = s % 64; // Bit shift in the remaining words. diff --git a/libc/test/src/__support/CMakeLists.txt b/libc/test/src/__support/CMakeLists.txt --- a/libc/test/src/__support/CMakeLists.txt +++ b/libc/test/src/__support/CMakeLists.txt @@ -111,3 +111,4 @@ add_subdirectory(CPP) add_subdirectory(File) add_subdirectory(OSUtil) +add_subdirectory(FPUtil) diff --git a/libc/test/src/__support/FPUtil/CMakeLists.txt b/libc/test/src/__support/FPUtil/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/libc/test/src/__support/FPUtil/CMakeLists.txt @@ -0,0 +1,12 @@ +add_libc_testsuite(libc_fputil_unittests) + +add_fp_unittest( + dyadic_float_test + NEED_MPFR + SUITE + libc_fputil_unittests + SRCS + dyadic_float_test.cpp + DEPENDS + libc.src.__support.FPUtil.dyadic_float +) diff --git a/libc/test/src/__support/FPUtil/dyadic_float_test.cpp b/libc/test/src/__support/FPUtil/dyadic_float_test.cpp new file mode 100644 --- /dev/null +++ b/libc/test/src/__support/FPUtil/dyadic_float_test.cpp @@ -0,0 +1,67 @@ +//===-- Unittests for the DyadicFloat class -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/__support/FPUtil/dyadic_float.h" +#include "src/__support/UInt.h" +#include "utils/MPFRWrapper/MPFRUtils.h" +#include "utils/UnitTest/FPMatcher.h" +#include "utils/UnitTest/Test.h" + +using Float128 = __llvm_libc::fputil::DyadicFloat<128>; +using Float192 = __llvm_libc::fputil::DyadicFloat<192>; +using Float256 = __llvm_libc::fputil::DyadicFloat<256>; + +TEST(LlvmLibcDyadicFloatTest, BasicConversions) { + Float128 x(/*sign*/ false, /*exponent*/ 0, + /*mantissa*/ Float128::mantissa_type(1)); + volatile float xf = float(x); + volatile double xd = double(x); + ASSERT_FP_EQ(1.0f, xf); + ASSERT_FP_EQ(1.0, xd); + + Float128 y(0x1.0p-53); + volatile float yf = float(y); + volatile double yd = double(y); + ASSERT_FP_EQ(0x1.0p-53f, yf); + ASSERT_FP_EQ(0x1.0p-53, yd); + + Float128 z = quick_add(x, y); + + EXPECT_FP_EQ_ALL_ROUNDING(xf + yf, float(z)); + EXPECT_FP_EQ_ALL_ROUNDING(xd + yd, double(z)); +} + +TEST(LlvmLibcDyadicFloatTest, QuickAdd) { + Float192 x(/*sign*/ false, /*exponent*/ 0, + /*mantissa*/ Float192::mantissa_type(0x123456)); + volatile double xd = double(x); + ASSERT_FP_EQ(0x1.23456p20, xd); + + Float192 y(0x1.abcdefp-20); + volatile double yd = double(y); + ASSERT_FP_EQ(0x1.abcdefp-20, yd); + + Float192 z = quick_add(x, y); + + EXPECT_FP_EQ_ALL_ROUNDING(xd + yd, (volatile double)(z)); +} + +TEST(LlvmLibcDyadicFloatTest, QuickMul) { + Float256 x(/*sign*/ false, /*exponent*/ 0, + /*mantissa*/ Float256::mantissa_type(0x123456)); + volatile double xd = double(x); + ASSERT_FP_EQ(0x1.23456p20, xd); + + Float256 y(0x1.abcdefp-25); + volatile double yd = double(y); + ASSERT_FP_EQ(0x1.abcdefp-25, yd); + + Float256 z = quick_mul(x, y); + + EXPECT_FP_EQ_ALL_ROUNDING(xd * yd, double(z)); +} diff --git a/libc/utils/UnitTest/CMakeLists.txt b/libc/utils/UnitTest/CMakeLists.txt --- a/libc/utils/UnitTest/CMakeLists.txt +++ b/libc/utils/UnitTest/CMakeLists.txt @@ -33,7 +33,7 @@ FPMatcher.h ) target_include_directories(LibcFPTestHelpers PUBLIC ${LIBC_SOURCE_DIR}) -target_link_libraries(LibcFPTestHelpers LibcUnitTest) +target_link_libraries(LibcFPTestHelpers LibcUnitTest libc_test_utils) add_dependencies( LibcFPTestHelpers LibcUnitTest diff --git a/libc/utils/UnitTest/FPMatcher.h b/libc/utils/UnitTest/FPMatcher.h --- a/libc/utils/UnitTest/FPMatcher.h +++ b/libc/utils/UnitTest/FPMatcher.h @@ -1,4 +1,4 @@ -//===-- TestMatchers.h ------------------------------------------*- C++ -*-===// +//===-- FPMatchers.h --------------------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -12,6 +12,7 @@ #include "src/__support/FPUtil/FEnvImpl.h" #include "src/__support/FPUtil/FPBits.h" #include "utils/UnitTest/Test.h" +#include "utils/testutils/RoundingModeUtils.h" #include #include @@ -132,4 +133,17 @@ } \ } while (0) +#define EXPECT_FP_EQ_ALL_ROUNDING(expected, actual) \ + do { \ + using namespace __llvm_libc::testutils; \ + ForceRoundingMode __r1(RoundingMode::Nearest); \ + EXPECT_FP_EQ((expected), (actual)); \ + ForceRoundingMode __r2(RoundingMode::Upward); \ + EXPECT_FP_EQ((expected), (actual)); \ + ForceRoundingMode __r3(RoundingMode::Downward); \ + EXPECT_FP_EQ((expected), (actual)); \ + ForceRoundingMode __r4(RoundingMode::TowardZero); \ + EXPECT_FP_EQ((expected), (actual)); \ + } while (0) + #endif // LLVM_LIBC_UTILS_UNITTEST_FPMATCHER_H