diff --git a/libc/src/__support/FPUtil/CMakeLists.txt b/libc/src/__support/FPUtil/CMakeLists.txt --- a/libc/src/__support/FPUtil/CMakeLists.txt +++ b/libc/src/__support/FPUtil/CMakeLists.txt @@ -177,4 +177,15 @@ ROUND_OPT ) +add_header_library( + big_float + HDRS + big_float.h + DEPENDS + .fp_bits + .multiply_add + libc.src.__support.uint + libc.src.__support.CPP.type_traits +) + add_subdirectory(generic) diff --git a/libc/src/__support/FPUtil/big_float.h b/libc/src/__support/FPUtil/big_float.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/FPUtil/big_float.h @@ -0,0 +1,208 @@ +//===-- A class for high precision floating point numbers -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_BIG_FLOAT_H +#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_BIG_FLOAT_H + +#include "src/__support/CPP/type_traits.h" +#include "src/__support/FPUtil/FPBits.h" +#include "src/__support/FPUtil/FloatProperties.h" +#include "src/__support/FPUtil/multiply_add.h" +#include "src/__support/UInt.h" + +#include +#include + +namespace __llvm_libc { +namespace fputil { + +// A class for computations of high precision floating point. User will need to +// perform normalization if needed. +// +// We store the value in dyadic format: +// The exponent field will be the exponent value of the least significant bit of +// the mantissa. So the real value that is stored is: +// real value = (-1)^sign * 2^exponent * (mantissa as unsigned integer) +template struct BigFloat { + static_assert(MANTISSA_LENGTH >= 64 && MANTISSA_LENGTH % 8 == 0, + "Unsupported mantissa length."); + using mantissa_type = cpp::UInt; + + bool sign; + int exponent; // Exponent of the least significant bit of the mantissa. + mantissa_type mantissa; + + BigFloat() = default; + + // Assumption: a is normalized and not Inf/NaN. + BigFloat(double a) { + FPBits a_bits(a); + sign = a_bits.get_sign(); + exponent = a_bits.get_exponent() - 52; + mantissa = mantissa_type(a_bits.get_explicit_mantissa()); + } + + BigFloat(bool s, int e, mantissa_type m) + : sign(s), exponent(e), mantissa(m) {} + + // Explicit conversion to bigger float. + template MANTISSA_LENGTH), void>> + explicit operator BigFloat() const { + return {sign, exponent, + static_cast::mantissa_type>(mantissa)}; + } + + bool is_zero() const { return mantissa == 0; } + + void shift_left(int shift_length) { + exponent -= shift_length; + mantissa <<= shift_length; + } + + void shift_right(int shift_length) { + exponent += shift_length; + mantissa >>= shift_length; + } + + // Convert to double. Respect the current rounding mode. Expect the output + // to be normals. + double to_double() const { + if (is_zero()) + return 0.0; + + int shift = static_cast(mantissa.clz()); + mantissa_type m(mantissa); + int e = exponent; + + // Move leading 1 (if non-zero) to the most significant bit of mantissa. + if (shift) { + e -= shift; + m <<= shift; + e -= shift; + } + + mantissa_type m_hi(m >> (MANTISSA_LENGTH - 53)); + auto d_hi = FPBits::create_value( + sign, + e + (MANTISSA_LENGTH - 1) + FloatProperties::EXPONENT_BIAS, + uint64_t(m_hi) & FloatProperties::MANTISSA_MASK); + + constexpr mantissa_type ROUND_MASK = mantissa_type(1) + << (MANTISSA_LENGTH - 54); + constexpr mantissa_type STICKY_MASK = ROUND_MASK - 1; + + bool round_bit = m & ROUND_MASK; + bool sticky_bit = m & STICKY_MASK; + int round_and_sticky = int(round_bit) * 2 + int(sticky_bit); + auto d_lo = FPBits::create_value( + sign, + e + (MANTISSA_LENGTH - 55) + FloatProperties::EXPONENT_BIAS, + uint64_t(0)); + + if (!round_and_sticky) + return d_hi.get_val(); + + return multiply_add(d_lo.get_val(), double(round_and_sticky), + d_hi.get_val()); + } + + // TODO: Add in-place addition and multiplication operations. +}; + +// Splitting a BigFloat into 2 BigFloat's. +template +constexpr inline pair> split(const BigFloat &a) { + constexpr size_t SHIFT_LENGTH = LENGTH / 2; + using mtype = typename BigFloat::mantissa_type; + return {{a.sign, a.exponent, static_cast(a.mantissa)}, + {a.sign, a.exponent + static_cast(SHIFT_LENGTH), + static_cast(a.mantissa >> SHIFT_LENGTH)}}; +} + +template +constexpr BigFloat quick_add(BigFloat a, BigFloat b) { + if (unlikely(a.is_zero())) + return b; + if (unlikely(b.is_zero())) + return a; + + // Move the leading 1's to the most significant bits. + int shift_a = static_cast(a.mantissa.clz()); + if (shift_a) + a.shift_left(shift_a); + int shift_b = static_cast(b.mantissa.clz()); + if (shift_b) + b.shift_left(shift_b); + + // Align exponents + if (a.exponent > b.exponent) + b.shift_right(a.exponent - b.exponent); + else if (b.exponent > a.exponent) + a.shift_right(b.exponent - a.exponent); + + BigFloat result; + + if (a.sign == b.sign) { + // Addition + result.sign = a.sign; + result.exponent = a.exponent; + result.mantissa = a.mantissa; + if (result.mantissa.add(b.mantissa)) { + // Mantissa addition overflow. + result.shift_right(1); + result.mantissa[LENGTH / 64 - 1] |= (uint64_t(1) << 63); + } + return result; + } + + // Subtraction + if (a.mantissa >= b.mantissa) { + result.sign = a.sign; + result.exponent = a.exponent; + result.mantissa = a.mantissa - b.mantissa; + } else { + result.sign = b.sign; + result.exponent = b.exponent; + result.mantissa = b.mantissa - a.mantissa; + } + + return result; +} + +// Multiply 2 big floats, round toward 0. +template +constexpr BigFloat quick_mul(const BigFloat &a, + const BigFloat &b) { + BigFloat result; + result.sign = (a.sign != b.sign); + result.exponent = a.exponent + b.exponent; + if (a.is_zero() || b.is_zero()) { + result.mantissa = cpp::UInt(0); + return result; + } + + cpp::UInt<2 *LENGTH> full_prod = cpp::full_mul(a.mantissa, b.mantissa); + int shift = static_cast(full_prod.clz()); + if (shift) { + full_prod <<= shift; + result.exponent -= shift; + } + result.exponent += LENGTH; + for (size_t i = 0; i < LENGTH / 64; ++i) { + result.mantissa[i] = full_prod[i + LENGTH / 64]; + } + + return result; +} + +} // namespace fputil + +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_BIG_FLOAT_H \ No newline at end of file diff --git a/libc/src/__support/UInt.h b/libc/src/__support/UInt.h --- a/libc/src/__support/UInt.h +++ b/libc/src/__support/UInt.h @@ -18,7 +18,62 @@ #include // For size_t #include -namespace __llvm_libc::cpp { +namespace __llvm_libc { + +namespace util { + +template struct pair { + T lo = 0; + T hi = 0; +}; + +constexpr inline pair split(uint64_t a) { + return {/*lo*/ a & 0xFFFF'FFFF, /*hi*/ a >> 32}; +} + +constexpr inline pair full_add(uint64_t a, uint64_t b) { + uint64_t b_complement = ~b; + if (a <= b_complement) { + // No carry + return {a + b, 0}; + } + // Carry + return {a - b_complement - 1, 1}; +} + +constexpr inline bool half_add(pair &a, uint64_t b) { + uint64_t b_complement = ~b; + if (a.lo <= b_complement) { + // No carry + a.lo += b; + return false; + } + // Carry + a.lo -= b_complement + 1; + ++a.hi; + // Return true if overflowed. + return (a.hi == 0); +} + +constexpr inline pair full_mul(uint64_t a, uint64_t b) { + pair pa = split(a); + pair pb = split(b); + + pair result{/*lo*/ pa.lo * pb.lo, /*hi*/ pa.hi * pb.hi}; + + pair lo_hi = split(pa.lo * pb.hi); + pair hi_lo = split(pa.hi * pb.lo); + + result.hi += lo_hi.hi + hi_lo.hi; + half_add(result, lo_hi.lo << 32); + half_add(result, hi_lo.lo << 32); + + return result; +} + +} // namespace util + +namespace cpp { template class UInt { @@ -88,6 +143,20 @@ return uint8_t(uint64_t(*this)); } + // Zero-extend to a bigger UInt or truncate to a smaller UInt. + template = 0> + constexpr explicit operator UInt() const { + UInt result(0); + + constexpr size_t iter = (Bits > OtherBits) ? (OtherBits / 64) : WordCount; + + for (size_t i = 0; i < iter; ++i) { + result[i] = val[i]; + } + + return result; + } + UInt &operator=(const UInt &other) { for (size_t i = 0; i < WordCount; ++i) val[i] = other.val[i]; @@ -515,6 +584,36 @@ const uint64_t *data() const { return val; } }; +template +constexpr UInt full_mul(const UInt &a, + const UInt &b) { + UInt results(0); + + for (size_t i = 0; i < Bits1 / 64; ++i) { + util::pair carry; + + for (size_t j = 0; j < Bits2 / 64; ++j) { + util::pair p = util::full_mul(a[i], b[j]); + util::half_add(carry, p.lo); + util::half_add(carry, results[i + j]); + results[i + j] = carry.lo; + carry.lo = carry.hi; + carry.hi = 0; + util::half_add(carry, p.hi); + } + size_t pos = i + Bits2 / 64; + while (carry.lo | carry.hi) { + util::half_add(carry, results[pos]); + results[pos] = carry.lo; + carry.lo = carry.hi; + carry.hi = 0; + ++pos; + } + } + + return results; +} + template <> constexpr UInt<128> UInt<128>::operator*(const UInt<128> &other) const { // temp low covers bits 0-63, middle covers 32-95, high covers 64-127, and @@ -558,7 +657,11 @@ // Provides is_integral of UInt<128>. template <> struct is_integral> : public cpp::true_type {}; +template <> struct is_integral> : public cpp::true_type {}; +template <> struct is_integral> : public cpp::true_type {}; + +} // namespace cpp -} // namespace __llvm_libc::cpp +} // namespace __llvm_libc #endif // LLVM_LIBC_UTILS_UINT_H diff --git a/libc/test/src/__support/uint128_test.cpp b/libc/test/src/__support/uint128_test.cpp --- a/libc/test/src/__support/uint128_test.cpp +++ b/libc/test/src/__support/uint128_test.cpp @@ -15,6 +15,7 @@ // we use a sugar which does not conflict with the UInt128 type which can // resolve to __uint128_t if the platform has it. using LL_UInt128 = __llvm_libc::cpp::UInt<128>; +using LL_UInt256 = __llvm_libc::cpp::UInt<256>; TEST(LlvmLibcUInt128ClassTest, BasicInit) { LL_UInt128 empty; @@ -427,3 +428,22 @@ EXPECT_LE(a, a); EXPECT_GE(a, a); } + +TEST(LlvmLibcUInt128ClassTest, FullMulTest) { + LL_UInt128 a1({0xffffffff00000000, 0xffff00000000ffff}); + LL_UInt128 b1({0xff00ff0000ff00ff, 0xf0f0f0f00f0f0f0f}); + + LL_UInt256 expected1({0xff00ff0100000000, 0xeef1f1ef01fe00ff, + 0xe0e10f0d201e0e10, 0xf0efffff1e1ff0f1}); + LL_UInt256 result1(__llvm_libc::cpp::full_mul(a1, b1)); + + EXPECT_EQ(expected1, result1); + + LL_UInt128 a2({0xffffffffffffffff, 0xffffffffffffffff}); + LL_UInt128 b2({0xffffffffffffffff, 0xffffffffffffffff}); + + LL_UInt256 expected2({0x0000000000000001, 0x0000000000000000, + 0xfffffffffffffffe, 0xffffffffffffffff}); + LL_UInt256 result2(__llvm_libc::cpp::full_mul(a2, b2)); + EXPECT_EQ(expected2, result2); +} diff --git a/libc/utils/UnitTest/LibcTest.cpp b/libc/utils/UnitTest/LibcTest.cpp --- a/libc/utils/UnitTest/LibcTest.cpp +++ b/libc/utils/UnitTest/LibcTest.cpp @@ -54,9 +54,8 @@ // UInt128. // TODO(lntue): Investigate why UInt<128> was printed backward, with the lower // 64-bits first. -template -std::string describeValue128(UInt128Type Value) { - std::string S(sizeof(UInt128) * 2, '0'); +template std::string describeValueUInt(UIntType Value) { + std::string S(sizeof(UIntType) * 2, '0'); for (auto I = S.rbegin(), End = S.rend(); I != End; ++I, Value >>= 4) { unsigned char Mod = static_cast(Value) & 15; @@ -68,14 +67,26 @@ #ifdef __SIZEOF_INT128__ template <> std::string describeValue<__uint128_t>(__uint128_t Value) { - return describeValue128(Value); + return describeValueUInt(Value); } #endif template <> std::string describeValue<__llvm_libc::cpp::UInt<128>>(__llvm_libc::cpp::UInt<128> Value) { - return describeValue128(Value); + return describeValueUInt(Value); +} + +template <> +std::string +describeValue<__llvm_libc::cpp::UInt<192>>(__llvm_libc::cpp::UInt<192> Value) { + return describeValueUInt(Value); +} + +template <> +std::string +describeValue<__llvm_libc::cpp::UInt<256>>(__llvm_libc::cpp::UInt<256> Value) { + return describeValueUInt(Value); } template @@ -282,6 +293,16 @@ __llvm_libc::cpp::UInt<128> RHS, const char *LHSStr, const char *RHSStr, const char *File, unsigned long Line); +template bool test<__llvm_libc::cpp::UInt<192>>( + RunContext *Ctx, TestCondition Cond, __llvm_libc::cpp::UInt<192> LHS, + __llvm_libc::cpp::UInt<192> RHS, const char *LHSStr, const char *RHSStr, + const char *File, unsigned long Line); + +template bool test<__llvm_libc::cpp::UInt<256>>( + RunContext *Ctx, TestCondition Cond, __llvm_libc::cpp::UInt<256> LHS, + __llvm_libc::cpp::UInt<256> RHS, const char *LHSStr, const char *RHSStr, + const char *File, unsigned long Line); + template bool test<__llvm_libc::cpp::string_view>( RunContext *Ctx, TestCondition Cond, __llvm_libc::cpp::string_view LHS, __llvm_libc::cpp::string_view RHS, const char *LHSStr, const char *RHSStr,