diff --git a/libc/config/linux/aarch64/entrypoints.txt b/libc/config/linux/aarch64/entrypoints.txt --- a/libc/config/linux/aarch64/entrypoints.txt +++ b/libc/config/linux/aarch64/entrypoints.txt @@ -75,6 +75,9 @@ libc.src.math.roundl libc.src.math.sincosf libc.src.math.sinf + libc.src.math.sqrt + libc.src.math.sqrtf + libc.src.math.sqrtl libc.src.math.trunc libc.src.math.truncf libc.src.math.truncl diff --git a/libc/config/linux/api.td b/libc/config/linux/api.td --- a/libc/config/linux/api.td +++ b/libc/config/linux/api.td @@ -204,6 +204,9 @@ "roundl", "sincosf", "sinf", + "sqrt", + "sqrtf", + "sqrtl", "trunc", "truncf", "truncl", diff --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt --- a/libc/config/linux/x86_64/entrypoints.txt +++ b/libc/config/linux/x86_64/entrypoints.txt @@ -108,6 +108,9 @@ libc.src.math.roundl libc.src.math.sincosf libc.src.math.sinf + libc.src.math.sqrt + libc.src.math.sqrtf + libc.src.math.sqrtl libc.src.math.trunc libc.src.math.truncf libc.src.math.truncl diff --git a/libc/spec/stdc.td b/libc/spec/stdc.td --- a/libc/spec/stdc.td +++ b/libc/spec/stdc.td @@ -314,6 +314,10 @@ FunctionSpec<"roundf", RetValSpec, [ArgSpec]>, FunctionSpec<"roundl", RetValSpec, [ArgSpec]>, + FunctionSpec<"sqrt", RetValSpec, [ArgSpec]>, + FunctionSpec<"sqrtf", RetValSpec, [ArgSpec]>, + FunctionSpec<"sqrtl", RetValSpec, [ArgSpec]>, + FunctionSpec<"trunc", RetValSpec, [ArgSpec]>, FunctionSpec<"truncf", RetValSpec, [ArgSpec]>, FunctionSpec<"truncl", RetValSpec, [ArgSpec]>, diff --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt --- a/libc/src/math/CMakeLists.txt +++ b/libc/src/math/CMakeLists.txt @@ -485,3 +485,39 @@ COMPILE_OPTIONS -O2 ) + +add_entrypoint_object( + sqrt + SRCS + sqrt.cpp + HDRS + sqrt.h + DEPENDS + libc.utils.FPUtil.fputil + COMPILE_OPTIONS + -O2 +) + +add_entrypoint_object( + sqrtf + SRCS + sqrtf.cpp + HDRS + sqrtf.h + DEPENDS + libc.utils.FPUtil.fputil + COMPILE_OPTIONS + -O2 +) + +add_entrypoint_object( + sqrtl + SRCS + sqrtl.cpp + HDRS + sqrtl.h + DEPENDS + libc.utils.FPUtil.fputil + COMPILE_OPTIONS + -O2 +) diff --git a/libc/src/math/sqrt.h b/libc/src/math/sqrt.h new file mode 100644 --- /dev/null +++ b/libc/src/math/sqrt.h @@ -0,0 +1,18 @@ +//===-- Implementation header for sqrt --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_MATH_SQRT_H +#define LLVM_LIBC_SRC_MATH_SQRT_H + +namespace __llvm_libc { + +double sqrt(double x); + +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_MATH_SQRT_H diff --git a/libc/src/math/sqrt.cpp b/libc/src/math/sqrt.cpp new file mode 100644 --- /dev/null +++ b/libc/src/math/sqrt.cpp @@ -0,0 +1,16 @@ +//===-- Implementation of sqrt function -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "utils/FPUtil/Sqrt.h" +#include "src/__support/common.h" + +namespace __llvm_libc { + +double LLVM_LIBC_ENTRYPOINT(sqrt)(double x) { return fputil::sqrt(x); } + +} // namespace __llvm_libc diff --git a/libc/src/math/sqrtf.h b/libc/src/math/sqrtf.h new file mode 100644 --- /dev/null +++ b/libc/src/math/sqrtf.h @@ -0,0 +1,18 @@ +//===-- Implementation header for sqrtf -------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_MATH_SQRTF_H +#define LLVM_LIBC_SRC_MATH_SQRTF_H + +namespace __llvm_libc { + +float sqrtf(float x); + +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_MATH_SQRTF_H diff --git a/libc/src/math/sqrtf.cpp b/libc/src/math/sqrtf.cpp new file mode 100644 --- /dev/null +++ b/libc/src/math/sqrtf.cpp @@ -0,0 +1,16 @@ +//===-- Implementation of sqrtf function ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/__support/common.h" +#include "utils/FPUtil/Sqrt.h" + +namespace __llvm_libc { + +float LLVM_LIBC_ENTRYPOINT(sqrtf)(float x) { return fputil::sqrt(x); } + +} // namespace __llvm_libc diff --git a/libc/src/math/sqrtl.h b/libc/src/math/sqrtl.h new file mode 100644 --- /dev/null +++ b/libc/src/math/sqrtl.h @@ -0,0 +1,18 @@ +//===-- Implementation header for sqrtl -------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_MATH_SQRTL_H +#define LLVM_LIBC_SRC_MATH_SQRTL_H + +namespace __llvm_libc { + +long double sqrtl(long double x); + +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_MATH_SQRTL_H diff --git a/libc/src/math/sqrtl.cpp b/libc/src/math/sqrtl.cpp new file mode 100644 --- /dev/null +++ b/libc/src/math/sqrtl.cpp @@ -0,0 +1,18 @@ +//===-- Implementation of sqrtl function ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/__support/common.h" +#include "utils/FPUtil/Sqrt.h" + +namespace __llvm_libc { + +long double LLVM_LIBC_ENTRYPOINT(sqrtl)(long double x) { + return fputil::sqrt(x); +} + +} // namespace __llvm_libc diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt --- a/libc/test/src/math/CMakeLists.txt +++ b/libc/test/src/math/CMakeLists.txt @@ -510,3 +510,42 @@ libc.src.math.fmaxl libc.utils.FPUtil.fputil ) + +add_fp_unittest( + sqrtf_test + NEED_MPFR + SUITE + libc_math_unittests + SRCS + sqrtf_test.cpp + DEPENDS + libc.include.math + libc.src.math.sqrtf + libc.utils.FPUtil.fputil +) + +add_fp_unittest( + sqrt_test + NEED_MPFR + SUITE + libc_math_unittests + SRCS + sqrt_test.cpp + DEPENDS + libc.include.math + libc.src.math.sqrt + libc.utils.FPUtil.fputil +) + +add_fp_unittest( + sqrtl_test + NEED_MPFR + SUITE + libc_math_unittests + SRCS + sqrtl_test.cpp + DEPENDS + libc.include.math + libc.src.math.sqrtl + libc.utils.FPUtil.fputil +) diff --git a/libc/test/src/math/sqrt_test.cpp b/libc/test/src/math/sqrt_test.cpp new file mode 100644 --- /dev/null +++ b/libc/test/src/math/sqrt_test.cpp @@ -0,0 +1,67 @@ +//===-- Unittests for sqrt -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include "include/math.h" +#include "src/math/sqrt.h" +#include "utils/FPUtil/FPBits.h" +#include "utils/FPUtil/TestHelpers.h" +#include "utils/MPFRWrapper/MPFRUtils.h" + +using FPBits = __llvm_libc::fputil::FPBits; +using UIntType = typename FPBits::UIntType; + +namespace mpfr = __llvm_libc::testing::mpfr; + +constexpr UIntType HiddenBit = + UIntType(1) << __llvm_libc::fputil::MantissaWidth::value; + +double nan = FPBits::buildNaN(1); +double inf = FPBits::inf(); +double negInf = FPBits::negInf(); + +TEST(SqrtTest, SpecialValues) { + ASSERT_FP_EQ(nan, __llvm_libc::sqrt(nan)); + ASSERT_FP_EQ(inf, __llvm_libc::sqrt(inf)); + ASSERT_FP_EQ(nan, __llvm_libc::sqrt(negInf)); + ASSERT_FP_EQ(0.0, __llvm_libc::sqrt(0.0)); + ASSERT_FP_EQ(-0.0, __llvm_libc::sqrt(-0.0)); + ASSERT_FP_EQ(nan, __llvm_libc::sqrt(-1.0)); + ASSERT_FP_EQ(1.0, __llvm_libc::sqrt(1.0)); + ASSERT_FP_EQ(2.0, __llvm_libc::sqrt(4.0)); + ASSERT_FP_EQ(3.0, __llvm_libc::sqrt(9.0)); +} + +TEST(SqrtTest, DenormalValues) { + for (UIntType mant = 1; mant < HiddenBit; mant <<= 1) { + FPBits denormal(0.0); + denormal.mantissa = mant; + + ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, double(denormal), + __llvm_libc::sqrt(denormal), 0.5); + } + + constexpr UIntType count = 1'000'001; + constexpr UIntType step = HiddenBit / count; + for (UIntType i = 0, v = 0; i <= count; ++i, v += step) { + double x = *reinterpret_cast(&v); + ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrt(x), 0.5); + } +} + +TEST(SqrtTest, InDoubleRange) { + constexpr UIntType count = 10'000'001; + constexpr UIntType step = UIntType(-1) / count; + for (UIntType i = 0, v = 0; i <= count; ++i, v += step) { + double x = *reinterpret_cast(&v); + if (isnan(x) || (x < 0)) { + continue; + } + + ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrt(x), 0.5); + } +} diff --git a/libc/test/src/math/sqrtf_test.cpp b/libc/test/src/math/sqrtf_test.cpp new file mode 100644 --- /dev/null +++ b/libc/test/src/math/sqrtf_test.cpp @@ -0,0 +1,67 @@ +//===-- Unittests for sqrtf -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include "include/math.h" +#include "src/math/sqrtf.h" +#include "utils/FPUtil/FPBits.h" +#include "utils/FPUtil/TestHelpers.h" +#include "utils/MPFRWrapper/MPFRUtils.h" + +using FPBits = __llvm_libc::fputil::FPBits; +using UIntType = typename FPBits::UIntType; + +namespace mpfr = __llvm_libc::testing::mpfr; + +constexpr UIntType HiddenBit = + UIntType(1) << __llvm_libc::fputil::MantissaWidth::value; + +float nan = FPBits::buildNaN(1); +float inf = FPBits::inf(); +float negInf = FPBits::negInf(); + +TEST(SqrtfTest, SpecialValues) { + ASSERT_FP_EQ(nan, __llvm_libc::sqrtf(nan)); + ASSERT_FP_EQ(inf, __llvm_libc::sqrtf(inf)); + ASSERT_FP_EQ(nan, __llvm_libc::sqrtf(negInf)); + ASSERT_FP_EQ(0.0f, __llvm_libc::sqrtf(0.0f)); + ASSERT_FP_EQ(-0.0f, __llvm_libc::sqrtf(-0.0f)); + ASSERT_FP_EQ(nan, __llvm_libc::sqrtf(-1.0f)); + ASSERT_FP_EQ(1.0f, __llvm_libc::sqrtf(1.0f)); + ASSERT_FP_EQ(2.0f, __llvm_libc::sqrtf(4.0f)); + ASSERT_FP_EQ(3.0f, __llvm_libc::sqrtf(9.0f)); +} + +TEST(SqrtfTest, DenormalValues) { + for (UIntType mant = 1; mant < HiddenBit; mant <<= 1) { + FPBits denormal(0.0f); + denormal.mantissa = mant; + + ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, float(denormal), + __llvm_libc::sqrtf(denormal), 0.5); + } + + constexpr UIntType count = 1'000'001; + constexpr UIntType step = HiddenBit / count; + for (UIntType i = 0, v = 0; i <= count; ++i, v += step) { + float x = *reinterpret_cast(&v); + ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrtf(x), 0.5); + } +} + +TEST(SqrtfTest, InFloatRange) { + constexpr UIntType count = 10'000'001; + constexpr UIntType step = UIntType(-1) / count; + for (UIntType i = 0, v = 0; i <= count; ++i, v += step) { + float x = *reinterpret_cast(&v); + if (isnan(x) || (x < 0)) { + continue; + } + + ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrtf(x), 0.5); + } +} diff --git a/libc/test/src/math/sqrtl_test.cpp b/libc/test/src/math/sqrtl_test.cpp new file mode 100644 --- /dev/null +++ b/libc/test/src/math/sqrtl_test.cpp @@ -0,0 +1,67 @@ +//===-- Unittests for sqrtl ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include "include/math.h" +#include "src/math/sqrtl.h" +#include "utils/FPUtil/FPBits.h" +#include "utils/FPUtil/TestHelpers.h" +#include "utils/MPFRWrapper/MPFRUtils.h" + +using FPBits = __llvm_libc::fputil::FPBits; +using UIntType = typename FPBits::UIntType; + +namespace mpfr = __llvm_libc::testing::mpfr; + +constexpr UIntType HiddenBit = + UIntType(1) << __llvm_libc::fputil::MantissaWidth::value; + +long double nan = FPBits::buildNaN(1); +long double inf = FPBits::inf(); +long double negInf = FPBits::negInf(); + +TEST(SqrtlTest, SpecialValues) { + ASSERT_FP_EQ(nan, __llvm_libc::sqrtl(nan)); + ASSERT_FP_EQ(inf, __llvm_libc::sqrtl(inf)); + ASSERT_FP_EQ(nan, __llvm_libc::sqrtl(negInf)); + ASSERT_FP_EQ(0.0L, __llvm_libc::sqrtl(0.0L)); + ASSERT_FP_EQ(-0.0L, __llvm_libc::sqrtl(-0.0L)); + ASSERT_FP_EQ(nan, __llvm_libc::sqrtl(-1.0L)); + ASSERT_FP_EQ(1.0L, __llvm_libc::sqrtl(1.0L)); + ASSERT_FP_EQ(2.0L, __llvm_libc::sqrtl(4.0L)); + ASSERT_FP_EQ(3.0L, __llvm_libc::sqrtl(9.0L)); +} + +TEST(SqrtlTest, DenormalValues) { + for (UIntType mant = 1; mant < HiddenBit; mant <<= 1) { + FPBits denormal(0.0L); + denormal.mantissa = mant; + + ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, static_cast(denormal), + __llvm_libc::sqrtl(denormal), 0.5); + } + + constexpr UIntType count = 1'000'001; + constexpr UIntType step = HiddenBit / count; + for (UIntType i = 0, v = 0; i <= count; ++i, v += step) { + long double x = *reinterpret_cast(&v); + ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrtl(x), 0.5); + } +} + +TEST(SqrtlTest, InLongDoubleRange) { + constexpr UIntType count = 10'000'001; + constexpr UIntType step = UIntType(-1) / count; + for (UIntType i = 0, v = 0; i <= count; ++i, v += step) { + long double x = *reinterpret_cast(&v); + if (isnan(x) || (x < 0)) { + continue; + } + + ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrtl(x), 0.5); + } +} diff --git a/libc/utils/FPUtil/Sqrt.h b/libc/utils/FPUtil/Sqrt.h new file mode 100644 --- /dev/null +++ b/libc/utils/FPUtil/Sqrt.h @@ -0,0 +1,186 @@ +//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_UTILS_FPUTIL_SQRT_H +#define LLVM_LIBC_UTILS_FPUTIL_SQRT_H + +#include "FPBits.h" + +#include "utils/CPP/TypeTraits.h" + +namespace __llvm_libc { +namespace fputil { + +namespace internal { + +template +static inline void normalize(int &exponent, + typename FPBits::UIntType &mantissa); + +template <> inline void normalize(int &exponent, uint32_t &mantissa) { + // Use binary search to shift the leading 1 bit. + // With MantissaWidth = 23, it will take + // ceil(log2(23)) = 5 steps checking the mantissa bits as followed: + // Step 1: 0000 0000 0000 XXXX XXXX XXXX + // Step 2: 0000 00XX XXXX XXXX XXXX XXXX + // Step 3: 000X XXXX XXXX XXXX XXXX XXXX + // Step 4: 00XX XXXX XXXX XXXX XXXX XXXX + // Step 5: 0XXX XXXX XXXX XXXX XXXX XXXX + constexpr int nsteps = 5; // = ceil(log2(MantissaWidth)) + constexpr uint32_t bounds[nsteps] = {1 << 12, 1 << 18, 1 << 21, 1 << 22, + 1 << 23}; + constexpr int shifts[nsteps] = {12, 6, 3, 2, 1}; + + for (int i = 0; i < nsteps; ++i) { + if (mantissa < bounds[i]) { + exponent -= shifts[i]; + mantissa <<= shifts[i]; + } + } +} + +template <> inline void normalize(int &exponent, uint64_t &mantissa) { + // Use binary search to shift the leading 1 bit similar to float. + // With MantissaWidth = 52, it will take + // ceil(log2(52)) = 6 steps checking the mantissa bits. + constexpr int nsteps = 6; // = ceil(log2(MantissaWidth)) + constexpr uint64_t bounds[nsteps] = {1ULL << 26, 1ULL << 39, 1ULL << 46, + 1ULL << 49, 1ULL << 51, 1ULL << 52}; + constexpr int shifts[nsteps] = {27, 14, 7, 4, 2, 1}; + + for (int i = 0; i < nsteps; ++i) { + if (mantissa < bounds[i]) { + exponent -= shifts[i]; + mantissa <<= shifts[i]; + } + } +} + +#if !(defined(__x86_64__) || defined(__i386__)) +template <> +inline void normalize(int &exponent, __uint128_t &mantissa) { + // Use binary search to shift the leading 1 bit similar to float. + // With MantissaWidth = 112, it will take + // ceil(log2(112)) = 7 steps checking the mantissa bits. + constexpr int nsteps = 7; // = ceil(log2(MantissaWidth)) + constexpr __uint128_t bounds[nsteps] = { + __uint128_t(1) << 56, __uint128_t(1) << 84, __uint128_t(1) << 98, + __uint128_t(1) << 105, __uint128_t(1) << 109, __uint128_t(1) << 111, + __uint128_t(1) << 112}; + constexpr int shifts[nsteps] = {57, 29, 15, 8, 4, 2, 1}; + + for (int i = 0; i < nsteps; ++i) { + if (mantissa < bounds[i]) { + exponent -= shifts[i]; + mantissa <<= shifts[i]; + } + } +} +#endif + +} // namespace internal + +// Correctly rounded IEEE 754 SQRT with round to nearest, ties to even. +// Shift-and-add algorithm. +template ::Value, int> = 0> +static inline T sqrt(T x) { + using UIntType = typename FPBits::UIntType; + constexpr UIntType One = UIntType(1) << MantissaWidth::value; + + FPBits bits(x); + + if (bits.isInfOrNaN()) { + if (bits.sign && (bits.mantissa == 0)) { + // sqrt(-Inf) = NaN + return FPBits::buildNaN(One >> 1); + } else { + // sqrt(NaN) = NaN + // sqrt(+Inf) = +Inf + return x; + } + } else if (bits.isZero()) { + // sqrt(+0) = +0 + // sqrt(-0) = -0 + return x; + } else if (bits.sign) { + // sqrt( negative numbers ) = NaN + return FPBits::buildNaN(One >> 1); + } else { + int xExp = bits.getExponent(); + UIntType xMant = bits.mantissa; + + // Step 1a: Normalize denormal input and append hiddent bit to the mantissa + if (bits.exponent == 0) { + ++xExp; // let xExp be the correct exponent of One bit. + internal::normalize(xExp, xMant); + } else { + xMant |= One; + } + + // Step 1b: Make sure the exponent is even. + if (xExp & 1) { + --xExp; + xMant <<= 1; + } + + // After step 1b, x = 2^(xExp) * xMant, where xExp is even, and + // 1 <= xMant < 4. So sqrt(x) = 2^(xExp / 2) * y, with 1 <= y < 2. + // Notice that the output of sqrt is always in the normal range. + // To perform shift-and-add algorithm to find y, let denote: + // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be: + // r(n) = 2^n ( xMant - y(n)^2 ). + // That leads to the following recurrence formula: + // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ] + // with the initial conditions: y(0) = 1, and r(0) = x - 1. + // So the nth digit y_n of the mantissa of sqrt(x) can be found by: + // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1) + // 0 otherwise. + UIntType y = One; + UIntType r = xMant - One; + + for (UIntType current_bit = One >> 1; current_bit; current_bit >>= 1) { + r <<= 1; + UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1) + if (r >= tmp) { + r -= tmp; + y += current_bit; + } + } + + // We compute one more iteration in order to round correctly. + bool lsb = y & 1; // Least significant bit + bool rb = false; // Round bit + r <<= 2; + UIntType tmp = (y << 2) + 1; + if (r >= tmp) { + r -= tmp; + rb = true; + } + + // Remove hidden bit and append the exponent field. + xExp = ((xExp >> 1) + FPBits::exponentBias); + + y = (y - One) | (static_cast(xExp) << MantissaWidth::value); + // Round to nearest, ties to even + if (rb && (lsb || (r != 0))) { + ++y; + } + + return *reinterpret_cast(&y); + } +} + +} // namespace fputil +} // namespace __llvm_libc + +#if (defined(__x86_64__) || defined(__i386__)) +#include "SqrtLongDoubleX86.h" +#endif // defined(__x86_64__) || defined(__i386__) + +#endif // LLVM_LIBC_UTILS_FPUTIL_SQRT_H diff --git a/libc/utils/FPUtil/SqrtLongDoubleX86.h b/libc/utils/FPUtil/SqrtLongDoubleX86.h new file mode 100644 --- /dev/null +++ b/libc/utils/FPUtil/SqrtLongDoubleX86.h @@ -0,0 +1,142 @@ +//===-- Square root of x86 long double numbers ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_UTILS_FPUTIL_SQRT_LONG_DOUBLE_X86_H +#define LLVM_LIBC_UTILS_FPUTIL_SQRT_LONG_DOUBLE_X86_H + +#include "FPBits.h" +#include "utils/CPP/TypeTraits.h" + +namespace __llvm_libc { +namespace fputil { + +#if (defined(__x86_64__) || defined(__i386__)) +namespace internal { + +template <> +inline void normalize(int &exponent, __uint128_t &mantissa) { + // Use binary search to shift the leading 1 bit similar to float. + // With MantissaWidth = 63, it will take + // ceil(log2(63)) = 6 steps checking the mantissa bits. + constexpr int nsteps = 6; // = ceil(log2(MantissaWidth)) + constexpr __uint128_t bounds[nsteps] = { + __uint128_t(1) << 32, __uint128_t(1) << 48, __uint128_t(1) << 56, + __uint128_t(1) << 60, __uint128_t(1) << 62, __uint128_t(1) << 63}; + constexpr int shifts[nsteps] = {32, 16, 8, 4, 2, 1}; + + for (int i = 0; i < nsteps; ++i) { + if (mantissa < bounds[i]) { + exponent -= shifts[i]; + mantissa <<= shifts[i]; + } + } +} + +} // namespace internal + +// Correctly rounded SQRT with round to nearest, ties to even. +// Shift-and-add algorithm. +template <> inline long double sqrt(long double x) { + using UIntType = typename FPBits::UIntType; + constexpr UIntType One = UIntType(1) + << int(MantissaWidth::value); + + FPBits bits(x); + + if (bits.isInfOrNaN()) { + if (bits.sign && (bits.mantissa == 0)) { + // sqrt(-Inf) = NaN + return FPBits::buildNaN(One >> 1); + } else { + // sqrt(NaN) = NaN + // sqrt(+Inf) = +Inf + return x; + } + } else if (bits.isZero()) { + // sqrt(+0) = +0 + // sqrt(-0) = -0 + return x; + } else if (bits.sign) { + // sqrt( negative numbers ) = NaN + return FPBits::buildNaN(One >> 1); + } else { + int xExp = bits.getExponent(); + UIntType xMant = bits.mantissa; + + // Step 1a: Normalize denormal input + if (bits.implicitBit) { + xMant |= One; + } else if (bits.exponent == 0) { + internal::normalize(xExp, xMant); + } + + // Step 1b: Make sure the exponent is even. + if (xExp & 1) { + --xExp; + xMant <<= 1; + } + + // After step 1b, x = 2^(xExp) * xMant, where xExp is even, and + // 1 <= xMant < 4. So sqrt(x) = 2^(xExp / 2) * y, with 1 <= y < 2. + // Notice that the output of sqrt is always in the normal range. + // To perform shift-and-add algorithm to find y, let denote: + // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be: + // r(n) = 2^n ( xMant - y(n)^2 ). + // That leads to the following recurrence formula: + // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ] + // with the initial conditions: y(0) = 1, and r(0) = x - 1. + // So the nth digit y_n of the mantissa of sqrt(x) can be found by: + // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1) + // 0 otherwise. + UIntType y = One; + UIntType r = xMant - One; + + for (UIntType current_bit = One >> 1; current_bit; current_bit >>= 1) { + r <<= 1; + UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1) + if (r >= tmp) { + r -= tmp; + y += current_bit; + } + } + + // We compute one more iteration in order to round correctly. + bool lsb = y & 1; // Least significant bit + bool rb = false; // Round bit + r <<= 2; + UIntType tmp = (y << 2) + 1; + if (r >= tmp) { + r -= tmp; + rb = true; + } + + // Append the exponent field. + xExp = ((xExp >> 1) + FPBits::exponentBias); + y |= (static_cast(xExp) + << (MantissaWidth::value + 1)); + + // Round to nearest, ties to even + if (rb && (lsb || (r != 0))) { + ++y; + } + + // Extract output + FPBits out(0.0L); + out.exponent = xExp; + out.implicitBit = 1; + out.mantissa = (y & (One - 1)); + + return out; + } +} +#endif // defined(__x86_64__) || defined(__i386__) + +} // namespace fputil +} // namespace __llvm_libc + +#endif // LLVM_LIBC_UTILS_FPUTIL_SQRT_LONG_DOUBLE_X86_H