diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp --- a/libc/utils/MPFRWrapper/MPFRUtils.cpp +++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp @@ -11,11 +11,39 @@ #include "llvm/Support/raw_ostream.h" #include +#include namespace __llvm_libc { namespace testing { namespace mpfr { +template struct FloatProperties {}; + +template <> struct FloatProperties { + typedef uint32_t BitsType; + static constexpr uint32_t mantissaWidth = 23; + static constexpr uint32_t signMask = 0x7FFFFFFF; + static constexpr uint32_t exponentOffset = 127; +}; + +template <> struct FloatProperties { + typedef uint64_t BitsType; + static constexpr uint32_t mantissaWidth = 52; + static constexpr uint64_t signMask = 0x7FFFFFFF; + static constexpr uint32_t exponentOffset = 1023; +}; + +// Returns the zero adjusted exponent value of abs(x). +template int getExponent(T x) { + using Properties = FloatProperties; + using BitsType = typename Properties::BitsType; + BitsType bits = *reinterpret_cast(&x); + bits &= Properties::signMask; // Zero the sign bit. + int e = (bits >> Properties::mantissaWidth); // Shift out the mantissa. + e -= Properties::exponentOffset; // Zero adjust + return e; +} + class MPFRNumber { // A precision value which allows sufficiently large additional // precision even compared to double precision floating point values. @@ -43,6 +71,13 @@ mpfr_set_d(value, x, MPFR_RNDN); } + template ::Value, int> = 0> + explicit MPFRNumber(XType x) { + mpfr_init2(value, mpfrPrecision); + mpfr_set_sj(value, x, MPFR_RNDN); + } + MPFRNumber(const MPFRNumber &other) { mpfr_set(value, other.value, MPFR_RNDN); } @@ -51,24 +86,43 @@ // Returns true if |other| is within the tolerance value |t| of this // number. - bool isEqual(const MPFRNumber &other, const Tolerance &t) { - MPFRNumber tolerance(0.0); + template bool isEqual(T other, const Tolerance &t) { + MPFRNumber otherExponent(getExponent(other)); + MPFRNumber exp2otherExponent; + // exp2otherExponent = 2^(exponent of other) + mpfr_exp2(exp2otherExponent.value, otherExponent.value, MPFR_RNDN); + + MPFRNumber tolerance(0.0f); uint32_t bitMask = 1 << (t.width - 1); for (int exponent = -t.basePrecision; bitMask > 0; bitMask >>= 1) { --exponent; if (t.bits & bitMask) { - MPFRNumber delta; - mpfr_set_ui_2exp(delta.value, 1, exponent, MPFR_RNDN); + // delta = -n where n is the n-th adittional bit. + MPFRNumber delta(exponent); + + // delta = 2^(-n) + mpfr_exp2(delta.value, delta.value, MPFR_RNDN); + + // delta = 2^E * 2^(-n) where E is the exponent of other) + mpfr_mul(delta.value, delta.value, exp2otherExponent.value, MPFR_RNDN); + + // tolerance += delta mpfr_add(tolerance.value, tolerance.value, delta.value, MPFR_RNDN); } } + MPFRNumber mpfrOther(other); MPFRNumber difference; - if (mpfr_cmp(value, other.value) >= 0) - mpfr_sub(difference.value, value, other.value, MPFR_RNDN); + if (mpfr_cmp(value, mpfrOther.value) >= 0) + mpfr_sub(difference.value, value, mpfrOther.value, MPFR_RNDN); else - mpfr_sub(difference.value, other.value, value, MPFR_RNDN); + mpfr_sub(difference.value, mpfrOther.value, value, MPFR_RNDN); + if (!mpfr_lessequal_p(difference.value, tolerance.value)) { + tolerance.dump("Tolr:"); + dump("This:"); + mpfrOther.dump("Othr:"); + } return mpfr_lessequal_p(difference.value, tolerance.value); } @@ -105,7 +159,7 @@ break; } - if (!mpfrResult.isEqual(mpfrLibcResult, t)) { + if (!mpfrResult.isEqual(libcResult, t)) { llvm::outs() << llvm::raw_ostream::RED << "Libc result is not within tolerance value of the MPFR " << "result:\n"