diff --git a/libc/src/__support/FPUtil/PolyEval.h b/libc/src/__support/FPUtil/PolyEval.h --- a/libc/src/__support/FPUtil/PolyEval.h +++ b/libc/src/__support/FPUtil/PolyEval.h @@ -35,6 +35,12 @@ } // namespace fputil } // namespace __llvm_libc +#ifdef LLVM_LIBC_ARCH_X86_64 + +#include "x86_64/PolyEval.h" + +#endif // LLVM_LIBC_ARCH_X86_64 + #else namespace __llvm_libc { diff --git a/libc/src/__support/FPUtil/x86_64/FMA.h b/libc/src/__support/FPUtil/x86_64/FMA.h --- a/libc/src/__support/FPUtil/x86_64/FMA.h +++ b/libc/src/__support/FPUtil/x86_64/FMA.h @@ -16,27 +16,34 @@ #endif #include "src/__support/CPP/TypeTraits.h" +#include namespace __llvm_libc { namespace fputil { template -static inline cpp::EnableIfType::Value, T> fma(T x, T y, - T z) { - float result = x; - __asm__ __volatile__("vfmadd213ss %x2, %x1, %x0" - : "+x"(result) - : "x"(y), "x"(z)); +__attribute__((target( + "fma"))) static inline cpp::EnableIfType::Value, T> +fma(T x, T y, T z) { + float result; + __m128 xmm = _mm_load_ss(&x); + __m128 ymm = _mm_load_ss(&y); + __m128 zmm = _mm_load_ss(&z); + __m128 r = _mm_fmadd_ss(xmm, ymm, zmm); + _mm_store_ss(&result, r); return result; } template -static inline cpp::EnableIfType::Value, T> fma(T x, T y, - T z) { - double result = x; - __asm__ __volatile__("vfmadd213sd %x2, %x1, %x0" - : "+x"(result) - : "x"(y), "x"(z)); +__attribute__((target( + "fma"))) static inline cpp::EnableIfType::Value, T> +fma(T x, T y, T z) { + double result; + __m128d xmm = _mm_load_sd(&x); + __m128d ymm = _mm_load_sd(&y); + __m128d zmm = _mm_load_sd(&z); + __m128d r = _mm_fmadd_sd(xmm, ymm, zmm); + _mm_store_sd(&result, r); return result; } diff --git a/libc/src/__support/FPUtil/x86_64/PolyEval.h b/libc/src/__support/FPUtil/x86_64/PolyEval.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/FPUtil/x86_64/PolyEval.h @@ -0,0 +1,86 @@ +//===-- Optimized PolyEval implementations for x86_64 --------- C++ -----*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_POLYEVAL_H +#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_POLYEVAL_H + +#include "src/__support/architectures.h" + +#if !defined(LLVM_LIBC_ARCH_X86_64) +#error "Invalid include" +#endif + +#include + +namespace __llvm_libc { +namespace fputil { + +// Cubic polynomials: +// polyeval(x, a0, a1, a2, a3) = a3*x^3 + a2*x^2 + a1*x + a0 +template <> +__attribute__((target("fma"))) inline float +polyeval(float x, float a0, float a1, float a2, float a3) { + __m128 xmm = _mm_set1_ps(x); + __m128 a13 = _mm_set_ps(0.0f, x, a3, a1); + __m128 a02 = _mm_set_ps(0.0f, 0.0f, a2, a0); + // r = (0, x^2, a3*x + a2, a1*x + a0) + __m128 r = _mm_fmadd_ps(a13, xmm, a02); + // result = (a3*x + a2) * x^2 + (a1*x + a0) + return fma(r[2], r[1], r[0]); +} + +template <> +__attribute__((target("fma"))) inline double +polyeval(double x, double a0, double a1, double a2, double a3) { + __m256d xmm = _mm256_set1_pd(x); + __m256d a13 = _mm256_set_pd(0.0, x, a3, a1); + __m256d a02 = _mm256_set_pd(0.0, 0.0, a2, a0); + // r = (0, x^2, a3*x + a2, a1*x + a0) + __m256d r = _mm256_fmadd_pd(a13, xmm, a02); + // result = (a3*x + a2) * x^2 + (a1*x + a0) + return fma(r[2], r[1], r[0]); +} + +// Quintic polynomials: +// polyeval(x, a0, a1, a2, a3, a4, a5) = a5*x^5 + a4*x^4 + a3*x^3 + a2*x^2 + +// + a1*x + a0 +template <> +__attribute__((target("fma"))) inline float +polyeval(float x, float a0, float a1, float a2, float a3, float a4, float a5) { + __m128 xmm = _mm_set1_ps(x); + __m128 a25 = _mm_set_ps(0.0f, x, a5, a2); + __m128 a14 = _mm_set_ps(0.0f, 0.0f, a4, a1); + __m128 a03 = _mm_set_ps(0.0f, 0.0f, a3, a0); + // r1 = (0, x^2, a5*x + a4, a2*x + a1) + __m128 r1 = _mm_fmadd_ps(a25, xmm, a14); + // r2 = (0, x^3, (a5*x + a4)*x + a3, (a2*x + a1)*x + a0 + __m128 r2 = _mm_fmadd_ps(r1, xmm, a03); + // result = ((a5*x + a4)*x + a3) * x^3 + ((a2*x + a1)*x + a0) + return fma(r2[2], r2[1], r2[0]); +} + +template <> +__attribute__((target("fma"))) inline double +polyeval(double x, double a0, double a1, double a2, double a3, double a4, + double a5) { + __m256d xmm = _mm256_set1_pd(x); + __m256d a25 = _mm256_set_pd(0.0, x, a5, a2); + __m256d a14 = _mm256_set_pd(0.0, 0.0, a4, a1); + __m256d a03 = _mm256_set_pd(0.0, 0.0, a3, a0); + // r1 = (0, x^2, a5*x + a4, a2*x + a1) + __m256d r1 = _mm256_fmadd_pd(a25, xmm, a14); + // r2 = (0, x^3, (a5*x + a4)*x + a3, (a2*x + a1)*x + a0 + __m256d r2 = _mm256_fmadd_pd(r1, xmm, a03); + // result = ((a5*x + a4)*x + a3) * x^3 + ((a2*x + a1)*x + a0) + return fma(r2[2], r2[1], r2[0]); +} + +} // namespace fputil +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_POLYEVAL_H diff --git a/libc/test/src/math/expm1f_test.cpp b/libc/test/src/math/expm1f_test.cpp --- a/libc/test/src/math/expm1f_test.cpp +++ b/libc/test/src/math/expm1f_test.cpp @@ -108,6 +108,6 @@ // wider precision. if (isnan(result) || isinf(result) || errno != 0) continue; - ASSERT_MPFR_MATCH(mpfr::Operation::Expm1, x, __llvm_libc::expm1f(x), 1.5); + ASSERT_MPFR_MATCH(mpfr::Operation::Expm1, x, __llvm_libc::expm1f(x), 2.2); } }