diff --git a/libc/src/__support/FPUtil/CMakeLists.txt b/libc/src/__support/FPUtil/CMakeLists.txt --- a/libc/src/__support/FPUtil/CMakeLists.txt +++ b/libc/src/__support/FPUtil/CMakeLists.txt @@ -22,3 +22,14 @@ libc.src.__support.common libc.src.__support.CPP.standalone_cpp ) + +add_header_library( + sqrt + HDRS + sqrt.h + DEPENDS + .fputil + libc.src.__support.FPUtil.generic.sqrt +) + +add_subdirectory(generic) diff --git a/libc/src/__support/FPUtil/Sqrt.h b/libc/src/__support/FPUtil/Sqrt.h deleted file mode 100644 --- a/libc/src/__support/FPUtil/Sqrt.h +++ /dev/null @@ -1,192 +0,0 @@ -//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H -#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H - -#include "FPBits.h" -#include "PlatformDefs.h" - -#include "src/__support/CPP/TypeTraits.h" - -namespace __llvm_libc { -namespace fputil { - -namespace internal { - -template -static inline void normalize(int &exponent, - typename FPBits::UIntType &mantissa); - -template <> inline void normalize(int &exponent, uint32_t &mantissa) { - // Use binary search to shift the leading 1 bit. - // With MantissaWidth = 23, it will take - // ceil(log2(23)) = 5 steps checking the mantissa bits as followed: - // Step 1: 0000 0000 0000 XXXX XXXX XXXX - // Step 2: 0000 00XX XXXX XXXX XXXX XXXX - // Step 3: 000X XXXX XXXX XXXX XXXX XXXX - // Step 4: 00XX XXXX XXXX XXXX XXXX XXXX - // Step 5: 0XXX XXXX XXXX XXXX XXXX XXXX - constexpr int NSTEPS = 5; // = ceil(log2(MantissaWidth)) - constexpr uint32_t BOUNDS[NSTEPS] = {1 << 12, 1 << 18, 1 << 21, 1 << 22, - 1 << 23}; - constexpr int SHIFTS[NSTEPS] = {12, 6, 3, 2, 1}; - - for (int i = 0; i < NSTEPS; ++i) { - if (mantissa < BOUNDS[i]) { - exponent -= SHIFTS[i]; - mantissa <<= SHIFTS[i]; - } - } -} - -template <> inline void normalize(int &exponent, uint64_t &mantissa) { - // Use binary search to shift the leading 1 bit similar to float. - // With MantissaWidth = 52, it will take - // ceil(log2(52)) = 6 steps checking the mantissa bits. - constexpr int NSTEPS = 6; // = ceil(log2(MantissaWidth)) - constexpr uint64_t BOUNDS[NSTEPS] = {1ULL << 26, 1ULL << 39, 1ULL << 46, - 1ULL << 49, 1ULL << 51, 1ULL << 52}; - constexpr int SHIFTS[NSTEPS] = {27, 14, 7, 4, 2, 1}; - - for (int i = 0; i < NSTEPS; ++i) { - if (mantissa < BOUNDS[i]) { - exponent -= SHIFTS[i]; - mantissa <<= SHIFTS[i]; - } - } -} - -#ifdef LONG_DOUBLE_IS_DOUBLE -template <> -inline void normalize(int &exponent, uint64_t &mantissa) { - normalize(exponent, mantissa); -} -#elif !defined(SPECIAL_X86_LONG_DOUBLE) -template <> -inline void normalize(int &exponent, __uint128_t &mantissa) { - // Use binary search to shift the leading 1 bit similar to float. - // With MantissaWidth = 112, it will take - // ceil(log2(112)) = 7 steps checking the mantissa bits. - constexpr int NSTEPS = 7; // = ceil(log2(MantissaWidth)) - constexpr __uint128_t BOUNDS[NSTEPS] = { - __uint128_t(1) << 56, __uint128_t(1) << 84, __uint128_t(1) << 98, - __uint128_t(1) << 105, __uint128_t(1) << 109, __uint128_t(1) << 111, - __uint128_t(1) << 112}; - constexpr int SHIFTS[NSTEPS] = {57, 29, 15, 8, 4, 2, 1}; - - for (int i = 0; i < NSTEPS; ++i) { - if (mantissa < BOUNDS[i]) { - exponent -= SHIFTS[i]; - mantissa <<= SHIFTS[i]; - } - } -} -#endif - -} // namespace internal - -// Correctly rounded IEEE 754 SQRT with round to nearest, ties to even. -// Shift-and-add algorithm. -template ::Value, int> = 0> -static inline T sqrt(T x) { - using UIntType = typename FPBits::UIntType; - constexpr UIntType ONE = UIntType(1) << MantissaWidth::VALUE; - - FPBits bits(x); - - if (bits.is_inf_or_nan()) { - if (bits.get_sign() && (bits.get_mantissa() == 0)) { - // sqrt(-Inf) = NaN - return FPBits::build_nan(ONE >> 1); - } else { - // sqrt(NaN) = NaN - // sqrt(+Inf) = +Inf - return x; - } - } else if (bits.is_zero()) { - // sqrt(+0) = +0 - // sqrt(-0) = -0 - return x; - } else if (bits.get_sign()) { - // sqrt( negative numbers ) = NaN - return FPBits::build_nan(ONE >> 1); - } else { - int x_exp = bits.get_exponent(); - UIntType x_mant = bits.get_mantissa(); - - // Step 1a: Normalize denormal input and append hidden bit to the mantissa - if (bits.get_unbiased_exponent() == 0) { - ++x_exp; // let x_exp be the correct exponent of ONE bit. - internal::normalize(x_exp, x_mant); - } else { - x_mant |= ONE; - } - - // Step 1b: Make sure the exponent is even. - if (x_exp & 1) { - --x_exp; - x_mant <<= 1; - } - - // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and - // 1 <= x_mant < 4. So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2. - // Notice that the output of sqrt is always in the normal range. - // To perform shift-and-add algorithm to find y, let denote: - // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be: - // r(n) = 2^n ( x_mant - y(n)^2 ). - // That leads to the following recurrence formula: - // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ] - // with the initial conditions: y(0) = 1, and r(0) = x - 1. - // So the nth digit y_n of the mantissa of sqrt(x) can be found by: - // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1) - // 0 otherwise. - UIntType y = ONE; - UIntType r = x_mant - ONE; - - for (UIntType current_bit = ONE >> 1; current_bit; current_bit >>= 1) { - r <<= 1; - UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1) - if (r >= tmp) { - r -= tmp; - y += current_bit; - } - } - - // We compute one more iteration in order to round correctly. - bool lsb = y & 1; // Least significant bit - bool rb = false; // Round bit - r <<= 2; - UIntType tmp = (y << 2) + 1; - if (r >= tmp) { - r -= tmp; - rb = true; - } - - // Remove hidden bit and append the exponent field. - x_exp = ((x_exp >> 1) + FPBits::EXPONENT_BIAS); - - y = (y - ONE) | (static_cast(x_exp) << MantissaWidth::VALUE); - // Round to nearest, ties to even - if (rb && (lsb || (r != 0))) { - ++y; - } - - return *reinterpret_cast(&y); - } -} - -} // namespace fputil -} // namespace __llvm_libc - -#ifdef SPECIAL_X86_LONG_DOUBLE -#include "x86_64/SqrtLongDouble.h" -#endif // SPECIAL_X86_LONG_DOUBLE - -#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H diff --git a/libc/src/__support/FPUtil/aarch64/sqrt.h b/libc/src/__support/FPUtil/aarch64/sqrt.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/FPUtil/aarch64/sqrt.h @@ -0,0 +1,38 @@ +//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_AARCH64_SQRT_H +#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_AARCH64_SQRT_H + +#include "src/__support/architectures.h" + +#if !defined(LLVM_LIBC_ARCH_AARCH64) +#error "Invalid include" +#endif + +#include "src/__support/FPUtil/generic/sqrt.h" + +namespace __llvm_libc { +namespace fputil { + +template <> inline float sqrt(float x) { + float y; + __asm__ __volatile__("fsqrt %s0, %s1\n\t" : "=w"(y) : "w"(x)); + return y; +} + +template <> inline double sqrt(double x) { + double y; + __asm__ __volatile__("fsqrt %d0, %d1\n\t" : "=w"(y) : "w"(x)); + return y; +} + +} // namespace fputil +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_AARCH64_SQRT_H diff --git a/libc/src/__support/FPUtil/generic/CMakeLists.txt b/libc/src/__support/FPUtil/generic/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/libc/src/__support/FPUtil/generic/CMakeLists.txt @@ -0,0 +1,6 @@ +add_header_library( + sqrt + HDRS + sqrt.h + sqrt_80_bit_long_double.h +) diff --git a/libc/src/__support/FPUtil/generic/sqrt.h b/libc/src/__support/FPUtil/generic/sqrt.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/FPUtil/generic/sqrt.h @@ -0,0 +1,214 @@ +//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_H +#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_H + +#include "sqrt_80_bit_long_double.h" +#include "src/__support/CPP/TypeTraits.h" +#include "src/__support/FPUtil/FEnvImpl.h" +#include "src/__support/FPUtil/FPBits.h" +#include "src/__support/FPUtil/PlatformDefs.h" + +namespace __llvm_libc { +namespace fputil { + +namespace internal { + +template struct SpecialLongDouble { + static constexpr bool VALUE = false; +}; + +#if defined(SPECIAL_X86_LONG_DOUBLE) +template <> struct SpecialLongDouble { + static constexpr bool VALUE = true; +}; +#endif // SPECIAL_X86_LONG_DOUBLE + +template +static inline void normalize(int &exponent, + typename FPBits::UIntType &mantissa); + +template <> inline void normalize(int &exponent, uint32_t &mantissa) { + // Use binary search to shift the leading 1 bit. + // With MantissaWidth = 23, it will take + // ceil(log2(23)) = 5 steps checking the mantissa bits as followed: + // Step 1: 0000 0000 0000 XXXX XXXX XXXX + // Step 2: 0000 00XX XXXX XXXX XXXX XXXX + // Step 3: 000X XXXX XXXX XXXX XXXX XXXX + // Step 4: 00XX XXXX XXXX XXXX XXXX XXXX + // Step 5: 0XXX XXXX XXXX XXXX XXXX XXXX + constexpr int NSTEPS = 5; // = ceil(log2(MantissaWidth)) + constexpr uint32_t BOUNDS[NSTEPS] = {1 << 12, 1 << 18, 1 << 21, 1 << 22, + 1 << 23}; + constexpr int SHIFTS[NSTEPS] = {12, 6, 3, 2, 1}; + + for (int i = 0; i < NSTEPS; ++i) { + if (mantissa < BOUNDS[i]) { + exponent -= SHIFTS[i]; + mantissa <<= SHIFTS[i]; + } + } +} + +template <> inline void normalize(int &exponent, uint64_t &mantissa) { + // Use binary search to shift the leading 1 bit similar to float. + // With MantissaWidth = 52, it will take + // ceil(log2(52)) = 6 steps checking the mantissa bits. + constexpr int NSTEPS = 6; // = ceil(log2(MantissaWidth)) + constexpr uint64_t BOUNDS[NSTEPS] = {1ULL << 26, 1ULL << 39, 1ULL << 46, + 1ULL << 49, 1ULL << 51, 1ULL << 52}; + constexpr int SHIFTS[NSTEPS] = {27, 14, 7, 4, 2, 1}; + + for (int i = 0; i < NSTEPS; ++i) { + if (mantissa < BOUNDS[i]) { + exponent -= SHIFTS[i]; + mantissa <<= SHIFTS[i]; + } + } +} + +#ifdef LONG_DOUBLE_IS_DOUBLE +template <> +inline void normalize(int &exponent, uint64_t &mantissa) { + normalize(exponent, mantissa); +} +#elif !defined(SPECIAL_X86_LONG_DOUBLE) +template <> +inline void normalize(int &exponent, __uint128_t &mantissa) { + // Use binary search to shift the leading 1 bit similar to float. + // With MantissaWidth = 112, it will take + // ceil(log2(112)) = 7 steps checking the mantissa bits. + constexpr int NSTEPS = 7; // = ceil(log2(MantissaWidth)) + constexpr __uint128_t BOUNDS[NSTEPS] = { + __uint128_t(1) << 56, __uint128_t(1) << 84, __uint128_t(1) << 98, + __uint128_t(1) << 105, __uint128_t(1) << 109, __uint128_t(1) << 111, + __uint128_t(1) << 112}; + constexpr int SHIFTS[NSTEPS] = {57, 29, 15, 8, 4, 2, 1}; + + for (int i = 0; i < NSTEPS; ++i) { + if (mantissa < BOUNDS[i]) { + exponent -= SHIFTS[i]; + mantissa <<= SHIFTS[i]; + } + } +} +#endif + +} // namespace internal + +// Correctly rounded IEEE 754 SQRT for all rounding modes. +// Shift-and-add algorithm. +template +static inline cpp::EnableIfType::Value, T> +sqrt(T x) { + + if constexpr (internal::SpecialLongDouble::VALUE) { + // Special 80-bit long double. + return x86::sqrt(x); + } else { + // IEEE floating points formats. + using UIntType = typename FPBits::UIntType; + constexpr UIntType ONE = UIntType(1) << MantissaWidth::VALUE; + + FPBits bits(x); + + if (bits.is_inf_or_nan()) { + if (bits.get_sign() && (bits.get_mantissa() == 0)) { + // sqrt(-Inf) = NaN + return FPBits::build_nan(ONE >> 1); + } else { + // sqrt(NaN) = NaN + // sqrt(+Inf) = +Inf + return x; + } + } else if (bits.is_zero()) { + // sqrt(+0) = +0 + // sqrt(-0) = -0 + return x; + } else if (bits.get_sign()) { + // sqrt( negative numbers ) = NaN + return FPBits::build_nan(ONE >> 1); + } else { + int x_exp = bits.get_exponent(); + UIntType x_mant = bits.get_mantissa(); + + // Step 1a: Normalize denormal input and append hidden bit to the mantissa + if (bits.get_unbiased_exponent() == 0) { + ++x_exp; // let x_exp be the correct exponent of ONE bit. + internal::normalize(x_exp, x_mant); + } else { + x_mant |= ONE; + } + + // Step 1b: Make sure the exponent is even. + if (x_exp & 1) { + --x_exp; + x_mant <<= 1; + } + + // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and + // 1 <= x_mant < 4. So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2. + // Notice that the output of sqrt is always in the normal range. + // To perform shift-and-add algorithm to find y, let denote: + // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be: + // r(n) = 2^n ( x_mant - y(n)^2 ). + // That leads to the following recurrence formula: + // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ] + // with the initial conditions: y(0) = 1, and r(0) = x - 1. + // So the nth digit y_n of the mantissa of sqrt(x) can be found by: + // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1) + // 0 otherwise. + UIntType y = ONE; + UIntType r = x_mant - ONE; + + for (UIntType current_bit = ONE >> 1; current_bit; current_bit >>= 1) { + r <<= 1; + UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1) + if (r >= tmp) { + r -= tmp; + y += current_bit; + } + } + + // We compute one more iteration in order to round correctly. + bool lsb = y & 1; // Least significant bit + bool rb = false; // Round bit + r <<= 2; + UIntType tmp = (y << 2) + 1; + if (r >= tmp) { + r -= tmp; + rb = true; + } + + // Remove hidden bit and append the exponent field. + x_exp = ((x_exp >> 1) + FPBits::EXPONENT_BIAS); + + y = (y - ONE) | (static_cast(x_exp) << MantissaWidth::VALUE); + + switch (get_round()) { + case FE_TONEAREST: + // Round to nearest, ties to even + if (rb && (lsb || (r != 0))) + ++y; + break; + case FE_UPWARD: + if (rb || (r != 0)) + ++y; + break; + } + + return *reinterpret_cast(&y); + } + } +} + +} // namespace fputil +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_H diff --git a/libc/src/__support/FPUtil/x86_64/SqrtLongDouble.h b/libc/src/__support/FPUtil/generic/sqrt_80_bit_long_double.h rename from libc/src/__support/FPUtil/x86_64/SqrtLongDouble.h rename to libc/src/__support/FPUtil/generic/sqrt_80_bit_long_double.h --- a/libc/src/__support/FPUtil/x86_64/SqrtLongDouble.h +++ b/libc/src/__support/FPUtil/generic/sqrt_80_bit_long_double.h @@ -6,26 +6,18 @@ // //===----------------------------------------------------------------------===// -#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_LONG_DOUBLE_H -#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_LONG_DOUBLE_H +#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_80_BIT_LONG_DOUBLE_H +#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_80_BIT_LONG_DOUBLE_H -#include "src/__support/architectures.h" - -#if !defined(LLVM_LIBC_ARCH_X86) -#error "Invalid include" -#endif - -#include "src/__support/CPP/TypeTraits.h" +#include "src/__support/FPUtil/FEnvImpl.h" #include "src/__support/FPUtil/FPBits.h" -#include "src/__support/FPUtil/Sqrt.h" +#include "src/__support/FPUtil/PlatformDefs.h" namespace __llvm_libc { namespace fputil { +namespace x86 { -namespace internal { - -template <> -inline void normalize(int &exponent, __uint128_t &mantissa) { +inline void normalize(int &exponent, __uint128_t &mantissa) { // Use binary search to shift the leading 1 bit similar to float. // With MantissaWidth = 63, it will take // ceil(log2(63)) = 6 steps checking the mantissa bits. @@ -43,11 +35,14 @@ } } -} // namespace internal +// if constexpr statement in sqrt.h still requires x86::sqrt to be declared +// even when it's not used. +static inline long double sqrt(long double x); -// Correctly rounded SQRT with round to nearest, ties to even. +// Correctly rounded SQRT for all rounding modes. // Shift-and-add algorithm. -template <> inline long double sqrt(long double x) { +#if defined(SPECIAL_X86_LONG_DOUBLE) +static inline long double sqrt(long double x) { using UIntType = typename FPBits::UIntType; constexpr UIntType ONE = UIntType(1) << int(MantissaWidth::VALUE); @@ -78,7 +73,7 @@ if (bits.get_implicit_bit()) { x_mant |= ONE; } else if (bits.get_unbiased_exponent() == 0) { - internal::normalize(x_exp, x_mant); + normalize(x_exp, x_mant); } // Step 1b: Make sure the exponent is even. @@ -126,9 +121,16 @@ y |= (static_cast(x_exp) << (MantissaWidth::VALUE + 1)); - // Round to nearest, ties to even - if (rb && (lsb || (r != 0))) { - ++y; + switch (get_round()) { + case FE_TONEAREST: + // Round to nearest, ties to even + if (rb && (lsb || (r != 0))) + ++y; + break; + case FE_UPWARD: + if (rb || (r != 0)) + ++y; + break; } // Extract output @@ -140,8 +142,10 @@ return out; } } +#endif // SPECIAL_X86_LONG_DOUBLE +} // namespace x86 } // namespace fputil } // namespace __llvm_libc -#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_LONG_DOUBLE_H +#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_80_BIT_LONG_DOUBLE_H diff --git a/libc/src/__support/FPUtil/sqrt.h b/libc/src/__support/FPUtil/sqrt.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/FPUtil/sqrt.h @@ -0,0 +1,22 @@ +//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H +#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H + +#include "src/__support/architectures.h" + +#if defined(LLVM_LIBC_ARCH_X86_64) +#include "x86_64/sqrt.h" +#elif defined(LLVM_LIBC_ARCH_AARCH64) +#include "aarch64/sqrt.h" +#else +#include "generic/sqrt.h" + +#endif +#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H diff --git a/libc/src/__support/FPUtil/x86_64/sqrt.h b/libc/src/__support/FPUtil/x86_64/sqrt.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/FPUtil/x86_64/sqrt.h @@ -0,0 +1,44 @@ +//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_H +#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_H + +#include "src/__support/architectures.h" + +#if !defined(LLVM_LIBC_ARCH_X86) +#error "Invalid include" +#endif + +#include "src/__support/FPUtil/generic/sqrt.h" + +namespace __llvm_libc { +namespace fputil { + +template <> inline float sqrt(float x) { + float result; + __asm__ __volatile__("sqrtss %x1, %x0" : "=x"(result) : "x"(x)); + return result; +} + +template <> inline double sqrt(double x) { + double result; + __asm__ __volatile__("sqrtsd %x1, %x0" : "=x"(result) : "x"(x)); + return result; +} + +template <> inline long double sqrt(long double x) { + long double result; + __asm__ __volatile__("fsqrt" : "=t"(result) : "t"(x)); + return result; +} + +} // namespace fputil +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_H diff --git a/libc/src/math/aarch64/CMakeLists.txt b/libc/src/math/aarch64/CMakeLists.txt --- a/libc/src/math/aarch64/CMakeLists.txt +++ b/libc/src/math/aarch64/CMakeLists.txt @@ -77,23 +77,3 @@ COMPILE_OPTIONS -O2 ) - -add_entrypoint_object( - sqrt - SRCS - sqrt.cpp - HDRS - ../sqrt.h - COMPILE_OPTIONS - -O2 -) - -add_entrypoint_object( - sqrtf - SRCS - sqrtf.cpp - HDRS - ../sqrtf.h - COMPILE_OPTIONS - -O2 -) diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt --- a/libc/src/math/generic/CMakeLists.txt +++ b/libc/src/math/generic/CMakeLists.txt @@ -859,8 +859,10 @@ ../sqrt.h DEPENDS libc.src.__support.FPUtil.fputil + libc.src.__support.FPUtil.sqrt COMPILE_OPTIONS - -O2 + -O3 + -Wno-c++17-extensions ) add_entrypoint_object( @@ -871,8 +873,10 @@ ../sqrtf.h DEPENDS libc.src.__support.FPUtil.fputil + libc.src.__support.FPUtil.sqrt COMPILE_OPTIONS - -O2 + -O3 + -Wno-c++17-extensions ) add_entrypoint_object( @@ -883,8 +887,10 @@ ../sqrtl.h DEPENDS libc.src.__support.FPUtil.fputil + libc.src.__support.FPUtil.sqrt COMPILE_OPTIONS - -O2 + -O3 + -Wno-c++17-extensions ) add_entrypoint_object( diff --git a/libc/src/math/generic/sqrt.cpp b/libc/src/math/generic/sqrt.cpp --- a/libc/src/math/generic/sqrt.cpp +++ b/libc/src/math/generic/sqrt.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "src/math/sqrt.h" -#include "src/__support/FPUtil/Sqrt.h" +#include "src/__support/FPUtil/sqrt.h" #include "src/__support/common.h" namespace __llvm_libc { diff --git a/libc/src/math/generic/sqrtf.cpp b/libc/src/math/generic/sqrtf.cpp --- a/libc/src/math/generic/sqrtf.cpp +++ b/libc/src/math/generic/sqrtf.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "src/math/sqrtf.h" -#include "src/__support/FPUtil/Sqrt.h" +#include "src/__support/FPUtil/sqrt.h" #include "src/__support/common.h" namespace __llvm_libc { diff --git a/libc/src/math/generic/sqrtl.cpp b/libc/src/math/generic/sqrtl.cpp --- a/libc/src/math/generic/sqrtl.cpp +++ b/libc/src/math/generic/sqrtl.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "src/math/sqrtl.h" -#include "src/__support/FPUtil/Sqrt.h" +#include "src/__support/FPUtil/sqrt.h" #include "src/__support/common.h" namespace __llvm_libc { diff --git a/libc/src/math/x86_64/CMakeLists.txt b/libc/src/math/x86_64/CMakeLists.txt --- a/libc/src/math/x86_64/CMakeLists.txt +++ b/libc/src/math/x86_64/CMakeLists.txt @@ -27,33 +27,3 @@ COMPILE_OPTIONS -O2 ) - -add_entrypoint_object( - sqrt - SRCS - sqrt.cpp - HDRS - ../sqrt.h - COMPILE_OPTIONS - -O2 -) - -add_entrypoint_object( - sqrtf - SRCS - sqrtf.cpp - HDRS - ../sqrtf.h - COMPILE_OPTIONS - -O2 -) - -add_entrypoint_object( - sqrtl - SRCS - sqrtl.cpp - HDRS - ../sqrtl.h - COMPILE_OPTIONS - -O2 -) diff --git a/libc/src/math/x86_64/sqrt.cpp b/libc/src/math/x86_64/sqrt.cpp deleted file mode 100644 --- a/libc/src/math/x86_64/sqrt.cpp +++ /dev/null @@ -1,20 +0,0 @@ -//===-- Implementation of the sqrt function for x86_64 --------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "src/math/sqrt.h" -#include "src/__support/common.h" - -namespace __llvm_libc { - -LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { - double result; - __asm__ __volatile__("sqrtsd %x1, %x0" : "=x"(result) : "x"(x)); - return result; -} - -} // namespace __llvm_libc diff --git a/libc/src/math/x86_64/sqrtf.cpp b/libc/src/math/x86_64/sqrtf.cpp deleted file mode 100644 --- a/libc/src/math/x86_64/sqrtf.cpp +++ /dev/null @@ -1,20 +0,0 @@ -//===-- Implementation of the sqrtf function for x86_64 -------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "src/math/sqrtf.h" -#include "src/__support/common.h" - -namespace __llvm_libc { - -LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) { - float result; - __asm__ __volatile__("sqrtss %x1, %x0" : "=x"(result) : "x"(x)); - return result; -} - -} // namespace __llvm_libc diff --git a/libc/src/math/x86_64/sqrtl.cpp b/libc/src/math/x86_64/sqrtl.cpp deleted file mode 100644 --- a/libc/src/math/x86_64/sqrtl.cpp +++ /dev/null @@ -1,20 +0,0 @@ -//===-- Implementation of the sqrtl function for x86_64 -------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "src/math/sqrtl.h" -#include "src/__support/common.h" - -namespace __llvm_libc { - -LLVM_LIBC_FUNCTION(long double, sqrtl, (long double x)) { - long double result; - __asm__ __volatile__("fsqrt" : "=t"(result) : "t"(x)); - return result; -} - -} // namespace __llvm_libc diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt --- a/libc/test/src/math/CMakeLists.txt +++ b/libc/test/src/math/CMakeLists.txt @@ -983,26 +983,63 @@ libc.src.__support.FPUtil.fputil ) -# The quad precision test for sqrt against MPFR currently suffers -# from insufficient precision in MPFR calculations leading to -# https://hal.archives-ouvertes.fr/hal-01091186/document. We will -# renable after fixing the precision issue. -if(${LIBC_TARGET_ARCHITECTURE_IS_X86}) - add_fp_unittest( - sqrtl_test - NEED_MPFR - SUITE - libc_math_unittests - SRCS - sqrtl_test.cpp - DEPENDS - libc.include.math - libc.src.math.sqrtl - libc.src.__support.FPUtil.fputil - ) -else() - message(STATUS "Skipping sqrtl_test") -endif() +add_fp_unittest( + sqrtl_test + NEED_MPFR + SUITE + libc_math_unittests + SRCS + sqrtl_test.cpp + DEPENDS + libc.include.math + libc.src.math.sqrtl + libc.src.__support.FPUtil.fputil +) + +add_fp_unittest( + generic_sqrtf_test + NEED_MPFR + SUITE + libc_math_unittests + SRCS + generic_sqrtf_test.cpp + DEPENDS + libc.src.__support.FPUtil.fputil + libc.src.__support.FPUtil.generic.sqrt + COMPILE_OPTIONS + -O3 + -Wno-c++17-extensions +) + +add_fp_unittest( + generic_sqrt_test + NEED_MPFR + SUITE + libc_math_unittests + SRCS + generic_sqrt_test.cpp + DEPENDS + libc.src.__support.FPUtil.fputil + libc.src.__support.FPUtil.generic.sqrt + COMPILE_OPTIONS + -O3 + -Wno-c++17-extensions +) + +add_fp_unittest( + generic_sqrtl_test + NEED_MPFR + SUITE + libc_math_unittests + SRCS + generic_sqrtl_test.cpp + DEPENDS + libc.src.__support.FPUtil.fputil + libc.src.__support.FPUtil.generic.sqrt + COMPILE_OPTIONS + -O3 + -Wno-c++17-extensions +) add_fp_unittest( remquof_test diff --git a/libc/src/math/generic/sqrt.cpp b/libc/test/src/math/generic_sqrt_test.cpp copy from libc/src/math/generic/sqrt.cpp copy to libc/test/src/math/generic_sqrt_test.cpp --- a/libc/src/math/generic/sqrt.cpp +++ b/libc/test/src/math/generic_sqrt_test.cpp @@ -1,4 +1,4 @@ -//===-- Implementation of sqrt function -----------------------------------===// +//===-- Unittests for generic implementation of sqrt ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,12 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "src/math/sqrt.h" -#include "src/__support/FPUtil/Sqrt.h" -#include "src/__support/common.h" +#include "SqrtTest.h" -namespace __llvm_libc { +#include "src/__support/FPUtil/generic/sqrt.h" -LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt(x); } - -} // namespace __llvm_libc +LIST_SQRT_TESTS(double, __llvm_libc::fputil::sqrt) diff --git a/libc/src/math/generic/sqrt.cpp b/libc/test/src/math/generic_sqrtf_test.cpp copy from libc/src/math/generic/sqrt.cpp copy to libc/test/src/math/generic_sqrtf_test.cpp --- a/libc/src/math/generic/sqrt.cpp +++ b/libc/test/src/math/generic_sqrtf_test.cpp @@ -1,4 +1,4 @@ -//===-- Implementation of sqrt function -----------------------------------===// +//===-- Unittests for generic implementation of sqrtf----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,12 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "src/math/sqrt.h" -#include "src/__support/FPUtil/Sqrt.h" -#include "src/__support/common.h" +#include "SqrtTest.h" -namespace __llvm_libc { +#include "src/__support/FPUtil/generic/sqrt.h" -LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt(x); } - -} // namespace __llvm_libc +LIST_SQRT_TESTS(float, __llvm_libc::fputil::sqrt) diff --git a/libc/src/math/generic/sqrt.cpp b/libc/test/src/math/generic_sqrtl_test.cpp copy from libc/src/math/generic/sqrt.cpp copy to libc/test/src/math/generic_sqrtl_test.cpp --- a/libc/src/math/generic/sqrt.cpp +++ b/libc/test/src/math/generic_sqrtl_test.cpp @@ -1,4 +1,4 @@ -//===-- Implementation of sqrt function -----------------------------------===// +//===-- Unittests for generic implementation of sqrtl----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,12 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "src/math/sqrt.h" -#include "src/__support/FPUtil/Sqrt.h" -#include "src/__support/common.h" +#include "SqrtTest.h" -namespace __llvm_libc { +#include "src/__support/FPUtil/generic/sqrt.h" -LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt(x); } - -} // namespace __llvm_libc +LIST_SQRT_TESTS(long double, __llvm_libc::fputil::sqrt) diff --git a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel @@ -80,7 +80,6 @@ "src/__support/FPUtil/NearestIntegerOperations.h", "src/__support/FPUtil/NormalFloat.h", "src/__support/FPUtil/PlatformDefs.h", - "src/__support/FPUtil/Sqrt.h", ] fputil_hdrs = selects.with_or({ @@ -88,7 +87,6 @@ PLATFORM_CPU_X86_64: fputil_common_hdrs + [ "src/__support/FPUtil/x86_64/LongDoubleBits.h", "src/__support/FPUtil/x86_64/NextAfterLongDouble.h", - "src/__support/FPUtil/x86_64/SqrtLongDouble.h", "src/__support/FPUtil/x86_64/FEnvImpl.h", ], PLATFORM_CPU_ARM64: fputil_common_hdrs + [ @@ -106,6 +104,31 @@ ], ) +sqrt_common_hdrs = [ + "src/__support/FPUtil/sqrt.h", + "src/__support/FPUtil/generic/sqrt.h", + "src/__support/FPUtil/generic/sqrt_80_bit_long_double.h", +] + +sqrt_hdrs = selects.with_or({ + "//conditions:default": sqrt_common_hdrs, + PLATFORM_CPU_X86_64: sqrt_common_hdrs + [ + "src/__support/FPUtil/x86_64/sqrt.h", + ], + PLATFORM_CPU_ARM64: sqrt_common_hdrs + [ + "src/__support/FPUtil/aarch64/sqrt.h", + ], +}) + +cc_library( + name = "__support_fputil_sqrt", + hdrs = sqrt_hdrs, + deps = [ + ":__support_fputil", + ":libc_root", + ], +) + ################################ fenv targets ################################ libc_function( @@ -438,28 +461,23 @@ libc_math_function( name = "sqrt", - specializations = [ - "aarch64", - "generic", - "x86_64", - ], + additional_deps = [ + ":__support_fputil_sqrt", + ] ) libc_math_function( name = "sqrtf", - specializations = [ - "aarch64", - "generic", - "x86_64", - ], + additional_deps = [ + ":__support_fputil_sqrt", + ] ) libc_math_function( name = "sqrtl", - specializations = [ - "generic", - "x86_64", - ], + additional_deps = [ + ":__support_fputil_sqrt", + ] ) libc_math_function(name = "copysign")