diff --git a/libc/test/src/stdlib/strtod_test.cpp b/libc/test/src/stdlib/strtod_test.cpp --- a/libc/test/src/stdlib/strtod_test.cpp +++ b/libc/test/src/stdlib/strtod_test.cpp @@ -10,12 +10,17 @@ #include "src/stdlib/strtod.h" #include "utils/UnitTest/Test.h" +#include "utils/testutils/RoundingModeUtils.h" #include #include #include -class LlvmLibcStrToDTest : public __llvm_libc::testing::Test { +using __llvm_libc::testutils::ForceRoundingModeTest; +using __llvm_libc::testutils::RoundingMode; + +class LlvmLibcStrToDTest : public __llvm_libc::testing::Test, + ForceRoundingModeTest { public: void run_test(const char *inputString, const ptrdiff_t expectedStrLen, const uint64_t expectedRawData, const int expectedErrno = 0) { diff --git a/libc/test/src/stdlib/strtof_test.cpp b/libc/test/src/stdlib/strtof_test.cpp --- a/libc/test/src/stdlib/strtof_test.cpp +++ b/libc/test/src/stdlib/strtof_test.cpp @@ -10,12 +10,17 @@ #include "src/stdlib/strtof.h" #include "utils/UnitTest/Test.h" +#include "utils/testutils/RoundingModeUtils.h" #include #include #include -class LlvmLibcStrToFTest : public __llvm_libc::testing::Test { +using __llvm_libc::testutils::ForceRoundingModeTest; +using __llvm_libc::testutils::RoundingMode; + +class LlvmLibcStrToFTest : public __llvm_libc::testing::Test, + ForceRoundingModeTest { public: void run_test(const char *inputString, const ptrdiff_t expectedStrLen, const uint32_t expectedRawData, const int expectedErrno = 0) { diff --git a/libc/utils/MPFRWrapper/CMakeLists.txt b/libc/utils/MPFRWrapper/CMakeLists.txt --- a/libc/utils/MPFRWrapper/CMakeLists.txt +++ b/libc/utils/MPFRWrapper/CMakeLists.txt @@ -12,12 +12,13 @@ libc.src.__support.CPP.type_traits libc.src.__support.FPUtil.fputil LibcUnitTest + libc_test_utils ) if(EXISTS ${LLVM_LIBC_MPFR_INSTALL_PATH}) target_include_directories(libcMPFRWrapper PUBLIC ${LLVM_LIBC_MPFR_INSTALL_PATH}/include) target_link_directories(libcMPFRWrapper PUBLIC ${LLVM_LIBC_MPFR_INSTALL_PATH}/lib) endif() - target_link_libraries(libcMPFRWrapper LibcFPTestHelpers LibcUnitTest mpfr gmp) + target_link_libraries(libcMPFRWrapper LibcFPTestHelpers LibcUnitTest mpfr gmp libc_test_utils) else() message(WARNING "Math tests using MPFR will be skipped.") endif() diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h --- a/libc/utils/MPFRWrapper/MPFRUtils.h +++ b/libc/utils/MPFRWrapper/MPFRUtils.h @@ -11,6 +11,7 @@ #include "src/__support/CPP/TypeTraits.h" #include "utils/UnitTest/Test.h" +#include "utils/testutils/RoundingModeUtils.h" #include @@ -75,17 +76,8 @@ EndTernaryOperationsSingleOutput, }; -enum class RoundingMode : uint8_t { Upward, Downward, TowardZero, Nearest }; - -int get_fe_rounding(RoundingMode mode); - -struct ForceRoundingMode { - ForceRoundingMode(RoundingMode); - ~ForceRoundingMode(); - - int old_rounding_mode; - int rounding_mode; -}; +using __llvm_libc::testutils::ForceRoundingMode; +using __llvm_libc::testutils::RoundingMode; template struct BinaryInput { static_assert( diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp --- a/libc/utils/MPFRWrapper/MPFRUtils.cpp +++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp @@ -106,35 +106,6 @@ } } -int get_fe_rounding(RoundingMode mode) { - switch (mode) { - case RoundingMode::Upward: - return FE_UPWARD; - break; - case RoundingMode::Downward: - return FE_DOWNWARD; - break; - case RoundingMode::TowardZero: - return FE_TOWARDZERO; - break; - case RoundingMode::Nearest: - return FE_TONEAREST; - break; - } -} - -ForceRoundingMode::ForceRoundingMode(RoundingMode mode) { - old_rounding_mode = fegetround(); - rounding_mode = get_fe_rounding(mode); - if (old_rounding_mode != rounding_mode) - fesetround(rounding_mode); -} - -ForceRoundingMode::~ForceRoundingMode() { - if (old_rounding_mode != rounding_mode) - fesetround(old_rounding_mode); -} - class MPFRNumber { unsigned int mpfr_precision; mpfr_rnd_t mpfr_rounding; diff --git a/libc/utils/testutils/CMakeLists.txt b/libc/utils/testutils/CMakeLists.txt --- a/libc/utils/testutils/CMakeLists.txt +++ b/libc/utils/testutils/CMakeLists.txt @@ -15,4 +15,5 @@ FDReader.h Timer.h Timer.cpp + RoundingModeUtils.cpp ) diff --git a/libc/utils/testutils/RoundingModeUtils.h b/libc/utils/testutils/RoundingModeUtils.h new file mode 100644 --- /dev/null +++ b/libc/utils/testutils/RoundingModeUtils.h @@ -0,0 +1,34 @@ +//===-- RoundingModeUtils.h -------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_UTILS_TESTUTILS_ROUNDINGMODEUTILS_H +#define LLVM_LIBC_UTILS_TESTUTILS_ROUNDINGMODEUTILS_H + +#include + +namespace __llvm_libc { +namespace testutils { + +enum class RoundingMode : uint8_t { Upward, Downward, TowardZero, Nearest }; + +struct ForceRoundingMode { + ForceRoundingMode(RoundingMode); + ~ForceRoundingMode(); + + int old_rounding_mode; + int rounding_mode; +}; + +template struct ForceRoundingModeTest : ForceRoundingMode { + ForceRoundingModeTest() : ForceRoundingMode(R) {} +}; + +} // namespace testutils +} // namespace __llvm_libc + +#endif // LLVM_LIBC_UTILS_TESTUTILS_ROUNDINGMODEUTILS_H diff --git a/libc/utils/testutils/RoundingModeUtils.cpp b/libc/utils/testutils/RoundingModeUtils.cpp new file mode 100644 --- /dev/null +++ b/libc/utils/testutils/RoundingModeUtils.cpp @@ -0,0 +1,46 @@ +//===-- RoundingModeUtils.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "RoundingModeUtils.h" + +#include + +namespace __llvm_libc { +namespace testutils { + +int get_fe_rounding(RoundingMode mode) { + switch (mode) { + case RoundingMode::Upward: + return FE_UPWARD; + break; + case RoundingMode::Downward: + return FE_DOWNWARD; + break; + case RoundingMode::TowardZero: + return FE_TOWARDZERO; + break; + case RoundingMode::Nearest: + return FE_TONEAREST; + break; + } +} + +ForceRoundingMode::ForceRoundingMode(RoundingMode mode) { + old_rounding_mode = fegetround(); + rounding_mode = get_fe_rounding(mode); + if (old_rounding_mode != rounding_mode) + fesetround(rounding_mode); +} + +ForceRoundingMode::~ForceRoundingMode() { + if (old_rounding_mode != rounding_mode) + fesetround(old_rounding_mode); +} + +} // namespace testutils +} // namespace __llvm_libc