diff --git a/libc/src/math/generic/hypotf.cpp b/libc/src/math/generic/hypotf.cpp --- a/libc/src/math/generic/hypotf.cpp +++ b/libc/src/math/generic/hypotf.cpp @@ -6,13 +6,57 @@ // //===----------------------------------------------------------------------===// #include "src/math/hypotf.h" -#include "src/__support/FPUtil/Hypot.h" +#include "src/__support/FPUtil/FPBits.h" #include "src/__support/common.h" namespace __llvm_libc { LLVM_LIBC_FUNCTION(float, hypotf, (float x, float y)) { - return __llvm_libc::fputil::hypot(x, y); + using DoubleBits = fputil::FPBits; + using FPBits = fputil::FPBits; + + double xd = static_cast(x); + double yd = static_cast(y); + + // These squares are exact. + double xSq = xd * xd; + double ySq = yd * yd; + + // Compute the sum of squares. + double sumSq = xSq + ySq; + + // Compute the rounding error with Dekker's algorithm. + double err = ((sumSq - xSq) - ySq) + ((sumSq - ySq) - xSq); + + // Take sqrt in double precision. + DoubleBits result(__builtin_sqrt(sumSq)); + + if (!DoubleBits(sumSq).is_inf_or_nan()) { + // Correct rounding. + if (err != 0.0) { + double rSq = static_cast(result) * static_cast(result); + double diff = sumSq - rSq; + constexpr uint64_t mask = 0x0000'0000'3FFF'FFFFULL; + uint64_t lrs = result.uintval() & mask; + + if (lrs == 0x0000'0000'1000'0000ULL && err < diff) { + result.bits |= 1ULL; + } else if (lrs == 0x0000'0000'3000'0000ULL && err > diff) { + result.bits -= 1ULL; + } + } + } else { + FPBits bits_x(x), bits_y(y); + if (bits_x.is_inf_or_nan() || bits_y.is_inf_or_nan()) { + if (bits_x.is_inf() || bits_y.is_inf()) + return static_cast(FPBits::inf()); + if (bits_x.is_nan()) + return x; + return y; + } + } + + return static_cast(static_cast(result)); } } // namespace __llvm_libc