diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h --- a/libc/utils/MPFRWrapper/MPFRUtils.h +++ b/libc/utils/MPFRWrapper/MPFRUtils.h @@ -48,6 +48,7 @@ Floor, Round, Sin, + Sqrt, Trunc }; @@ -56,6 +57,9 @@ template bool compare(Operation op, T input, T libcOutput, const Tolerance &t); +template +bool compare(Operation op, T input, T libcOutput, double t); + template class MPFRMatcher : public testing::Matcher { static_assert(__llvm_libc::cpp::IsFloatingPointType::Value, "MPFRMatcher can only be used with floating point values."); @@ -64,14 +68,21 @@ T input; Tolerance tolerance; T matchValue; + double ulpTolerance; + bool useULP; public: MPFRMatcher(Operation op, T testInput, Tolerance &t) - : operation(op), input(testInput), tolerance(t) {} + : operation(op), input(testInput), tolerance(t), useULP(false) {} + MPFRMatcher(Operation op, T testInput, double ulpTolerance) + : operation(op), input(testInput), ulpTolerance(ulpTolerance), + useULP(true) {} bool match(T libcResult) { matchValue = libcResult; - return internal::compare(operation, input, libcResult, tolerance); + return (useULP + ? internal::compare(operation, input, libcResult, ulpTolerance) + : internal::compare(operation, input, libcResult, tolerance)); } void explainError(testutils::StreamWrapper &OS) override; @@ -79,9 +90,12 @@ } // namespace internal -template +template __attribute__((no_sanitize("address"))) -internal::MPFRMatcher getMPFRMatcher(Operation op, T input, Tolerance t) { +typename cpp::EnableIfType || + cpp::IsSameV, + internal::MPFRMatcher> +getMPFRMatcher(Operation op, T input, U t) { static_assert( __llvm_libc::cpp::IsFloatingPointType::Value, "getMPFRMatcher can only be used to match floating point results."); diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp --- a/libc/utils/MPFRWrapper/MPFRUtils.cpp +++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp @@ -9,6 +9,7 @@ #include "MPFRUtils.h" #include "utils/FPUtil/FPBits.h" +#include "utils/FPUtil/TestHelpers.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -119,6 +120,9 @@ case Operation::Sin: mpfr_sin(value, mpfrInput.value, MPFR_RNDN); break; + case Operation::Sqrt: + mpfr_sqrt(value, mpfrInput.value, MPFR_RNDN); + break; case Operation::Trunc: mpfr_trunc(value, mpfrInput.value); break; @@ -157,29 +161,88 @@ // These functions are useful for debugging. float asFloat() const { return mpfr_get_flt(value, MPFR_RNDN); } double asDouble() const { return mpfr_get_d(value, MPFR_RNDN); } + long double asLongDouble() const { return mpfr_get_ld(value, MPFR_RNDN); } + void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); } + + // Return the ULP (units-in-the-last-place) difference between the + // stored MPFR and a floating point number. + // We define: + // ULP(mpfr_value, value) = abs(mpfr_value - value) / eps(value) + // ULP < 0.5 will imply that the value is correctly rounded. + template + cpp::EnableIfType::Value, double> ulp(T input) { + fputil::FPBits bits(input); + MPFRNumber mpfrInput(input); + + // abs(value - input) + mpfr_sub(mpfrInput.value, value, mpfrInput.value, MPFR_RNDN); + mpfr_abs(mpfrInput.value, mpfrInput.value, MPFR_RNDN); + + // get eps(input) + int epsExponent = bits.exponent - fputil::FPBits::exponentBias - + fputil::MantissaWidth::value; + if (bits.exponent == 0) { + // correcting denormal exponent + ++epsExponent; + } else if ((bits.mantissa == 0) && (bits.exponent > 1) && + mpfr_less_p(value, mpfrInput.value)) { + // when the input is exactly 2^n, distance (epsilon) between the input + // and the next floating point number is different from the distance to + // the previous floating point number. So in that case, if the correct + // value from MPFR is smaller than the input, we use the smaller epsilon + --epsExponent; + } + + // Since eps(value) is of the form 2^e, instead of dividing such number, + // we multiply by its inverse 2^{-e}. + mpfr_mul_2si(mpfrInput.value, mpfrInput.value, -epsExponent, MPFR_RNDN); + + return mpfrInput.asDouble(); + } }; namespace internal { +template T toFloatingPoint(MPFRNumber mpfrInput); + +template <> float toFloatingPoint(MPFRNumber mpfrInput) { + return mpfrInput.asFloat(); +} + +template <> double toFloatingPoint(MPFRNumber mpfrInput) { + return mpfrInput.asDouble(); +} + +template <> long double toFloatingPoint(MPFRNumber mpfrInput) { + return mpfrInput.asLongDouble(); +} + template void MPFRMatcher::explainError(testutils::StreamWrapper &OS) { MPFRNumber mpfrResult(operation, input); MPFRNumber mpfrInput(input); MPFRNumber mpfrMatchValue(matchValue); - MPFRNumber mpfrToleranceValue(matchValue, tolerance); FPBits inputBits(input); FPBits matchBits(matchValue); - // TODO: Call to llvm::utohexstr implicitly converts __uint128_t values to - // uint64_t values. This can be fixed using a custom wrapper for - // llvm::utohexstr to handle __uint128_t values correctly. + FPBits mpfrResultBits(toFloatingPoint(mpfrResult)); OS << "Match value not within tolerance value of MPFR result:\n" - << " Input decimal: " << mpfrInput.str() << '\n' - << " Input bits: 0x" << llvm::utohexstr(inputBits.bitsAsUInt()) << '\n' - << " Match decimal: " << mpfrMatchValue.str() << '\n' - << " Match bits: 0x" << llvm::utohexstr(matchBits.bitsAsUInt()) << '\n' - << " MPFR result: " << mpfrResult.str() << '\n' - << "Tolerance value: " << mpfrToleranceValue.str() << '\n'; + << " Input decimal: " << mpfrInput.str() << '\n'; + __llvm_libc::fputil::testing::describeValue(" Input bits: ", input, OS); + OS << '\n' << " Match decimal: " << mpfrMatchValue.str() << '\n'; + __llvm_libc::fputil::testing::describeValue(" Match bits: ", matchValue, + OS); + OS << '\n' << " MPFR result: " << mpfrResult.str() << '\n'; + __llvm_libc::fputil::testing::describeValue( + " MPFR rounded: ", toFloatingPoint(mpfrResult), OS); + OS << '\n'; + if (useULP) { + OS << " ULP error: " << std::to_string(mpfrResult.ulp(matchValue)) + << '\n'; + } else { + MPFRNumber mpfrToleranceValue = MPFRNumber(matchValue, tolerance); + OS << "Tolerance value: " << mpfrToleranceValue.str() << '\n'; + } } template void MPFRMatcher::explainError(testutils::StreamWrapper &); @@ -201,6 +264,21 @@ template bool compare(Operation, long double, long double, const Tolerance &); +template +bool compare(Operation op, T input, T libcResult, double ulpError) { + // If the ulp error is exactly 0.5 (i.e a tie), we would check that the result + // is rounded to the nearest even. + MPFRNumber mpfrResult(op, input); + double ulp = mpfrResult.ulp(libcResult); + bool bitsAreEven = ((FPBits(libcResult).bitsAsUInt() & 1) == 0); + return (ulp < ulpError) || + ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven)); +} + +template bool compare(Operation, float, float, double); +template bool compare(Operation, double, double, double); +template bool compare(Operation, long double, long double, double); + } // namespace internal } // namespace mpfr