diff --git a/libc/test/src/math/SqrtTest.h b/libc/test/src/math/SqrtTest.h --- a/libc/test/src/math/SqrtTest.h +++ b/libc/test/src/math/SqrtTest.h @@ -1,4 +1,4 @@ -//===-- Utility class to test fabs[f|l] -------------------------*- C++ -*-===// +//===-- Utility class to test sqrt[f|l] -------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -24,7 +24,7 @@ public: typedef T (*SqrtFunc)(T); - void testSpecialNumbers(SqrtFunc func) { + void test_special_numbers(SqrtFunc func) { ASSERT_FP_EQ(aNaN, func(aNaN)); ASSERT_FP_EQ(inf, func(inf)); ASSERT_FP_EQ(aNaN, func(neg_inf)); @@ -36,24 +36,23 @@ ASSERT_FP_EQ(T(3.0), func(T(9.0))); } - void testDenormalValues(SqrtFunc func) { + void test_denormal_values(SqrtFunc func) { for (UIntType mant = 1; mant < HIDDEN_BIT; mant <<= 1) { FPBits denormal(T(0.0)); denormal.set_mantissa(mant); - ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, T(denormal), func(T(denormal)), - T(0.5)); + test_all_rounding_modes(func, T(denormal)); } constexpr UIntType COUNT = 1'000'001; constexpr UIntType STEP = HIDDEN_BIT / COUNT; for (UIntType i = 0, v = 0; i <= COUNT; ++i, v += STEP) { T x = *reinterpret_cast(&v); - ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5); + test_all_rounding_modes(func, x); } } - void testNormalRange(SqrtFunc func) { + void test_normal_range(SqrtFunc func) { constexpr UIntType COUNT = 10'000'001; constexpr UIntType STEP = UIntType(-1) / COUNT; for (UIntType i = 0, v = 0; i <= COUNT; ++i, v += STEP) { @@ -61,13 +60,31 @@ if (isnan(x) || (x < 0)) { continue; } - ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5); + test_all_rounding_modes(func, x); } } + + void test_all_rounding_modes(SqrtFunc func, T x) { + mpfr::ForceRoundingMode r1(mpfr::RoundingMode::Nearest); + EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5, + mpfr::RoundingMode::Nearest); + + mpfr::ForceRoundingMode r2(mpfr::RoundingMode::Upward); + EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5, + mpfr::RoundingMode::Upward); + + mpfr::ForceRoundingMode r3(mpfr::RoundingMode::Downward); + EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5, + mpfr::RoundingMode::Downward); + + mpfr::ForceRoundingMode r4(mpfr::RoundingMode::TowardZero); + EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5, + mpfr::RoundingMode::TowardZero); + } }; #define LIST_SQRT_TESTS(T, func) \ using LlvmLibcSqrtTest = SqrtTest; \ - TEST_F(LlvmLibcSqrtTest, SpecialNumbers) { testSpecialNumbers(&func); } \ - TEST_F(LlvmLibcSqrtTest, DenormalValues) { testDenormalValues(&func); } \ - TEST_F(LlvmLibcSqrtTest, NormalRange) { testNormalRange(&func); } + TEST_F(LlvmLibcSqrtTest, SpecialNumbers) { test_special_numbers(&func); } \ + TEST_F(LlvmLibcSqrtTest, DenormalValues) { test_denormal_values(&func); } \ + TEST_F(LlvmLibcSqrtTest, NormalRange) { test_normal_range(&func); } diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h --- a/libc/utils/MPFRWrapper/MPFRUtils.h +++ b/libc/utils/MPFRWrapper/MPFRUtils.h @@ -71,6 +71,18 @@ EndTernaryOperationsSingleOutput, }; +enum class RoundingMode : uint8_t { Upward, Downward, TowardZero, Nearest }; + +int get_fe_rounding(RoundingMode mode); + +struct ForceRoundingMode { + ForceRoundingMode(RoundingMode); + ~ForceRoundingMode(); + + int old_rounding_mode; + int rounding_mode; +}; + template struct BinaryInput { static_assert( __llvm_libc::cpp::IsFloatingPointType::Value, @@ -108,65 +120,72 @@ template bool compare_unary_operation_single_output(Operation op, T input, T libc_output, - double t); + double ulp_tolerance, + RoundingMode rounding); template bool compare_unary_operation_two_outputs(Operation op, T input, const BinaryOutput &libc_output, - double t); + double ulp_tolerance, + RoundingMode rounding); template bool compare_binary_operation_two_outputs(Operation op, const BinaryInput &input, const BinaryOutput &libc_output, - double t); + double ulp_tolerance, + RoundingMode rounding); template bool compare_binary_operation_one_output(Operation op, const BinaryInput &input, - T libc_output, double t); + T libc_output, double ulp_tolerance, + RoundingMode rounding); template bool compare_ternary_operation_one_output(Operation op, const TernaryInput &input, - T libc_output, double t); + T libc_output, double ulp_tolerance, + RoundingMode rounding); template void explain_unary_operation_single_output_error(Operation op, T input, T match_value, + double ulp_tolerance, + RoundingMode rounding, testutils::StreamWrapper &OS); template void explain_unary_operation_two_outputs_error( Operation op, T input, const BinaryOutput &match_value, - testutils::StreamWrapper &OS); + double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS); template void explain_binary_operation_two_outputs_error( Operation op, const BinaryInput &input, - const BinaryOutput &match_value, testutils::StreamWrapper &OS); + const BinaryOutput &match_value, double ulp_tolerance, + RoundingMode rounding, testutils::StreamWrapper &OS); template -void explain_binary_operation_one_output_error(Operation op, - const BinaryInput &input, - T match_value, - testutils::StreamWrapper &OS); +void explain_binary_operation_one_output_error( + Operation op, const BinaryInput &input, T match_value, + double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS); template -void explain_ternary_operation_one_output_error(Operation op, - const TernaryInput &input, - T match_value, - testutils::StreamWrapper &OS); +void explain_ternary_operation_one_output_error( + Operation op, const TernaryInput &input, T match_value, + double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS); template class MPFRMatcher : public testing::Matcher { InputType input; OutputType match_value; double ulp_tolerance; + RoundingMode rounding; public: - MPFRMatcher(InputType testInput, double ulp_tolerance) - : input(testInput), ulp_tolerance(ulp_tolerance) {} + MPFRMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding) + : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {} bool match(OutputType libcResult) { match_value = libcResult; - return match(input, match_value, ulp_tolerance); + return match(input, match_value); } // This method is marked with NOLINT because it the name `explainError` @@ -176,59 +195,64 @@ } private: - template static bool match(T in, T out, double tolerance) { - return compare_unary_operation_single_output(op, in, out, tolerance); + template bool match(T in, T out) { + return compare_unary_operation_single_output(op, in, out, ulp_tolerance, + rounding); } - template - static bool match(T in, const BinaryOutput &out, double tolerance) { - return compare_unary_operation_two_outputs(op, in, out, tolerance); + template bool match(T in, const BinaryOutput &out) { + return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance, + rounding); } - template - static bool match(const BinaryInput &in, T out, double tolerance) { - return compare_binary_operation_one_output(op, in, out, tolerance); + template bool match(const BinaryInput &in, T out) { + return compare_binary_operation_one_output(op, in, out, ulp_tolerance, + rounding); } template - static bool match(BinaryInput in, const BinaryOutput &out, - double tolerance) { - return compare_binary_operation_two_outputs(op, in, out, tolerance); + bool match(BinaryInput in, const BinaryOutput &out) { + return compare_binary_operation_two_outputs(op, in, out, ulp_tolerance, + rounding); } - template - static bool match(const TernaryInput &in, T out, double tolerance) { - return compare_ternary_operation_one_output(op, in, out, tolerance); + template bool match(const TernaryInput &in, T out) { + return compare_ternary_operation_one_output(op, in, out, ulp_tolerance, + rounding); } template - static void explain_error(T in, T out, testutils::StreamWrapper &OS) { - explain_unary_operation_single_output_error(op, in, out, OS); + void explain_error(T in, T out, testutils::StreamWrapper &OS) { + explain_unary_operation_single_output_error(op, in, out, ulp_tolerance, + rounding, OS); } template - static void explain_error(T in, const BinaryOutput &out, - testutils::StreamWrapper &OS) { - explain_unary_operation_two_outputs_error(op, in, out, OS); + void explain_error(T in, const BinaryOutput &out, + testutils::StreamWrapper &OS) { + explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance, + rounding, OS); } template - static void explain_error(const BinaryInput &in, - const BinaryOutput &out, - testutils::StreamWrapper &OS) { - explain_binary_operation_two_outputs_error(op, in, out, OS); + void explain_error(const BinaryInput &in, const BinaryOutput &out, + testutils::StreamWrapper &OS) { + explain_binary_operation_two_outputs_error(op, in, out, ulp_tolerance, + rounding, OS); } template - static void explain_error(const BinaryInput &in, T out, - testutils::StreamWrapper &OS) { - explain_binary_operation_one_output_error(op, in, out, OS); + void explain_error(const BinaryInput &in, T out, + testutils::StreamWrapper &OS) { + explain_binary_operation_one_output_error(op, in, out, ulp_tolerance, + rounding, OS); } template - static void explain_error(const TernaryInput &in, T out, - testutils::StreamWrapper &OS) { - explain_ternary_operation_one_output_error(op, in, out, OS); + void explain_error(const TernaryInput &in, T out, + testutils::StreamWrapper &OS) { + explain_ternary_operation_one_output_error(op, in, out, ulp_tolerance, + rounding, OS); } }; @@ -264,12 +288,12 @@ __attribute__((no_sanitize("address"))) cpp::EnableIfType(), internal::MPFRMatcher> -get_mpfr_matcher(InputType input, OutputType output_unused, double t) { - return internal::MPFRMatcher(input, t); +get_mpfr_matcher(InputType input, OutputType output_unused, + double ulp_tolerance, RoundingMode rounding) { + return internal::MPFRMatcher(input, ulp_tolerance, + rounding); } -enum class RoundingMode : uint8_t { Upward, Downward, TowardZero, Nearest }; - template T round(T x, RoundingMode mode); template bool round_to_long(T x, long &result); @@ -279,12 +303,42 @@ } // namespace testing } // namespace __llvm_libc -#define EXPECT_MPFR_MATCH(op, input, match_value, tolerance) \ +// GET_MPFR_DUMMY_ARG is going to be added to the end of GET_MPFR_MACRO as a +// simple way to avoid the compiler warning `gnu-zero-variadic-macro-arguments`. +#define GET_MPFR_DUMMY_ARG(...) 0 + +#define GET_MPFR_MACRO(__1, __2, __3, __4, __5, __NAME, ...) __NAME + +#define EXPECT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \ + EXPECT_THAT(match_value, \ + __llvm_libc::testing::mpfr::get_mpfr_matcher( \ + input, match_value, ulp_tolerance, \ + __llvm_libc::testing::mpfr::RoundingMode::Nearest)) + +#define EXPECT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ + rounding) \ EXPECT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher( \ - input, match_value, tolerance)) + input, match_value, ulp_tolerance, rounding)) + +#define EXPECT_MPFR_MATCH(...) \ + GET_MPFR_MACRO(__VA_ARGS__, EXPECT_MPFR_MATCH_ROUNDING, \ + EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG) \ + (__VA_ARGS__) -#define ASSERT_MPFR_MATCH(op, input, match_value, tolerance) \ +#define ASSERT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \ + ASSERT_THAT(match_value, \ + __llvm_libc::testing::mpfr::get_mpfr_matcher( \ + input, match_value, ulp_tolerance, \ + __llvm_libc::testing::mpfr::RoundingMode::Nearest)) + +#define ASSERT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ + rounding) \ ASSERT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher( \ - input, match_value, tolerance)) + input, match_value, ulp_tolerance, rounding)) + +#define ASSERT_MPFR_MATCH(...) \ + GET_MPFR_MACRO(__VA_ARGS__, ASSERT_MPFR_MATCH_ROUNDING, \ + ASSERT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG) \ + (__VA_ARGS__) #endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp --- a/libc/utils/MPFRWrapper/MPFRUtils.cpp +++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp @@ -14,6 +14,7 @@ #include "utils/UnitTest/FPMatcher.h" #include +#include #include #include #include @@ -55,141 +56,227 @@ }; #endif +// A precision value which allows sufficiently large additional +// precision compared to the floating point precision. +template struct ExtraPrecision; + +template <> struct ExtraPrecision { + static constexpr unsigned int VALUE = 128; +}; + +template <> struct ExtraPrecision { + static constexpr unsigned int VALUE = 256; +}; + +template <> struct ExtraPrecision { + static constexpr unsigned int VALUE = 256; +}; + +// If the ulp tolerance is less than or equal to 0.5, we would check that the +// result is rounded correctly with respect to the rounding mode by using the +// same precision as the inputs. +template +static inline unsigned int get_precision(double ulp_tolerance) { + if (ulp_tolerance <= 0.5) { + return Precision::VALUE; + } else { + return ExtraPrecision::VALUE; + } +} + +static inline mpfr_rnd_t get_mpfr_rounding_mode(RoundingMode mode) { + switch (mode) { + case RoundingMode::Upward: + return MPFR_RNDU; + break; + case RoundingMode::Downward: + return MPFR_RNDD; + break; + case RoundingMode::TowardZero: + return MPFR_RNDZ; + break; + case RoundingMode::Nearest: + return MPFR_RNDN; + break; + } +} + +int get_fe_rounding(RoundingMode mode) { + switch (mode) { + case RoundingMode::Upward: + return FE_UPWARD; + break; + case RoundingMode::Downward: + return FE_DOWNWARD; + break; + case RoundingMode::TowardZero: + return FE_TOWARDZERO; + break; + case RoundingMode::Nearest: + return FE_TONEAREST; + break; + } +} + +ForceRoundingMode::ForceRoundingMode(RoundingMode mode) { + old_rounding_mode = fegetround(); + rounding_mode = get_fe_rounding(mode); + if (old_rounding_mode != rounding_mode) + fesetround(rounding_mode); +} + +ForceRoundingMode::~ForceRoundingMode() { + if (old_rounding_mode != rounding_mode) + fesetround(old_rounding_mode); +} + class MPFRNumber { - // A precision value which allows sufficiently large additional - // precision even compared to quad-precision floating point values. unsigned int mpfr_precision; + mpfr_rnd_t mpfr_rounding; mpfr_t value; public: - MPFRNumber() : mpfr_precision(256) { mpfr_init2(value, mpfr_precision); } + MPFRNumber() : mpfr_precision(256), mpfr_rounding(MPFR_RNDN) { + mpfr_init2(value, mpfr_precision); + } // We use explicit EnableIf specializations to disallow implicit // conversions. Implicit conversions can potentially lead to loss of // precision. template ::Value, int> = 0> - explicit MPFRNumber(XType x, int precision = 128) - : mpfr_precision(precision) { + explicit MPFRNumber(XType x, int precision = ExtraPrecision::VALUE, + RoundingMode rounding = RoundingMode::Nearest) + : mpfr_precision(precision), + mpfr_rounding(get_mpfr_rounding_mode(rounding)) { mpfr_init2(value, mpfr_precision); - mpfr_set_flt(value, x, MPFR_RNDN); + mpfr_set_flt(value, x, mpfr_rounding); } template ::Value, int> = 0> - explicit MPFRNumber(XType x, int precision = 128) - : mpfr_precision(precision) { + explicit MPFRNumber(XType x, int precision = ExtraPrecision::VALUE, + RoundingMode rounding = RoundingMode::Nearest) + : mpfr_precision(precision), + mpfr_rounding(get_mpfr_rounding_mode(rounding)) { mpfr_init2(value, mpfr_precision); - mpfr_set_d(value, x, MPFR_RNDN); + mpfr_set_d(value, x, mpfr_rounding); } template ::Value, int> = 0> - explicit MPFRNumber(XType x, int precision = 128) - : mpfr_precision(precision) { + explicit MPFRNumber(XType x, int precision = ExtraPrecision::VALUE, + RoundingMode rounding = RoundingMode::Nearest) + : mpfr_precision(precision), + mpfr_rounding(get_mpfr_rounding_mode(rounding)) { mpfr_init2(value, mpfr_precision); - mpfr_set_ld(value, x, MPFR_RNDN); + mpfr_set_ld(value, x, mpfr_rounding); } template ::Value, int> = 0> - explicit MPFRNumber(XType x, int precision = 128) - : mpfr_precision(precision) { + explicit MPFRNumber(XType x, int precision = ExtraPrecision::VALUE, + RoundingMode rounding = RoundingMode::Nearest) + : mpfr_precision(precision), + mpfr_rounding(get_mpfr_rounding_mode(rounding)) { mpfr_init2(value, mpfr_precision); - mpfr_set_sj(value, x, MPFR_RNDN); + mpfr_set_sj(value, x, mpfr_rounding); } - MPFRNumber(const MPFRNumber &other) : mpfr_precision(other.mpfr_precision) { + MPFRNumber(const MPFRNumber &other) + : mpfr_precision(other.mpfr_precision), + mpfr_rounding(other.mpfr_rounding) { mpfr_init2(value, mpfr_precision); - mpfr_set(value, other.value, MPFR_RNDN); + mpfr_set(value, other.value, mpfr_rounding); } ~MPFRNumber() { mpfr_clear(value); } MPFRNumber &operator=(const MPFRNumber &rhs) { mpfr_precision = rhs.mpfr_precision; - mpfr_set(value, rhs.value, MPFR_RNDN); + mpfr_rounding = rhs.mpfr_rounding; + mpfr_set(value, rhs.value, mpfr_rounding); return *this; } MPFRNumber abs() const { - MPFRNumber result; - mpfr_abs(result.value, value, MPFR_RNDN); + MPFRNumber result(*this); + mpfr_abs(result.value, value, mpfr_rounding); return result; } MPFRNumber ceil() const { - MPFRNumber result; + MPFRNumber result(*this); mpfr_ceil(result.value, value); return result; } MPFRNumber cos() const { - MPFRNumber result; - mpfr_cos(result.value, value, MPFR_RNDN); + MPFRNumber result(*this); + mpfr_cos(result.value, value, mpfr_rounding); return result; } MPFRNumber exp() const { - MPFRNumber result; - mpfr_exp(result.value, value, MPFR_RNDN); + MPFRNumber result(*this); + mpfr_exp(result.value, value, mpfr_rounding); return result; } MPFRNumber exp2() const { - MPFRNumber result; - mpfr_exp2(result.value, value, MPFR_RNDN); + MPFRNumber result(*this); + mpfr_exp2(result.value, value, mpfr_rounding); return result; } MPFRNumber expm1() const { - MPFRNumber result; - mpfr_expm1(result.value, value, MPFR_RNDN); + MPFRNumber result(*this); + mpfr_expm1(result.value, value, mpfr_rounding); return result; } MPFRNumber floor() const { - MPFRNumber result; + MPFRNumber result(*this); mpfr_floor(result.value, value); return result; } MPFRNumber frexp(int &exp) { - MPFRNumber result; + MPFRNumber result(*this); mpfr_exp_t resultExp; - mpfr_frexp(&resultExp, result.value, value, MPFR_RNDN); + mpfr_frexp(&resultExp, result.value, value, mpfr_rounding); exp = resultExp; return result; } MPFRNumber hypot(const MPFRNumber &b) { - MPFRNumber result; - mpfr_hypot(result.value, value, b.value, MPFR_RNDN); + MPFRNumber result(*this); + mpfr_hypot(result.value, value, b.value, mpfr_rounding); return result; } MPFRNumber log() const { - MPFRNumber result; - mpfr_log(result.value, value, MPFR_RNDN); + MPFRNumber result(*this); + mpfr_log(result.value, value, mpfr_rounding); return result; } MPFRNumber remquo(const MPFRNumber &divisor, int "ient) { - MPFRNumber remainder; + MPFRNumber remainder(*this); long q; - mpfr_remquo(remainder.value, &q, value, divisor.value, MPFR_RNDN); + mpfr_remquo(remainder.value, &q, value, divisor.value, mpfr_rounding); quotient = q; return remainder; } MPFRNumber round() const { - MPFRNumber result; + MPFRNumber result(*this); mpfr_round(result.value, value); return result; } - bool roung_to_long(long &result) const { + bool round_to_long(long &result) const { // We first calculate the rounded value. This way, when converting // to long using mpfr_get_si, the rounding direction of MPFR_RNDN // (or any other rounding mode), does not have an influence. @@ -199,14 +286,14 @@ return mpfr_erangeflag_p(); } - bool roung_to_long(mpfr_rnd_t rnd, long &result) const { - MPFRNumber rint_result; + bool round_to_long(mpfr_rnd_t rnd, long &result) const { + MPFRNumber rint_result(*this); mpfr_rint(rint_result.value, value, rnd); - return rint_result.roung_to_long(result); + return rint_result.round_to_long(result); } MPFRNumber rint(mpfr_rnd_t rnd) const { - MPFRNumber result; + MPFRNumber result(*this); mpfr_rint(result.value, value, rnd); return result; } @@ -239,32 +326,32 @@ } MPFRNumber sin() const { - MPFRNumber result; - mpfr_sin(result.value, value, MPFR_RNDN); + MPFRNumber result(*this); + mpfr_sin(result.value, value, mpfr_rounding); return result; } MPFRNumber sqrt() const { - MPFRNumber result; - mpfr_sqrt(result.value, value, MPFR_RNDN); + MPFRNumber result(*this); + mpfr_sqrt(result.value, value, mpfr_rounding); return result; } MPFRNumber tan() const { - MPFRNumber result; - mpfr_tan(result.value, value, MPFR_RNDN); + MPFRNumber result(*this); + mpfr_tan(result.value, value, mpfr_rounding); return result; } MPFRNumber trunc() const { - MPFRNumber result; + MPFRNumber result(*this); mpfr_trunc(result.value, value); return result; } MPFRNumber fma(const MPFRNumber &b, const MPFRNumber &c) { MPFRNumber result(*this); - mpfr_fma(result.value, value, b.value, c.value, MPFR_RNDN); + mpfr_fma(result.value, value, b.value, c.value, mpfr_rounding); return result; } @@ -282,10 +369,14 @@ // These functions are useful for debugging. template T as() const; - template <> float as() const { return mpfr_get_flt(value, MPFR_RNDN); } - template <> double as() const { return mpfr_get_d(value, MPFR_RNDN); } + template <> float as() const { + return mpfr_get_flt(value, mpfr_rounding); + } + template <> double as() const { + return mpfr_get_d(value, mpfr_rounding); + } template <> long double as() const { - return mpfr_get_ld(value, MPFR_RNDN); + return mpfr_get_ld(value, mpfr_rounding); } void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); } @@ -378,8 +469,9 @@ template cpp::EnableIfType::Value, MPFRNumber> -unary_operation(Operation op, InputType input) { - MPFRNumber mpfrInput(input); +unary_operation(Operation op, InputType input, unsigned int precision, + RoundingMode rounding) { + MPFRNumber mpfrInput(input, precision, rounding); switch (op) { case Operation::Abs: return mpfrInput.abs(); @@ -420,8 +512,9 @@ template cpp::EnableIfType::Value, MPFRNumber> -unary_operation_two_outputs(Operation op, InputType input, int &output) { - MPFRNumber mpfrInput(input); +unary_operation_two_outputs(Operation op, InputType input, int &output, + unsigned int precision, RoundingMode rounding) { + MPFRNumber mpfrInput(input, precision, rounding); switch (op) { case Operation::Frexp: return mpfrInput.frexp(output); @@ -432,8 +525,10 @@ template cpp::EnableIfType::Value, MPFRNumber> -binary_operation_one_output(Operation op, InputType x, InputType y) { - MPFRNumber inputX(x), inputY(y); +binary_operation_one_output(Operation op, InputType x, InputType y, + unsigned int precision, RoundingMode rounding) { + MPFRNumber inputX(x, precision, rounding); + MPFRNumber inputY(y, precision, rounding); switch (op) { case Operation::Hypot: return inputX.hypot(inputY); @@ -445,8 +540,10 @@ template cpp::EnableIfType::Value, MPFRNumber> binary_operation_two_outputs(Operation op, InputType x, InputType y, - int &output) { - MPFRNumber inputX(x), inputY(y); + int &output, unsigned int precision, + RoundingMode rounding) { + MPFRNumber inputX(x, precision, rounding); + MPFRNumber inputY(y, precision, rounding); switch (op) { case Operation::RemQuo: return inputX.remquo(inputY, output); @@ -458,12 +555,14 @@ template cpp::EnableIfType::Value, MPFRNumber> ternary_operation_one_output(Operation op, InputType x, InputType y, - InputType z) { + InputType z, unsigned int precision, + RoundingMode rounding) { // For FMA function, we just need to compare with the mpfr_fma with the same // precision as InputType. Using higher precision as the intermediate results // to compare might incorrectly fail due to double-rounding errors. - constexpr unsigned int prec = Precision::VALUE; - MPFRNumber inputX(x, prec), inputY(y, prec), inputZ(z, prec); + MPFRNumber inputX(x, precision, rounding); + MPFRNumber inputY(y, precision, rounding); + MPFRNumber inputZ(z, precision, rounding); switch (op) { case Operation::Fma: return inputX.fma(inputY, inputZ); @@ -475,13 +574,14 @@ template void explain_unary_operation_single_output_error(Operation op, T input, T matchValue, + double ulp_tolerance, + RoundingMode rounding, testutils::StreamWrapper &OS) { - MPFRNumber mpfrInput(input); - MPFRNumber mpfr_result = unary_operation(op, input); + unsigned int precision = get_precision(ulp_tolerance); + MPFRNumber mpfrInput(input, precision); + MPFRNumber mpfr_result; + mpfr_result = unary_operation(op, input, precision, rounding); MPFRNumber mpfrMatchValue(matchValue); - FPBits inputBits(input); - FPBits matchBits(matchValue); - FPBits mpfr_resultBits(mpfr_result.as()); OS << "Match value not within tolerance value of MPFR result:\n" << " Input decimal: " << mpfrInput.str() << '\n'; __llvm_libc::fputil::testing::describeValue(" Input bits: ", input, OS); @@ -498,21 +598,24 @@ template void explain_unary_operation_single_output_error(Operation op, float, float, + double, RoundingMode, testutils::StreamWrapper &); template void explain_unary_operation_single_output_error( - Operation op, double, double, testutils::StreamWrapper &); + Operation op, double, double, double, RoundingMode, + testutils::StreamWrapper &); template void explain_unary_operation_single_output_error( - Operation op, long double, long double, testutils::StreamWrapper &); + Operation op, long double, long double, double, RoundingMode, + testutils::StreamWrapper &); template void explain_unary_operation_two_outputs_error( Operation op, T input, const BinaryOutput &libc_result, - testutils::StreamWrapper &OS) { - MPFRNumber mpfrInput(input); - FPBits inputBits(input); + double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS) { + unsigned int precision = get_precision(ulp_tolerance); + MPFRNumber mpfrInput(input, precision); int mpfrIntResult; - MPFRNumber mpfr_result = - unary_operation_two_outputs(op, input, mpfrIntResult); + MPFRNumber mpfr_result = unary_operation_two_outputs(op, input, mpfrIntResult, + precision, rounding); if (mpfrIntResult != libc_result.i) { OS << "MPFR integral result: " << mpfrIntResult << '\n' @@ -541,26 +644,26 @@ } template void explain_unary_operation_two_outputs_error( - Operation, float, const BinaryOutput &, testutils::StreamWrapper &); -template void -explain_unary_operation_two_outputs_error(Operation, double, - const BinaryOutput &, - testutils::StreamWrapper &); -template void explain_unary_operation_two_outputs_error( - Operation, long double, const BinaryOutput &, + Operation, float, const BinaryOutput &, double, RoundingMode, testutils::StreamWrapper &); +template void explain_unary_operation_two_outputs_error( + Operation, double, const BinaryOutput &, double, RoundingMode, + testutils::StreamWrapper &); +template void explain_unary_operation_two_outputs_error( + Operation, long double, const BinaryOutput &, double, + RoundingMode, testutils::StreamWrapper &); template void explain_binary_operation_two_outputs_error( Operation op, const BinaryInput &input, - const BinaryOutput &libc_result, testutils::StreamWrapper &OS) { - MPFRNumber mpfrX(input.x); - MPFRNumber mpfrY(input.y); - FPBits xbits(input.x); - FPBits ybits(input.y); + const BinaryOutput &libc_result, double ulp_tolerance, + RoundingMode rounding, testutils::StreamWrapper &OS) { + unsigned int precision = get_precision(ulp_tolerance); + MPFRNumber mpfrX(input.x, precision); + MPFRNumber mpfrY(input.y, precision); int mpfrIntResult; - MPFRNumber mpfr_result = - binary_operation_two_outputs(op, input.x, input.y, mpfrIntResult); + MPFRNumber mpfr_result = binary_operation_two_outputs( + op, input.x, input.y, mpfrIntResult, precision, rounding); MPFRNumber mpfrMatchValue(libc_result.f); OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n' @@ -576,25 +679,27 @@ } template void explain_binary_operation_two_outputs_error( - Operation, const BinaryInput &, const BinaryOutput &, - testutils::StreamWrapper &); + Operation, const BinaryInput &, const BinaryOutput &, double, + RoundingMode, testutils::StreamWrapper &); template void explain_binary_operation_two_outputs_error( Operation, const BinaryInput &, const BinaryOutput &, - testutils::StreamWrapper &); + double, RoundingMode, testutils::StreamWrapper &); template void explain_binary_operation_two_outputs_error( Operation, const BinaryInput &, - const BinaryOutput &, testutils::StreamWrapper &); + const BinaryOutput &, double, RoundingMode, + testutils::StreamWrapper &); template -void explain_binary_operation_one_output_error(Operation op, - const BinaryInput &input, - T libc_result, - testutils::StreamWrapper &OS) { - MPFRNumber mpfrX(input.x); - MPFRNumber mpfrY(input.y); +void explain_binary_operation_one_output_error( + Operation op, const BinaryInput &input, T libc_result, + double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS) { + unsigned int precision = get_precision(ulp_tolerance); + MPFRNumber mpfrX(input.x, precision); + MPFRNumber mpfrY(input.y, precision); FPBits xbits(input.x); FPBits ybits(input.y); - MPFRNumber mpfr_result = binary_operation_one_output(op, input.x, input.y); + MPFRNumber mpfr_result = + binary_operation_one_output(op, input.x, input.y, precision, rounding); MPFRNumber mpfrMatchValue(libc_result); OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n'; @@ -613,26 +718,28 @@ } template void explain_binary_operation_one_output_error( - Operation, const BinaryInput &, float, testutils::StreamWrapper &); + Operation, const BinaryInput &, float, double, RoundingMode, + testutils::StreamWrapper &); template void explain_binary_operation_one_output_error( - Operation, const BinaryInput &, double, testutils::StreamWrapper &); -template void explain_binary_operation_one_output_error( - Operation, const BinaryInput &, long double, + Operation, const BinaryInput &, double, double, RoundingMode, testutils::StreamWrapper &); +template void explain_binary_operation_one_output_error( + Operation, const BinaryInput &, long double, double, + RoundingMode, testutils::StreamWrapper &); template -void explain_ternary_operation_one_output_error(Operation op, - const TernaryInput &input, - T libc_result, - testutils::StreamWrapper &OS) { - MPFRNumber mpfrX(input.x, Precision::VALUE); - MPFRNumber mpfrY(input.y, Precision::VALUE); - MPFRNumber mpfrZ(input.z, Precision::VALUE); +void explain_ternary_operation_one_output_error( + Operation op, const TernaryInput &input, T libc_result, + double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS) { + unsigned int precision = get_precision(ulp_tolerance); + MPFRNumber mpfrX(input.x, precision); + MPFRNumber mpfrY(input.y, precision); + MPFRNumber mpfrZ(input.z, precision); FPBits xbits(input.x); FPBits ybits(input.y); FPBits zbits(input.z); - MPFRNumber mpfr_result = - ternary_operation_one_output(op, input.x, input.y, input.z); + MPFRNumber mpfr_result = ternary_operation_one_output( + op, input.x, input.y, input.z, precision, rounding); MPFRNumber mpfrMatchValue(libc_result); OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() @@ -654,68 +761,70 @@ } template void explain_ternary_operation_one_output_error( - Operation, const TernaryInput &, float, testutils::StreamWrapper &); + Operation, const TernaryInput &, float, double, RoundingMode, + testutils::StreamWrapper &); template void explain_ternary_operation_one_output_error( - Operation, const TernaryInput &, double, + Operation, const TernaryInput &, double, double, RoundingMode, testutils::StreamWrapper &); template void explain_ternary_operation_one_output_error( - Operation, const TernaryInput &, long double, - testutils::StreamWrapper &); + Operation, const TernaryInput &, long double, double, + RoundingMode, testutils::StreamWrapper &); template bool compare_unary_operation_single_output(Operation op, T input, T libc_result, - double ulp_error) { - // If the ulp error is exactly 0.5 (i.e a tie), we would check that the result - // is rounded to the nearest even. - MPFRNumber mpfr_result = unary_operation(op, input); + double ulp_tolerance, + RoundingMode rounding) { + unsigned int precision = get_precision(ulp_tolerance); + MPFRNumber mpfr_result; + mpfr_result = unary_operation(op, input, precision, rounding); double ulp = mpfr_result.ulp(libc_result); - bool bits_are_even = ((FPBits(libc_result).uintval() & 1) == 0); - return (ulp < ulp_error) || - ((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even)); + return (ulp <= ulp_tolerance); } template bool compare_unary_operation_single_output(Operation, float, - float, double); + float, double, + RoundingMode); template bool compare_unary_operation_single_output(Operation, double, - double, double); -template bool compare_unary_operation_single_output(Operation, - long double, - long double, - double); + double, double, + RoundingMode); +template bool compare_unary_operation_single_output( + Operation, long double, long double, double, RoundingMode); template bool compare_unary_operation_two_outputs(Operation op, T input, const BinaryOutput &libc_result, - double ulp_error) { + double ulp_tolerance, + RoundingMode rounding) { int mpfrIntResult; - MPFRNumber mpfr_result = - unary_operation_two_outputs(op, input, mpfrIntResult); + unsigned int precision = get_precision(ulp_tolerance); + MPFRNumber mpfr_result = unary_operation_two_outputs(op, input, mpfrIntResult, + precision, rounding); double ulp = mpfr_result.ulp(libc_result.f); if (mpfrIntResult != libc_result.i) return false; - bool bits_are_even = ((FPBits(libc_result.f).uintval() & 1) == 0); - return (ulp < ulp_error) || - ((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even)); + return (ulp <= ulp_tolerance); } -template bool -compare_unary_operation_two_outputs(Operation, float, - const BinaryOutput &, double); +template bool compare_unary_operation_two_outputs( + Operation, float, const BinaryOutput &, double, RoundingMode); template bool compare_unary_operation_two_outputs( - Operation, double, const BinaryOutput &, double); + Operation, double, const BinaryOutput &, double, RoundingMode); template bool compare_unary_operation_two_outputs( - Operation, long double, const BinaryOutput &, double); + Operation, long double, const BinaryOutput &, double, + RoundingMode); template bool compare_binary_operation_two_outputs(Operation op, const BinaryInput &input, const BinaryOutput &libc_result, - double ulp_error) { + double ulp_tolerance, + RoundingMode rounding) { int mpfrIntResult; - MPFRNumber mpfr_result = - binary_operation_two_outputs(op, input.x, input.y, mpfrIntResult); + unsigned int precision = get_precision(ulp_tolerance); + MPFRNumber mpfr_result = binary_operation_two_outputs( + op, input.x, input.y, mpfrIntResult, precision, rounding); double ulp = mpfr_result.ulp(libc_result.f); if (mpfrIntResult != libc_result.i) { @@ -727,81 +836,66 @@ } } - bool bits_are_even = ((FPBits(libc_result.f).uintval() & 1) == 0); - return (ulp < ulp_error) || - ((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even)); + return (ulp <= ulp_tolerance); } template bool compare_binary_operation_two_outputs( - Operation, const BinaryInput &, const BinaryOutput &, double); + Operation, const BinaryInput &, const BinaryOutput &, double, + RoundingMode); template bool compare_binary_operation_two_outputs( Operation, const BinaryInput &, const BinaryOutput &, - double); + double, RoundingMode); template bool compare_binary_operation_two_outputs( Operation, const BinaryInput &, - const BinaryOutput &, double); + const BinaryOutput &, double, RoundingMode); template bool compare_binary_operation_one_output(Operation op, const BinaryInput &input, - T libc_result, double ulp_error) { - MPFRNumber mpfr_result = binary_operation_one_output(op, input.x, input.y); + T libc_result, double ulp_tolerance, + RoundingMode rounding) { + unsigned int precision = get_precision(ulp_tolerance); + MPFRNumber mpfr_result = + binary_operation_one_output(op, input.x, input.y, precision, rounding); double ulp = mpfr_result.ulp(libc_result); - bool bits_are_even = ((FPBits(libc_result).uintval() & 1) == 0); - return (ulp < ulp_error) || - ((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even)); + return (ulp <= ulp_tolerance); } template bool compare_binary_operation_one_output( - Operation, const BinaryInput &, float, double); + Operation, const BinaryInput &, float, double, RoundingMode); template bool compare_binary_operation_one_output( - Operation, const BinaryInput &, double, double); + Operation, const BinaryInput &, double, double, RoundingMode); template bool compare_binary_operation_one_output( - Operation, const BinaryInput &, long double, double); + Operation, const BinaryInput &, long double, double, + RoundingMode); template bool compare_ternary_operation_one_output(Operation op, const TernaryInput &input, - T libc_result, double ulp_error) { - MPFRNumber mpfr_result = - ternary_operation_one_output(op, input.x, input.y, input.z); + T libc_result, double ulp_tolerance, + RoundingMode rounding) { + unsigned int precision = get_precision(ulp_tolerance); + MPFRNumber mpfr_result = ternary_operation_one_output( + op, input.x, input.y, input.z, precision, rounding); double ulp = mpfr_result.ulp(libc_result); - bool bits_are_even = ((FPBits(libc_result).uintval() & 1) == 0); - return (ulp < ulp_error) || - ((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even)); + return (ulp <= ulp_tolerance); } template bool compare_ternary_operation_one_output( - Operation, const TernaryInput &, float, double); + Operation, const TernaryInput &, float, double, RoundingMode); template bool compare_ternary_operation_one_output( - Operation, const TernaryInput &, double, double); + Operation, const TernaryInput &, double, double, RoundingMode); template bool compare_ternary_operation_one_output( - Operation, const TernaryInput &, long double, double); - -static mpfr_rnd_t get_mpfr_rounding_mode(RoundingMode mode) { - switch (mode) { - case RoundingMode::Upward: - return MPFR_RNDU; - break; - case RoundingMode::Downward: - return MPFR_RNDD; - break; - case RoundingMode::TowardZero: - return MPFR_RNDZ; - break; - case RoundingMode::Nearest: - return MPFR_RNDN; - break; - } -} + Operation, const TernaryInput &, long double, double, + RoundingMode); } // namespace internal template bool round_to_long(T x, long &result) { MPFRNumber mpfr(x); - return mpfr.roung_to_long(result); + return mpfr.round_to_long(result); } template bool round_to_long(float, long &); @@ -810,7 +904,7 @@ template bool round_to_long(T x, RoundingMode mode, long &result) { MPFRNumber mpfr(x); - return mpfr.roung_to_long(internal::get_mpfr_rounding_mode(mode), result); + return mpfr.round_to_long(get_mpfr_rounding_mode(mode), result); } template bool round_to_long(float, RoundingMode, long &); @@ -819,7 +913,7 @@ template T round(T x, RoundingMode mode) { MPFRNumber mpfr(x); - MPFRNumber result = mpfr.rint(internal::get_mpfr_rounding_mode(mode)); + MPFRNumber result = mpfr.rint(get_mpfr_rounding_mode(mode)); return result.as(); }