diff --git a/libc/src/__support/CMakeLists.txt b/libc/src/__support/CMakeLists.txt --- a/libc/src/__support/CMakeLists.txt +++ b/libc/src/__support/CMakeLists.txt @@ -115,12 +115,23 @@ libc.src.__support.CPP.array ) +add_header_library( + number_pair + HDRS + number_pair.h + DEPENDS + .builtin_wrappers + libc.src.__support.CPP.type_traits +) + add_header_library( uint HDRS UInt.h DEPENDS + .number_pair libc.src.__support.CPP.array + libc.src.__support.CPP.type_traits ) add_header_library( diff --git a/libc/src/__support/UInt.h b/libc/src/__support/UInt.h --- a/libc/src/__support/UInt.h +++ b/libc/src/__support/UInt.h @@ -6,21 +6,22 @@ // //===----------------------------------------------------------------------===// -#ifndef LLVM_LIBC_UTILS_UINT_H -#define LLVM_LIBC_UTILS_UINT_H +#ifndef LLVM_LIBC_SRC_SUPPORT_UINT_H +#define LLVM_LIBC_SRC_SUPPORT_UINT_H #include "src/__support/CPP/array.h" #include "src/__support/CPP/limits.h" #include "src/__support/CPP/optional.h" #include "src/__support/CPP/type_traits.h" #include "src/__support/builtin_wrappers.h" +#include "src/__support/number_pair.h" #include // For size_t #include namespace __llvm_libc::cpp { -template class UInt { +template struct UInt { static_assert(Bits > 0 && Bits % 64 == 0, "Number of bits in UInt should be a multiple of 64."); @@ -32,7 +33,6 @@ static constexpr uint64_t low(uint64_t v) { return v & MASK32; } static constexpr uint64_t high(uint64_t v) { return (v >> 32) & MASK32; } -public: constexpr UInt() {} constexpr UInt(const UInt &other) { @@ -100,37 +100,28 @@ // property of unsigned integers: // x + (~x) = 2^(sizeof(x)) - 1. constexpr uint64_t add(const UInt &x) { - bool carry = false; + uint64_t carry_in = 0; + uint64_t carry_out = 0; for (size_t i = 0; i < WordCount; ++i) { - uint64_t complement = ~x.val[i]; - if (!carry) { - if (val[i] <= complement) - val[i] += x.val[i]; - else { - val[i] -= complement + 1; - carry = true; - } - } else { - if (val[i] < complement) { - val[i] += x.val[i] + 1; - carry = false; - } else - val[i] -= complement; - } + val[i] = add_with_carry(val[i], x.val[i], carry_in, carry_out); + carry_in = carry_out; } - return carry ? 1 : 0; + return carry_out; } constexpr UInt operator+(const UInt &other) const { - UInt result(*this); - result.add(other); - // TODO(lntue): Set overflow flag / errno when carry is true. + UInt result; + uint64_t carry_in = 0; + uint64_t carry_out = 0; + for (size_t i = 0; i < WordCount; ++i) { + result.val[i] = add_with_carry(val[i], other.val[i], carry_in, carry_out); + carry_in = carry_out; + } return result; } constexpr UInt operator+=(const UInt &other) { - // TODO(lntue): Set overflow flag / errno when carry is true. - add(other); + add(other); // Returned carry value is ignored. return *this; } @@ -183,68 +174,40 @@ // carry bits. // Returns the carry value produced by the multiplication operation. constexpr uint64_t mul(uint64_t x) { - uint64_t x_lo = low(x); - uint64_t x_hi = high(x); - - cpp::array row1; + UInt<128> partial_sum(0); uint64_t carry = 0; for (size_t i = 0; i < WordCount; ++i) { - uint64_t l = low(val[i]); - uint64_t h = high(val[i]); - uint64_t p1 = x_lo * l; - uint64_t p2 = x_lo * h; - - uint64_t res_lo = low(p1) + carry; - carry = high(res_lo); - uint64_t res_hi = high(p1) + low(p2) + carry; - carry = high(res_hi) + high(p2); - - res_lo = low(res_lo); - res_hi = low(res_hi); - row1[i] = res_lo + (res_hi << 32); - } - row1[WordCount] = carry; - - cpp::array row2; - row2[0] = 0; - carry = 0; - for (size_t i = 0; i < WordCount; ++i) { - uint64_t l = low(val[i]); - uint64_t h = high(val[i]); - uint64_t p1 = x_hi * l; - uint64_t p2 = x_hi * h; - - uint64_t res_lo = low(p1) + carry; - carry = high(res_lo); - uint64_t res_hi = high(p1) + low(p2) + carry; - carry = high(res_hi) + high(p2); - - res_lo = low(res_lo); - res_hi = low(res_hi); - row2[i] = res_lo + (res_hi << 32); + number_pair prod = full_mul(val[i], x); + UInt<128> tmp({prod.lo, prod.hi}); + carry += partial_sum.add(tmp); + val[i] = partial_sum.val[0]; + partial_sum.val[0] = partial_sum.val[1]; + partial_sum.val[1] = carry; + carry = 0; } - row2[WordCount] = carry; - - UInt<(WordCount + 1) * 64> r1(row1), r2(row2); - r2.shift_left(32); - r1.add(r2); - for (size_t i = 0; i < WordCount; ++i) { - val[i] = r1[i]; - } - return r1[WordCount]; + return partial_sum.val[1]; } constexpr UInt operator*(const UInt &other) const { - UInt result(0); - for (size_t i = 0; i < WordCount; ++i) { - if (other[i] == 0) - continue; - UInt row_result(*this); - row_result.mul(other[i]); - row_result.shift_left(64 * i); - result = result + row_result; + if constexpr (WordCount == 1) { + return {val[0] * other.val[0]}; + } else { + UInt result(0); + UInt<128> partial_sum(0); + uint64_t carry = 0; + for (size_t i = 0; i < WordCount; ++i) { + for (size_t j = 0; j <= i; j++) { + number_pair prod = full_mul(val[j], other.val[i - j]); + UInt<128> tmp({prod.lo, prod.hi}); + carry += partial_sum.add(tmp); + } + result.val[i] = partial_sum.val[0]; + partial_sum.val[0] = partial_sum.val[1]; + partial_sum.val[1] = carry; + carry = 0; + } + return result; } - return result; } // pow takes a power and sets this to its starting value to that power. Zero @@ -325,21 +288,36 @@ } constexpr void shift_left(size_t s) { +#ifdef __SIZEOF_INT128 + if constexpr (Bits == 128) { + // Use builtin 128 bits if available; + if (s >= 128) { + a.val[0] = 0; + a.val[1] = 0; + return; + } + __uint128_t tmp = __uint128_t(val[0]) + (__uint128_t(val[1]) << 64); + tmp <<= s; + val[0] = uint64_t(tmp); + val[1] = uint64_t(tmp >> 64); + return; + } +#endif // __SIZEOF_INT128 const size_t drop = s / 64; // Number of words to drop const size_t shift = s % 64; // Bits to shift in the remaining words. - const uint64_t mask = ((uint64_t(1) << shift) - 1) << (64 - shift); + size_t i = WordCount; - for (size_t i = WordCount; drop > 0 && i > 0; --i) { - if (i > drop) - val[i - 1] = val[i - drop - 1]; - else - val[i - 1] = 0; + if (drop < WordCount) { + i = WordCount - 1; + size_t j = WordCount - 1 - drop; + for (; j; --i, --j) { + val[i] = (val[j] << shift) | (val[j - 1] >> (64 - shift)); + } + val[i] = val[0] << shift; } - for (size_t i = WordCount; shift > 0 && i > drop; --i) { - uint64_t drop_val = (val[i - 1] & mask) >> (64 - shift); - val[i - 1] <<= shift; - if (i < WordCount) - val[i] |= drop_val; + + for (size_t j = 0; j < i; ++j) { + val[j] = 0; } } @@ -357,19 +335,20 @@ constexpr void shift_right(size_t s) { const size_t drop = s / 64; // Number of words to drop const size_t shift = s % 64; // Bit shift in the remaining words. - const uint64_t mask = (uint64_t(1) << shift) - 1; - for (size_t i = 0; drop > 0 && i < WordCount; ++i) { - if (i + drop < WordCount) - val[i] = val[i + drop]; - else - val[i] = 0; + size_t i = 0; + + if (drop < WordCount) { + size_t j = drop; + for (; j < WordCount - 1; ++i, ++j) { + val[i] = (val[j] >> shift) | (val[j + 1] << (64 - shift)); + } + val[i] = val[j] >> shift; + ++i; } - for (size_t i = 0; shift > 0 && i < WordCount; ++i) { - uint64_t drop_val = ((val[i] & mask) << (64 - shift)); - val[i] >>= shift; - if (i > 0) - val[i - 1] |= drop_val; + + for (; i < WordCount; ++i) { + val[i] = 0; } } @@ -556,9 +535,16 @@ static constexpr UInt<128> min() { return 0; } }; -// Provides is_integral of UInt<128>. +// Provides is_integral of UInt<128>, UInt<192>, UInt<256>. template <> struct is_integral> : public cpp::true_type {}; +template <> struct is_integral> : public cpp::true_type {}; +template <> struct is_integral> : public cpp::true_type {}; + +// Provides is_unsigned of UInt<128>, UInt<192>, UInt<256>. +template <> struct is_unsigned> : public cpp::true_type {}; +template <> struct is_unsigned> : public cpp::true_type {}; +template <> struct is_unsigned> : public cpp::true_type {}; } // namespace __llvm_libc::cpp -#endif // LLVM_LIBC_UTILS_UINT_H +#endif // LLVM_LIBC_SRC_SUPPORT_UINT_H diff --git a/libc/src/__support/number_pair.h b/libc/src/__support/number_pair.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/number_pair.h @@ -0,0 +1,80 @@ +//===-- Utilities for pairs of numbers. -------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SUPPORT_NUMBER_PAIR_H +#define LLVM_LIBC_SRC_SUPPORT_NUMBER_PAIR_H + +#include "CPP/type_traits.h" +#include "builtin_wrappers.h" + +#include + +namespace __llvm_libc { + +template struct number_pair { + T lo; + T hi; +}; + +using DoubleDouble = number_pair; + +template +cpp::enable_if_t && cpp::is_unsigned_v, number_pair> +split(T a) { + constexpr size_t HALF_BIT_WIDTH = sizeof(T) * 4; + constexpr T LOWER_HALF_MASK = (T(1) << HALF_BIT_WIDTH) - T(1); + number_pair result; + result.lo = a & LOWER_HALF_MASK; + result.hi = a >> HALF_BIT_WIDTH; + return result; +} + +template number_pair full_mul(T a, T b); + +template <> +inline number_pair full_mul(uint32_t a, uint32_t b) { + uint64_t prod = uint64_t(a) * uint64_t(b); + number_pair result; + result.lo = uint32_t(prod); + result.hi = uint32_t(prod >> 32); + return result; +} + +template <> +inline number_pair full_mul(uint64_t a, uint64_t b) { +#ifdef __SIZEOF_INT128 + __uint128_t prod = __uint128_t(a) * __uint128_t(b); + number_pair result; + result.lo = uint64_t(prod); + result.hi = uint64_t(prod >> 64); + return result; +#else + number_pair pa = split(a); + number_pair pb = split(b); + number_pair prod; + + prod.lo = pa.lo * pb.lo; // exact + prod.hi = pa.hi * pb.hi; // exact + number_pair lo_hi = split(pa.lo * pb.hi); // exact + number_pair hi_lo = split(pa.hi * pb.lo); // exact + + uint64_t carry_in = 0; + uint64_t carry_out = 0; + uint64_t carry_unused = 0; + prod.lo = add_with_carry(prod.lo, lo_hi.lo << 32, carry_in, carry_out); + prod.hi = add_with_carry(prod.hi, lo_hi.hi, carry_out, carry_unused); + prod.lo = add_with_carry(prod.lo, hi_lo.lo << 32, carry_in, carry_out); + prod.hi = add_with_carry(prod.hi, hi_lo.hi, carry_out, carry_unused); + + return prod; +#endif // __SIZEOF_INT128 +} + +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_SUPPORT_NUMBER_PAIR_H diff --git a/libc/test/src/__support/uint128_test.cpp b/libc/test/src/__support/uint128_test.cpp --- a/libc/test/src/__support/uint128_test.cpp +++ b/libc/test/src/__support/uint128_test.cpp @@ -15,6 +15,8 @@ // we use a sugar which does not conflict with the UInt128 type which can // resolve to __uint128_t if the platform has it. using LL_UInt128 = __llvm_libc::cpp::UInt<128>; +using LL_UInt192 = __llvm_libc::cpp::UInt<192>; +using LL_UInt256 = __llvm_libc::cpp::UInt<256>; TEST(LlvmLibcUInt128ClassTest, BasicInit) { LL_UInt128 empty; @@ -43,6 +45,24 @@ LL_UInt128 result3({0x12346789bcdf1233, 0xa987765443210fed}); EXPECT_EQ(val5 + val6, result3); EXPECT_EQ(val5 + val6, val6 + val5); + + // Test 192-bit addition + LL_UInt192 val7({0x0123456789abcdef, 0xfedcba9876543210, 0xfedcba9889abcdef}); + LL_UInt192 val8({0x1111222233334444, 0xaaaabbbbccccdddd, 0xeeeeffffeeeeffff}); + LL_UInt192 result4( + {0x12346789bcdf1233, 0xa987765443210fed, 0xedcbba98789acdef}); + EXPECT_EQ(val7 + val8, result4); + EXPECT_EQ(val7 + val8, val8 + val7); + + // Test 256-bit addition + LL_UInt256 val9({0x1f1e1d1c1b1a1918, 0xf1f2f3f4f5f6f7f8, 0x0123456789abcdef, + 0xfedcba9876543210}); + LL_UInt256 val10({0x1111222233334444, 0xaaaabbbbccccdddd, 0x1111222233334444, + 0xaaaabbbbccccdddd}); + LL_UInt256 result5({0x302f3f3e4e4d5d5c, 0x9c9dafb0c2c3d5d5, + 0x12346789bcdf1234, 0xa987765443210fed}); + EXPECT_EQ(val9 + val10, result5); + EXPECT_EQ(val9 + val10, val10 + val9); } TEST(LlvmLibcUInt128ClassTest, SubtractionTests) { @@ -112,6 +132,26 @@ LL_UInt128 result5({0x917cf11d1e039c50, 0x3a4f32d17f40d08f}); EXPECT_EQ((val9 * val10), result5); EXPECT_EQ((val9 * val10), (val10 * val9)); + + // Test 192-bit multiplication + LL_UInt192 val11( + {0xffffffffffffffff, 0x01D762422C946590, 0x9F4F2726179A2245}); + LL_UInt192 val12( + {0xffffffffffffffff, 0x3792F412CB06794D, 0xCDB02555653131B6}); + + LL_UInt192 result6( + {0x0000000000000001, 0xc695a9ab08652121, 0x5de7faf698d32732}); + EXPECT_EQ((val11 * val12), result6); + EXPECT_EQ((val11 * val12), (val12 * val11)); + + LL_UInt256 val13({0xffffffffffffffff, 0x01D762422C946590, 0x9F4F2726179A2245, + 0xffffffffffffffff}); + LL_UInt256 val14({0xffffffffffffffff, 0xffffffffffffffff, 0x3792F412CB06794D, + 0xCDB02555653131B6}); + LL_UInt256 result7({0x0000000000000001, 0xfe289dbdd36b9a6f, + 0x291de4c71d5f646c, 0xfd37221cb06d4978}); + EXPECT_EQ((val13 * val14), result7); + EXPECT_EQ((val13 * val14), (val14 * val13)); } TEST(LlvmLibcUInt128ClassTest, DivisionTests) { diff --git a/libc/utils/UnitTest/LibcTest.cpp b/libc/utils/UnitTest/LibcTest.cpp --- a/libc/utils/UnitTest/LibcTest.cpp +++ b/libc/utils/UnitTest/LibcTest.cpp @@ -35,47 +35,43 @@ namespace internal { -// When the value is of integral type, just display it as normal. -template -cpp::enable_if_t, std::string> -describeValue(ValType Value) { - return std::to_string(Value); -} - -std::string describeValue(std::string Value) { return std::string(Value); } -std::string describeValue(cpp::string_view Value) { - return std::string(Value.data(), Value.size()); -} - // When the value is UInt128 or __uint128_t, show its hexadecimal digits. // We cannot just use a UInt128 specialization as that resolves to only // one type, UInt<128> or __uint128_t. We want both overloads as we want to // be able to unittest UInt<128> on platforms where UInt128 resolves to // UInt128. -// TODO(lntue): Investigate why UInt<128> was printed backward, with the lower -// 64-bits first. -template -std::string describeValue128(UInt128Type Value) { - std::string S(sizeof(UInt128) * 2, '0'); - - for (auto I = S.rbegin(), End = S.rend(); I != End; ++I, Value >>= 4) { - unsigned char Mod = static_cast(Value) & 15; - *I = Mod < 10 ? '0' + Mod : 'a' + Mod - 10; +template +cpp::enable_if_t && cpp::is_unsigned_v, std::string> +describeValueUInt(T Value) { + static_assert(sizeof(T) % 8 == 0, "Unsupported size of UInt"); + std::string S(sizeof(T) * 2, '0'); + + constexpr char HEXADECIMALS[16] = {'0', '1', '2', '3', '4', '5', '6', '7', + '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}; + + for (auto I = S.rbegin(), End = S.rend(); I != End; ++I, Value >>= 8) { + unsigned char Mod = static_cast(Value) & 0xFF; + *(I++) = HEXADECIMALS[Mod & 0x0F]; + *I = HEXADECIMALS[Mod >> 4]; } return "0x" + S; } -#ifdef __SIZEOF_INT128__ -template <> std::string describeValue<__uint128_t>(__uint128_t Value) { - return describeValue128(Value); +// When the value is of integral type, just display it as normal. +template +cpp::enable_if_t, std::string> +describeValue(ValType Value) { + if constexpr (sizeof(ValType) <= sizeof(uint64_t)) { + return std::to_string(Value); + } else { + return describeValueUInt(Value); + } } -#endif -template <> -std::string -describeValue<__llvm_libc::cpp::UInt<128>>(__llvm_libc::cpp::UInt<128> Value) { - return describeValue128(Value); +std::string describeValue(std::string Value) { return std::string(Value); } +std::string describeValue(cpp::string_view Value) { + return std::string(Value.data(), Value.size()); } template @@ -282,6 +278,16 @@ __llvm_libc::cpp::UInt<128> RHS, const char *LHSStr, const char *RHSStr, const char *File, unsigned long Line); +template bool test<__llvm_libc::cpp::UInt<192>>( + RunContext *Ctx, TestCondition Cond, __llvm_libc::cpp::UInt<192> LHS, + __llvm_libc::cpp::UInt<192> RHS, const char *LHSStr, const char *RHSStr, + const char *File, unsigned long Line); + +template bool test<__llvm_libc::cpp::UInt<256>>( + RunContext *Ctx, TestCondition Cond, __llvm_libc::cpp::UInt<256> LHS, + __llvm_libc::cpp::UInt<256> RHS, const char *LHSStr, const char *RHSStr, + const char *File, unsigned long Line); + template bool test<__llvm_libc::cpp::string_view>( RunContext *Ctx, TestCondition Cond, __llvm_libc::cpp::string_view LHS, __llvm_libc::cpp::string_view RHS, const char *LHSStr, const char *RHSStr,