diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp --- a/libc/utils/MPFRWrapper/MPFRUtils.cpp +++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp @@ -8,15 +8,55 @@ #include "MPFRUtils.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include +#include #include namespace __llvm_libc { namespace testing { namespace mpfr { +template struct FloatProperties {}; + +template <> struct FloatProperties { + typedef uint32_t BitsType; + static_assert(sizeof(BitsType) == sizeof(float), + "Unexpected size of 'float' type."); + + static constexpr uint32_t mantissaWidth = 23; + static constexpr uint32_t signMask = 0x7FFFFFFFU; + static constexpr uint32_t exponentOffset = 127; +}; + +template <> struct FloatProperties { + typedef uint64_t BitsType; + static_assert(sizeof(BitsType) == sizeof(double), + "Unexpected size of 'double' type."); + + static constexpr uint32_t mantissaWidth = 52; + static constexpr uint64_t signMask = 0x7FFFFFFFFFFFFFFFULL; + static constexpr uint32_t exponentOffset = 1023; +}; + +template typename FloatProperties::BitsType getBits(T x) { + using BitsType = typename FloatProperties::BitsType; + return *reinterpret_cast(&x); +} + +// Returns the zero adjusted exponent value of abs(x). +template int getExponent(T x) { + using Properties = FloatProperties; + using BitsType = typename Properties::BitsType; + BitsType bits = *reinterpret_cast(&x); + bits &= Properties::signMask; // Zero the sign bit. + int e = (bits >> Properties::mantissaWidth); // Shift out the mantissa. + e -= Properties::exponentOffset; // Zero adjust + return e; +} + class MPFRNumber { // A precision value which allows sufficiently large additional // precision even compared to double precision floating point values. @@ -44,6 +84,38 @@ mpfr_set_d(value, x, MPFR_RNDN); } + template ::Value, int> = 0> + explicit MPFRNumber(XType x) { + mpfr_init2(value, mpfrPrecision); + mpfr_set_sj(value, x, MPFR_RNDN); + } + + template MPFRNumber(XType x, const Tolerance &t) { + mpfr_init2(value, mpfrPrecision); + mpfr_set_zero(value, 1); // Set to positive zero. + MPFRNumber xExponent(getExponent(x)); + // E = 2^E + mpfr_exp2(xExponent.value, xExponent.value, MPFR_RNDN); + uint32_t bitMask = 1 << (t.width - 1); + for (int n = -t.basePrecision; bitMask > 0; bitMask >>= 1) { + --n; + if (t.bits & bitMask) { + // delta = -n + MPFRNumber delta(n); + + // delta = 2^(-n) + mpfr_exp2(delta.value, delta.value, MPFR_RNDN); + + // delta = E * 2^(-n) + mpfr_mul(delta.value, delta.value, xExponent.value, MPFR_RNDN); + + // tolerance += delta + mpfr_add(value, value, delta.value, MPFR_RNDN); + } + } + } + template ::Value, int> = 0> MPFRNumber(Operation op, XType rawValue) { @@ -65,20 +137,9 @@ ~MPFRNumber() { mpfr_clear(value); } - // Returns true if |other| is within the tolerance value |t| of this + // Returns true if |other| is within the |tolerance| value of this // number. - bool isEqual(const MPFRNumber &other, const Tolerance &t) { - MPFRNumber tolerance(0.0); - uint32_t bitMask = 1 << (t.width - 1); - for (int exponent = -t.basePrecision; bitMask > 0; bitMask >>= 1) { - --exponent; - if (t.bits & bitMask) { - MPFRNumber delta; - mpfr_set_ui_2exp(delta.value, 1, exponent, MPFR_RNDN); - mpfr_add(tolerance.value, tolerance.value, delta.value, MPFR_RNDN); - } - } - + bool isEqual(const MPFRNumber &other, const MPFRNumber &tolerance) const { MPFRNumber difference; if (mpfr_cmp(value, other.value) >= 0) mpfr_sub(difference.value, value, other.value, MPFR_RNDN); @@ -112,10 +173,14 @@ MPFRNumber mpfrResult(operation, input); MPFRNumber mpfrInput(input); MPFRNumber mpfrMatchValue(matchValue); + MPFRNumber mpfrToleranceValue(matchValue, tolerance); OS << "Match value not within tolerance value of MPFR result:\n" - << "Operation input: " << mpfrInput.str() << '\n' - << " Match value: " << mpfrMatchValue.str() << '\n' - << " MPFR result: " << mpfrResult.str() << '\n'; + << " Input decimal: " << mpfrInput.str() << '\n' + << " Input bits: 0x" << llvm::utohexstr(getBits(input)) << '\n' + << " Match decimal: " << mpfrMatchValue.str() << '\n' + << " Match bits: 0x" << llvm::utohexstr(getBits(matchValue)) << '\n' + << " MPFR result: " << mpfrResult.str() << '\n' + << "Tolerance value: " << mpfrToleranceValue.str() << '\n'; } template void MPFRMatcher::explainError(testutils::StreamWrapper &); @@ -126,7 +191,9 @@ MPFRNumber mpfrResult(op, input); MPFRNumber mpfrInput(input); MPFRNumber mpfrLibcResult(libcResult); - return mpfrResult.isEqual(mpfrLibcResult, t); + MPFRNumber mpfrToleranceValue(libcResult, t); + + return mpfrResult.isEqual(mpfrLibcResult, mpfrToleranceValue); }; template bool compare(Operation, float, float, const Tolerance &);