diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h --- a/libc/utils/MPFRWrapper/MPFRUtils.h +++ b/libc/utils/MPFRWrapper/MPFRUtils.h @@ -48,6 +48,7 @@ Floor, Round, Sin, + Sqrt, Trunc }; @@ -56,6 +57,9 @@ template bool compare(Operation op, T input, T libcOutput, const Tolerance &t); +template +bool compare(Operation op, T input, T libcOutput, double t); + template class MPFRMatcher : public testing::Matcher { static_assert(__llvm_libc::cpp::IsFloatingPointType::Value, "MPFRMatcher can only be used with floating point values."); @@ -64,14 +68,21 @@ T input; Tolerance tolerance; T matchValue; + double ulpTolerance; + bool useULP; public: MPFRMatcher(Operation op, T testInput, Tolerance &t) - : operation(op), input(testInput), tolerance(t) {} + : operation(op), input(testInput), tolerance(t), useULP(false) {} + MPFRMatcher(Operation op, T testInput, double ulpTolerance) + : operation(op), input(testInput), ulpTolerance(ulpTolerance), + useULP(true) {} bool match(T libcResult) { matchValue = libcResult; - return internal::compare(operation, input, libcResult, tolerance); + return (useULP + ? internal::compare(operation, input, libcResult, ulpTolerance) + : internal::compare(operation, input, libcResult, tolerance)); } void explainError(testutils::StreamWrapper &OS) override; @@ -79,9 +90,12 @@ } // namespace internal -template +template __attribute__((no_sanitize("address"))) -internal::MPFRMatcher getMPFRMatcher(Operation op, T input, Tolerance t) { +typename cpp::EnableIfType || + cpp::IsSameV, + internal::MPFRMatcher> +getMPFRMatcher(Operation op, T input, U t) { static_assert( __llvm_libc::cpp::IsFloatingPointType::Value, "getMPFRMatcher can only be used to match floating point results."); diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp --- a/libc/utils/MPFRWrapper/MPFRUtils.cpp +++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp @@ -9,6 +9,7 @@ #include "MPFRUtils.h" #include "utils/FPUtil/FPBits.h" +#include "utils/FPUtil/TestHelpers.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -119,6 +120,9 @@ case Operation::Sin: mpfr_sin(value, mpfrInput.value, MPFR_RNDN); break; + case Operation::Sqrt: + mpfr_sqrt(value, mpfrInput.value, MPFR_RNDN); + break; case Operation::Trunc: mpfr_trunc(value, mpfrInput.value); break; @@ -157,29 +161,81 @@ // These functions are useful for debugging. float asFloat() const { return mpfr_get_flt(value, MPFR_RNDN); } double asDouble() const { return mpfr_get_d(value, MPFR_RNDN); } + long double asLongDouble() const { return mpfr_get_ld(value, MPFR_RNDN); } + void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); } + + // Return the ULP (units-in-the-last-place) difference between the + // stored MPFR and a floating point number. + // We define: + // ULP(mpfr_value, value) = abs(mpfr_value - value) / eps(value) + // ULP < 0.5 will imply that the value is correctly rounded. + template + cpp::EnableIfType::Value, double> ulp(T input) { + fputil::FPBits bits(input); + MPFRNumber mpfrInput(input); + + // abs(value - input) + mpfr_sub(mpfrInput.value, value, mpfrInput.value, MPFR_RNDN); + mpfr_abs(mpfrInput.value, mpfrInput.value, MPFR_RNDN); + + // get eps(input) + int epsExponent = + bits.getExponent() - __llvm_libc::fputil::MantissaWidth::value; + if (bits.exponent == 0) { + // correcting denormal exponent + ++epsExponent; + } + + // Since eps(value) is of the form 2^e, instead of dividing such number, + // we multiply by its inverse 2^{-e}. + mpfr_mul_2si(mpfrInput.value, mpfrInput.value, -epsExponent, MPFR_RNDN); + + return mpfrInput.asDouble(); + } }; namespace internal { +template T toFloatingPoint(MPFRNumber mpfrInput); + +template <> float toFloatingPoint(MPFRNumber mpfrInput) { + return mpfrInput.asFloat(); +} + +template <> double toFloatingPoint(MPFRNumber mpfrInput) { + return mpfrInput.asDouble(); +} + +template <> long double toFloatingPoint(MPFRNumber mpfrInput) { + return mpfrInput.asLongDouble(); +} + template void MPFRMatcher::explainError(testutils::StreamWrapper &OS) { MPFRNumber mpfrResult(operation, input); MPFRNumber mpfrInput(input); MPFRNumber mpfrMatchValue(matchValue); - MPFRNumber mpfrToleranceValue(matchValue, tolerance); FPBits inputBits(input); FPBits matchBits(matchValue); - // TODO: Call to llvm::utohexstr implicitly converts __uint128_t values to - // uint64_t values. This can be fixed using a custom wrapper for - // llvm::utohexstr to handle __uint128_t values correctly. + FPBits mpfrResultBits(toFloatingPoint(mpfrResult)); OS << "Match value not within tolerance value of MPFR result:\n" - << " Input decimal: " << mpfrInput.str() << '\n' - << " Input bits: 0x" << llvm::utohexstr(inputBits.bitsAsUInt()) << '\n' - << " Match decimal: " << mpfrMatchValue.str() << '\n' - << " Match bits: 0x" << llvm::utohexstr(matchBits.bitsAsUInt()) << '\n' - << " MPFR result: " << mpfrResult.str() << '\n' - << "Tolerance value: " << mpfrToleranceValue.str() << '\n'; + << " Input decimal: " << mpfrInput.str() << '\n'; + __llvm_libc::fputil::testing::describeValue(" Input bits: ", input, OS); + OS << '\n' << " Match decimal: " << mpfrMatchValue.str() << '\n'; + __llvm_libc::fputil::testing::describeValue(" Match bits: ", matchValue, + OS); + OS << '\n' << " MPFR result: " << mpfrResult.str() << '\n'; + __llvm_libc::fputil::testing::describeValue( + " MPFR rounded: ", toFloatingPoint(mpfrResult), OS); + OS << '\n'; + if (useULP) { + OS << " ULP errors: " << std::to_string(mpfrResult.ulp(matchValue)) + << '\n'; + } else { + MPFRNumber mpfrToleranceValue = MPFRNumber(matchValue, tolerance); + OS << "Tolerance value: " << mpfrToleranceValue.str() << '\n'; + } } template void MPFRMatcher::explainError(testutils::StreamWrapper &); @@ -201,6 +257,16 @@ template bool compare(Operation, long double, long double, const Tolerance &); +template +bool compare(Operation op, T input, T libcResult, double ulpError) { + MPFRNumber mpfrResult(op, input); + return mpfrResult.ulp(libcResult) <= ulpError; +} + +template bool compare(Operation, float, float, double); +template bool compare(Operation, double, double, double); +template bool compare(Operation, long double, long double, double); + } // namespace internal } // namespace mpfr