diff --git a/libc/test/src/math/cosf_test.cpp b/libc/test/src/math/cosf_test.cpp --- a/libc/test/src/math/cosf_test.cpp +++ b/libc/test/src/math/cosf_test.cpp @@ -76,7 +76,7 @@ float x = as_float(v); if (isnan(x) || isinf(x)) continue; - EXPECT_TRUE(mpfr::equalsCos(x, __llvm_libc::cosf(x), tolerance)); + ASSERT_MPFR_MATCH(mpfr::OP_Cos, x, __llvm_libc::cosf(x), tolerance); } } @@ -84,12 +84,12 @@ TEST(CosfTest, SmallValues) { float x = as_float(0x17800000); float result = __llvm_libc::cosf(x); - EXPECT_TRUE(mpfr::equalsCos(x, result, tolerance)); + EXPECT_MPFR_MATCH(mpfr::OP_Cos, x, result, tolerance); EXPECT_EQ(FloatBits::One, as_uint32_bits(result)); - x = as_float(0x00400000); + x = as_float(0x0040000); result = __llvm_libc::cosf(x); - EXPECT_TRUE(mpfr::equalsCos(x, result, tolerance)); + EXPECT_MPFR_MATCH(mpfr::OP_Cos, x, result, tolerance); EXPECT_EQ(FloatBits::One, as_uint32_bits(result)); } @@ -98,6 +98,6 @@ TEST(CosfTest, SDCOMP_26094) { for (uint32_t v : sdcomp26094Values) { float x = as_float(v); - EXPECT_TRUE(mpfr::equalsCos(x, __llvm_libc::cosf(x), tolerance)); + ASSERT_MPFR_MATCH(mpfr::OP_Cos, x, __llvm_libc::cosf(x), tolerance); } } diff --git a/libc/test/src/math/sincosf_test.cpp b/libc/test/src/math/sincosf_test.cpp --- a/libc/test/src/math/sincosf_test.cpp +++ b/libc/test/src/math/sincosf_test.cpp @@ -87,8 +87,8 @@ float sin, cos; __llvm_libc::sincosf(x, &sin, &cos); - EXPECT_TRUE(mpfr::equalsCos(x, cos, tolerance)); - EXPECT_TRUE(mpfr::equalsSin(x, sin, tolerance)); + ASSERT_MPFR_MATCH(mpfr::OP_Cos, x, cos, tolerance); + ASSERT_MPFR_MATCH(mpfr::OP_Sin, x, sin, tolerance); } } @@ -98,16 +98,16 @@ float x = as_float(bits); float result_cos, result_sin; __llvm_libc::sincosf(x, &result_sin, &result_cos); - EXPECT_TRUE(mpfr::equalsCos(x, result_cos, tolerance)); - EXPECT_TRUE(mpfr::equalsSin(x, result_sin, tolerance)); + EXPECT_MPFR_MATCH(mpfr::OP_Cos, x, result_cos, tolerance); + EXPECT_MPFR_MATCH(mpfr::OP_Sin, x, result_sin, tolerance); EXPECT_EQ(FloatBits::One, as_uint32_bits(result_cos)); EXPECT_EQ(bits, as_uint32_bits(result_sin)); bits = 0x00400000; x = as_float(bits); __llvm_libc::sincosf(x, &result_sin, &result_cos); - EXPECT_TRUE(mpfr::equalsCos(x, result_cos, tolerance)); - EXPECT_TRUE(mpfr::equalsSin(x, result_sin, tolerance)); + EXPECT_MPFR_MATCH(mpfr::OP_Cos, x, result_cos, tolerance); + EXPECT_MPFR_MATCH(mpfr::OP_Sin, x, result_sin, tolerance); EXPECT_EQ(FloatBits::One, as_uint32_bits(result_cos)); EXPECT_EQ(bits, as_uint32_bits(result_sin)); } @@ -119,7 +119,7 @@ float x = as_float(v); float sin, cos; __llvm_libc::sincosf(x, &sin, &cos); - EXPECT_TRUE(mpfr::equalsCos(x, cos, tolerance)); - EXPECT_TRUE(mpfr::equalsSin(x, sin, tolerance)); + EXPECT_MPFR_MATCH(mpfr::OP_Cos, x, cos, tolerance); + EXPECT_MPFR_MATCH(mpfr::OP_Sin, x, sin, tolerance); } } diff --git a/libc/test/src/math/sinf_test.cpp b/libc/test/src/math/sinf_test.cpp --- a/libc/test/src/math/sinf_test.cpp +++ b/libc/test/src/math/sinf_test.cpp @@ -76,13 +76,13 @@ float x = as_float(v); if (isnan(x) || isinf(x)) continue; - EXPECT_TRUE(mpfr::equalsSin(x, __llvm_libc::sinf(x), tolerance)); + ASSERT_MPFR_MATCH(mpfr::OP_Sin, x, __llvm_libc::sinf(x), tolerance); } } TEST(SinfTest, SpecificBitPatterns) { float x = as_float(0xc70d39a1); - EXPECT_TRUE(mpfr::equalsSin(x, __llvm_libc::sinf(x), tolerance)); + EXPECT_MPFR_MATCH(mpfr::OP_Sin, x, __llvm_libc::sinf(x), tolerance); } // For small values, sin(x) is x. @@ -90,13 +90,13 @@ uint32_t bits = 0x17800000; float x = as_float(bits); float result = __llvm_libc::sinf(x); - EXPECT_TRUE(mpfr::equalsSin(x, result, tolerance)); + EXPECT_MPFR_MATCH(mpfr::OP_Sin, x, result, tolerance); EXPECT_EQ(bits, as_uint32_bits(result)); bits = 0x00400000; x = as_float(bits); result = __llvm_libc::sinf(x); - EXPECT_TRUE(mpfr::equalsSin(x, result, tolerance)); + EXPECT_MPFR_MATCH(mpfr::OP_Sin, x, result, tolerance); EXPECT_EQ(bits, as_uint32_bits(result)); } @@ -105,6 +105,6 @@ TEST(SinfTest, SDCOMP_26094) { for (uint32_t v : sdcomp26094Values) { float x = as_float(v); - EXPECT_TRUE(mpfr::equalsSin(x, __llvm_libc::sinf(x), tolerance)); + EXPECT_MPFR_MATCH(mpfr::OP_Sin, x, __llvm_libc::sinf(x), tolerance); } } diff --git a/libc/utils/CPP/TypeTraits.h b/libc/utils/CPP/TypeTraits.h --- a/libc/utils/CPP/TypeTraits.h +++ b/libc/utils/CPP/TypeTraits.h @@ -46,6 +46,22 @@ template struct IsSame : public FalseValue {}; template struct IsSame : public TrueValue {}; +template struct TypeIdentity { typedef T Type; }; + +template struct RemoveCV : public TypeIdentity {}; +template struct RemoveCV : public TypeIdentity {}; +template struct RemoveCV : public TypeIdentity {}; +template +struct RemoveCV : public TypeIdentity {}; + +template using RemoveCVType = typename RemoveCV::Type; + +template struct IsFloatingPointType { + static constexpr bool Value = IsSame>::Value || + IsSame>::Value || + IsSame>::Value; +}; + } // namespace cpp } // namespace __llvm_libc diff --git a/libc/utils/MPFRWrapper/CMakeLists.txt b/libc/utils/MPFRWrapper/CMakeLists.txt --- a/libc/utils/MPFRWrapper/CMakeLists.txt +++ b/libc/utils/MPFRWrapper/CMakeLists.txt @@ -12,7 +12,8 @@ MPFRUtils.cpp MPFRUtils.h ) - target_link_libraries(libcMPFRWrapper -lmpfr -lgmp) + add_dependencies(libcMPFRWrapper libc.utils.CPP.standalone_cpp LibcUnitTest LLVMSupport) + target_link_libraries(libcMPFRWrapper -lmpfr -lgmp LibcUnitTest LLVMSupport) else() message(WARNING "Math tests using MPFR will be skipped.") endif() diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h --- a/libc/utils/MPFRWrapper/MPFRUtils.h +++ b/libc/utils/MPFRWrapper/MPFRUtils.h @@ -9,6 +9,9 @@ #ifndef LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H #define LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H +#include "utils/CPP/TypeTraits.h" +#include "utils/UnitTest/Test.h" + #include namespace __llvm_libc { @@ -36,16 +39,56 @@ uint32_t bits; }; -// Return true if |libcOutput| is within the tolerance |t| of the cos(x) -// value as evaluated by MPFR. -bool equalsCos(float x, float libcOutput, const Tolerance &t); +enum Operation { + OP_Cos, + OP_Sin, +}; + +namespace internal { + +template +bool compare(Operation op, T input, T libcOutput, const Tolerance &t); + +template class MPFRMatcher : public testing::Matcher { + static_assert(__llvm_libc::cpp::IsFloatingPointType::Value, + "MPFRMatcher can only be used with floating point values."); + + Operation operation; + T input; + Tolerance tolerance; + T matchValue; + +public: + MPFRMatcher(Operation op, T testInput, Tolerance &t) + : operation(op), input(testInput), tolerance(t) {} -// Return true if |libcOutput| is within the tolerance |t| of the sin(x) -// value as evaluated by MPFR. -bool equalsSin(float x, float libcOutput, const Tolerance &t); + bool match(T libcResult) { + matchValue = libcResult; + return internal::compare(operation, input, libcResult, tolerance); + } + + void explainError(testutils::StreamWrapper &OS) override; +}; + +} // namespace internal + +template ::Value, int> = 0> +internal::MPFRMatcher getMPFRMatcher(Operation op, T input, Tolerance t) { + return internal::MPFRMatcher(op, input, t); +} } // namespace mpfr } // namespace testing } // namespace __llvm_libc +#define EXPECT_MPFR_MATCH(op, input, matchValue, tolerance) \ + EXPECT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher( \ + op, input, tolerance)) + +#define ASSERT_MPFR_MATCH(op, input, matchValue, tolerance) \ + ASSERT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher( \ + op, input, tolerance)) + #endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp --- a/libc/utils/MPFRWrapper/MPFRUtils.cpp +++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp @@ -8,7 +8,8 @@ #include "MPFRUtils.h" -#include +#include "llvm/Support/raw_ostream.h" + #include namespace __llvm_libc { @@ -25,11 +26,35 @@ public: MPFRNumber() { mpfr_init2(value, mpfrPrecision); } - explicit MPFRNumber(float x) { + template ::Value, int> = 0> + explicit MPFRNumber(XType x) { mpfr_init2(value, mpfrPrecision); mpfr_set_flt(value, x, MPFR_RNDN); } + template ::Value, int> = 0> + explicit MPFRNumber(XType x) { + mpfr_init2(value, mpfrPrecision); + mpfr_set_d(value, x, MPFR_RNDN); + } + + template ::Value, int> = 0> + MPFRNumber(Operation op, XType rawValue) { + mpfr_init2(value, mpfrPrecision); + MPFRNumber mpfrInput(rawValue); + switch (op) { + case OP_Cos: + mpfr_cos(value, mpfrInput.value, MPFR_RNDN); + break; + case OP_Sin: + mpfr_sin(value, mpfrInput.value, MPFR_RNDN); + break; + } + } + MPFRNumber(const MPFRNumber &other) { mpfr_set(value, other.value, MPFR_RNDN); } @@ -59,38 +84,51 @@ return mpfr_lessequal_p(difference.value, tolerance.value); } + std::string str() const { + // 200 bytes should be more than sufficient to hold a 100-digit number + // plus additional bytes for the decimal point, '-' sign etc. + constexpr size_t printBufSize = 200; + char buffer[printBufSize]; + mpfr_snprintf(buffer, printBufSize, "%100.50Rf", value); + llvm::StringRef ref(buffer); + ref = ref.trim(); + return ref.str(); + } + // These functions are useful for debugging. float asFloat() const { return mpfr_get_flt(value, MPFR_RNDN); } double asDouble() const { return mpfr_get_d(value, MPFR_RNDN); } void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); } +}; -public: - static MPFRNumber cos(float x) { - MPFRNumber result; - MPFRNumber mpfrX(x); - mpfr_cos(result.value, mpfrX.value, MPFR_RNDN); - return result; - } +namespace internal { + +template +void MPFRMatcher::explainError(testutils::StreamWrapper &OS) { + MPFRNumber mpfrResult(operation, input); + MPFRNumber mpfrInput(input); + MPFRNumber mpfrMatchValue(matchValue); + OS << "Match value not within tolerance value of MPFR result:\n" + << "Operation input: " << mpfrInput.str() << '\n' + << " Match value: " << mpfrMatchValue.str() << '\n' + << " MPFR result: " << mpfrResult.str() << '\n'; +} - static MPFRNumber sin(float x) { - MPFRNumber result; - MPFRNumber mpfrX(x); - mpfr_sin(result.value, mpfrX.value, MPFR_RNDN); - return result; - } +template void MPFRMatcher::explainError(testutils::StreamWrapper &); +template void MPFRMatcher::explainError(testutils::StreamWrapper &); + +template +bool compare(Operation op, T input, T libcResult, const Tolerance &t) { + MPFRNumber mpfrResult(op, input); + MPFRNumber mpfrInput(input); + MPFRNumber mpfrLibcResult(libcResult); + return mpfrResult.isEqual(mpfrLibcResult, t); }; -bool equalsCos(float input, float libcOutput, const Tolerance &t) { - MPFRNumber mpfrResult = MPFRNumber::cos(input); - MPFRNumber libcResult(libcOutput); - return mpfrResult.isEqual(libcResult, t); -} +template bool compare(Operation, float, float, const Tolerance &); +template bool compare(Operation, double, double, const Tolerance &); -bool equalsSin(float input, float libcOutput, const Tolerance &t) { - MPFRNumber mpfrResult = MPFRNumber::sin(input); - MPFRNumber libcResult(libcOutput); - return mpfrResult.isEqual(libcResult, t); -} +} // namespace internal } // namespace mpfr } // namespace testing diff --git a/libc/utils/testutils/StreamWrapper.cpp b/libc/utils/testutils/StreamWrapper.cpp --- a/libc/utils/testutils/StreamWrapper.cpp +++ b/libc/utils/testutils/StreamWrapper.cpp @@ -10,6 +10,7 @@ #include "llvm/Support/raw_ostream.h" #include #include +#include namespace __llvm_libc { namespace testutils { @@ -41,6 +42,7 @@ template StreamWrapper & StreamWrapper::operator<<(unsigned long long t); template StreamWrapper &StreamWrapper::operator<<(bool t); +template StreamWrapper &StreamWrapper::operator<<(std::string t); } // namespace testutils } // namespace __llvm_libc