diff --git a/libc/src/__support/UInt.h b/libc/src/__support/UInt.h --- a/libc/src/__support/UInt.h +++ b/libc/src/__support/UInt.h @@ -191,6 +191,78 @@ } } + // Return the full product. + template + constexpr UInt ful_mul(const UInt &other) const { + UInt result(0); + UInt<128> partial_sum(0); + uint64_t carry = 0; + constexpr size_t OtherWordCount = UInt::WordCount; + for (size_t i = 0; i <= WordCount + OtherWordCount - 2; ++i) { + const size_t lower_idx = i < OtherWordCount ? 0 : i - OtherWordCount + 1; + const size_t upper_idx = i < WordCount ? i : WordCount - 1; + for (size_t j = lower_idx; j <= upper_idx; ++j) { + NumberPair prod = full_mul(val[j], other.val[i - j]); + UInt<128> tmp({prod.lo, prod.hi}); + carry += partial_sum.add(tmp); + } + result.val[i] = partial_sum.val[0]; + partial_sum.val[0] = partial_sum.val[1]; + partial_sum.val[1] = carry; + carry = 0; + } + result.val[WordCount + OtherWordCount - 1] = partial_sum.val[0]; + return result; + } + + // Fast hi part of the full product. The normal product `operator*` returns + // `Bits` least significant bits of the full product, while this function will + // approximate `Bits` most significant bits of the full product with errors + // bounded by: + // 0 <= (a.full_mul(b) >> Bits) - a.quick_mul_hi(b)) <= WordCount - 1. + // + // An example usage of this is to quickly (but less accurately) compute the + // product of (normalized) mantissas of floating point numbers: + // (mant_1, mant_2) -> quick_mul_hi -> normalize leading bit + // is much more efficient than: + // (mant_1, mant_2) -> ful_mul -> normalize leading bit + // -> convert back to same Bits width by shifting/rounding, + // especially for higher precisions. + // + // Performance summary: + // Number of 64-bit x 64-bit -> 128-bit multiplications performed. + // Bits WordCount ful_mul quick_mul_hi Error bound + // 128 2 4 3 1 + // 196 3 9 6 2 + // 256 4 16 10 3 + // 512 8 64 36 7 + constexpr UInt quick_mul_hi(const UInt &other) const { + UInt result(0); + UInt<128> partial_sum(0); + uint64_t carry = 0; + // First round of accumulation for those at WordCount - 1 in the full + // product. + for (size_t i = 0; i < WordCount; ++i) { + NumberPair prod = + full_mul(val[i], other.val[WordCount - 1 - i]); + UInt<128> tmp({prod.lo, prod.hi}); + carry += partial_sum.add(tmp); + } + for (size_t i = WordCount; i < 2 * WordCount - 1; ++i) { + partial_sum.val[0] = partial_sum.val[1]; + partial_sum.val[1] = carry; + carry = 0; + for (size_t j = i - WordCount + 1; j < WordCount; ++j) { + NumberPair prod = full_mul(val[j], other.val[i - j]); + UInt<128> tmp({prod.lo, prod.hi}); + carry += partial_sum.add(tmp); + } + result.val[i - WordCount] = partial_sum.val[0]; + } + result.val[WordCount - 1] = partial_sum.val[1]; + return result; + } + // pow takes a power and sets this to its starting value to that power. Zero // to the zeroth power returns 1. constexpr void pow_n(uint64_t power) { diff --git a/libc/test/src/__support/CMakeLists.txt b/libc/test/src/__support/CMakeLists.txt --- a/libc/test/src/__support/CMakeLists.txt +++ b/libc/test/src/__support/CMakeLists.txt @@ -64,11 +64,11 @@ ) add_libc_unittest( - uint128_test + uint_test SUITE libc_support_unittests SRCS - uint128_test.cpp + uint_test.cpp DEPENDS libc.src.__support.uint libc.src.__support.CPP.optional diff --git a/libc/test/src/__support/uint128_test.cpp b/libc/test/src/__support/uint_test.cpp rename from libc/test/src/__support/uint128_test.cpp rename to libc/test/src/__support/uint_test.cpp --- a/libc/test/src/__support/uint128_test.cpp +++ b/libc/test/src/__support/uint_test.cpp @@ -1,4 +1,4 @@ -//===-- Unittests for the 128 bit integer class ---------------------------===// +//===-- Unittests for the UInt integer class ------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -17,15 +17,18 @@ using LL_UInt128 = __llvm_libc::cpp::UInt<128>; using LL_UInt192 = __llvm_libc::cpp::UInt<192>; using LL_UInt256 = __llvm_libc::cpp::UInt<256>; +using LL_UInt320 = __llvm_libc::cpp::UInt<320>; +using LL_UInt512 = __llvm_libc::cpp::UInt<512>; +using LL_UInt1024 = __llvm_libc::cpp::UInt<1024>; -TEST(LlvmLibcUInt128ClassTest, BasicInit) { +TEST(LlvmLibcUIntClassTest, BasicInit) { LL_UInt128 empty; LL_UInt128 half_val(12345); LL_UInt128 full_val({12345, 67890}); ASSERT_TRUE(half_val != full_val); } -TEST(LlvmLibcUInt128ClassTest, AdditionTests) { +TEST(LlvmLibcUIntClassTest, AdditionTests) { LL_UInt128 val1(12345); LL_UInt128 val2(54321); LL_UInt128 result1(66666); @@ -65,7 +68,7 @@ EXPECT_EQ(val9 + val10, val10 + val9); } -TEST(LlvmLibcUInt128ClassTest, SubtractionTests) { +TEST(LlvmLibcUIntClassTest, SubtractionTests) { LL_UInt128 val1(12345); LL_UInt128 val2(54321); LL_UInt128 result1({0xffffffffffff5c08, 0xffffffffffffffff}); @@ -94,7 +97,7 @@ EXPECT_EQ(val6, val5 + result6); } -TEST(LlvmLibcUInt128ClassTest, MultiplicationTests) { +TEST(LlvmLibcUIntClassTest, MultiplicationTests) { LL_UInt128 val1({5, 0}); LL_UInt128 val2({10, 0}); LL_UInt128 result1({50, 0}); @@ -154,7 +157,7 @@ EXPECT_EQ((val13 * val14), (val14 * val13)); } -TEST(LlvmLibcUInt128ClassTest, DivisionTests) { +TEST(LlvmLibcUIntClassTest, DivisionTests) { LL_UInt128 val1({10, 0}); LL_UInt128 val2({5, 0}); LL_UInt128 result1({2, 0}); @@ -201,7 +204,7 @@ EXPECT_FALSE(val13.div(val14).has_value()); } -TEST(LlvmLibcUInt128ClassTest, ModuloTests) { +TEST(LlvmLibcUIntClassTest, ModuloTests) { LL_UInt128 val1({10, 0}); LL_UInt128 val2({5, 0}); LL_UInt128 result1({0, 0}); @@ -248,7 +251,7 @@ EXPECT_EQ((val17 % val18), result9); } -TEST(LlvmLibcUInt128ClassTest, PowerTests) { +TEST(LlvmLibcUIntClassTest, PowerTests) { LL_UInt128 val1({10, 0}); val1.pow_n(30); LL_UInt128 result1({5076944270305263616, 54210108624}); // (10 ^ 30) @@ -299,7 +302,7 @@ } } -TEST(LlvmLibcUInt128ClassTest, ShiftLeftTests) { +TEST(LlvmLibcUIntClassTest, ShiftLeftTests) { LL_UInt128 val1(0x0123456789abcdef); LL_UInt128 result1(0x123456789abcdef0); EXPECT_EQ((val1 << 4), result1); @@ -325,7 +328,7 @@ EXPECT_EQ((val2 << 256), result6); } -TEST(LlvmLibcUInt128ClassTest, ShiftRightTests) { +TEST(LlvmLibcUIntClassTest, ShiftRightTests) { LL_UInt128 val1(0x0123456789abcdef); LL_UInt128 result1(0x00123456789abcde); EXPECT_EQ((val1 >> 4), result1); @@ -351,7 +354,7 @@ EXPECT_EQ((val2 >> 256), result6); } -TEST(LlvmLibcUInt128ClassTest, AndTests) { +TEST(LlvmLibcUIntClassTest, AndTests) { LL_UInt128 base({0xffff00000000ffff, 0xffffffff00000000}); LL_UInt128 val128({0xf0f0f0f00f0f0f0f, 0xff00ff0000ff00ff}); uint64_t val64 = 0xf0f0f0f00f0f0f0f; @@ -364,7 +367,7 @@ EXPECT_EQ((base & val32), result32); } -TEST(LlvmLibcUInt128ClassTest, OrTests) { +TEST(LlvmLibcUIntClassTest, OrTests) { LL_UInt128 base({0xffff00000000ffff, 0xffffffff00000000}); LL_UInt128 val128({0xf0f0f0f00f0f0f0f, 0xff00ff0000ff00ff}); uint64_t val64 = 0xf0f0f0f00f0f0f0f; @@ -377,7 +380,7 @@ EXPECT_EQ((base | val32), result32); } -TEST(LlvmLibcUInt128ClassTest, CompoundAssignments) { +TEST(LlvmLibcUIntClassTest, CompoundAssignments) { LL_UInt128 x({0xffff00000000ffff, 0xffffffff00000000}); LL_UInt128 b({0xf0f0f0f00f0f0f0f, 0xff00ff0000ff00ff}); @@ -419,7 +422,7 @@ EXPECT_EQ(a, mul_result); } -TEST(LlvmLibcUInt128ClassTest, UnaryPredecrement) { +TEST(LlvmLibcUIntClassTest, UnaryPredecrement) { LL_UInt128 a = LL_UInt128({0x1111111111111111, 0x1111111111111111}); ++a; EXPECT_EQ(a, LL_UInt128({0x1111111111111112, 0x1111111111111111})); @@ -433,7 +436,7 @@ EXPECT_EQ(a, LL_UInt128({0x0, 0x0})); } -TEST(LlvmLibcUInt128ClassTest, EqualsTests) { +TEST(LlvmLibcUIntClassTest, EqualsTests) { LL_UInt128 a1({0xffffffff00000000, 0xffff00000000ffff}); LL_UInt128 a2({0xffffffff00000000, 0xffff00000000ffff}); LL_UInt128 b({0xff00ff0000ff00ff, 0xf0f0f0f00f0f0f0f}); @@ -449,7 +452,7 @@ ASSERT_TRUE(a_lower != a_upper); } -TEST(LlvmLibcUInt128ClassTest, ComparisonTests) { +TEST(LlvmLibcUIntClassTest, ComparisonTests) { LL_UInt128 a({0xffffffff00000000, 0xffff00000000ffff}); LL_UInt128 b({0xff00ff0000ff00ff, 0xf0f0f0f00f0f0f0f}); EXPECT_GT(a, b); @@ -467,3 +470,40 @@ EXPECT_LE(a, a); EXPECT_GE(a, a); } + +TEST(LlvmLibcUIntClassTest, FullMulTests) { + LL_UInt128 a({0xffffffffffffffffULL, 0xffffffffffffffffULL}); + LL_UInt128 b({0xfedcba9876543210ULL, 0xfefdfcfbfaf9f8f7ULL}); + LL_UInt256 r({0x0123456789abcdf0ULL, 0x0102030405060708ULL, + 0xfedcba987654320fULL, 0xfefdfcfbfaf9f8f7ULL}); + LL_UInt128 r_hi({0xfedcba987654320eULL, 0xfefdfcfbfaf9f8f7ULL}); + + EXPECT_EQ(a.ful_mul(b), r); + EXPECT_EQ(a.quick_mul_hi(b), r_hi); + + LL_UInt192 c( + {0x7766554433221101ULL, 0xffeeddccbbaa9988ULL, 0x1f2f3f4f5f6f7f8fULL}); + LL_UInt320 rr({0x8899aabbccddeeffULL, 0x0011223344556677ULL, + 0x583715f4d3b29171ULL, 0xffeeddccbbaa9988ULL, + 0x1f2f3f4f5f6f7f8fULL}); + + EXPECT_EQ(a.ful_mul(c), rr); + EXPECT_EQ(a.ful_mul(c), c.ful_mul(a)); +} + +#define TEST_QUICK_MUL_HI(Bits, Error) \ + do { \ + LL_UInt##Bits a = ~LL_UInt##Bits(0); \ + LL_UInt##Bits hi = a.quick_mul_hi(a); \ + LL_UInt##Bits trunc = static_cast(a.ful_mul(a) >> Bits); \ + uint64_t overflow = trunc.sub(hi); \ + EXPECT_EQ(overflow, uint64_t(0)); \ + EXPECT_LE(uint64_t(trunc), uint64_t(Error)); \ + } while (0) + +TEST(LlvmLibcUIntClassTest, QuickMulHiTests) { + TEST_QUICK_MUL_HI(128, 1); + TEST_QUICK_MUL_HI(192, 2); + TEST_QUICK_MUL_HI(256, 3); + TEST_QUICK_MUL_HI(512, 7); +} diff --git a/libc/utils/UnitTest/LibcTest.cpp b/libc/utils/UnitTest/LibcTest.cpp --- a/libc/utils/UnitTest/LibcTest.cpp +++ b/libc/utils/UnitTest/LibcTest.cpp @@ -288,6 +288,11 @@ __llvm_libc::cpp::UInt<256> RHS, const char *LHSStr, const char *RHSStr, const char *File, unsigned long Line); +template bool test<__llvm_libc::cpp::UInt<320>>( + RunContext *Ctx, TestCondition Cond, __llvm_libc::cpp::UInt<320> LHS, + __llvm_libc::cpp::UInt<320> RHS, const char *LHSStr, const char *RHSStr, + const char *File, unsigned long Line); + template bool test<__llvm_libc::cpp::string_view>( RunContext *Ctx, TestCondition Cond, __llvm_libc::cpp::string_view LHS, __llvm_libc::cpp::string_view RHS, const char *LHSStr, const char *RHSStr,