Index: include/clang/Driver/ToolChain.h =================================================================== --- include/clang/Driver/ToolChain.h +++ include/clang/Driver/ToolChain.h @@ -572,6 +572,10 @@ virtual void AddCudaIncludeArgs(const llvm::opt::ArgList &DriverArgs, llvm::opt::ArgStringList &CC1Args) const; + /// Add target math header. + virtual void AddMathDeviceFunctions(const llvm::opt::ArgList &DriverArgs, + llvm::opt::ArgStringList &CC1Args) const {}; + /// Add arguments to use MCU GCC toolchain includes. virtual void AddIAMCUIncludeArgs(const llvm::opt::ArgList &DriverArgs, llvm::opt::ArgStringList &CC1Args) const; Index: lib/Driver/ToolChains/Clang.cpp =================================================================== --- lib/Driver/ToolChains/Clang.cpp +++ lib/Driver/ToolChains/Clang.cpp @@ -1150,6 +1150,14 @@ if (JA.isOffloading(Action::OFK_Cuda)) getToolChain().AddCudaIncludeArgs(Args, CmdArgs); + // If we are offloading to a target via OpenMP and this target happens + // to be an NVIDIA GPU then we need to include the CUDA runtime wrapper + // to ensure the correct math functions are called in the offloaded + // code. + if (JA.isDeviceOffloading(Action::OFK_OpenMP) && + getToolChain().getTriple().isNVPTX()) + getToolChain().AddMathDeviceFunctions(Args, CmdArgs); + // Add -i* options, and automatically translate to // -include-pch/-include-pth for transparent PCH support. It's // wonky, but we include looking for .gch so we can support seamless Index: lib/Driver/ToolChains/Cuda.h =================================================================== --- lib/Driver/ToolChains/Cuda.h +++ lib/Driver/ToolChains/Cuda.h @@ -48,6 +48,9 @@ void AddCudaIncludeArgs(const llvm::opt::ArgList &DriverArgs, llvm::opt::ArgStringList &CC1Args) const; + void AddMathDeviceFunctions(const llvm::opt::ArgList &DriverArgs, + llvm::opt::ArgStringList &CC1Args) const; + /// Emit an error if Version does not support the given Arch. /// /// If either Version or Arch is unknown, does not emit an error. Emits at @@ -165,6 +168,9 @@ void AddCudaIncludeArgs(const llvm::opt::ArgList &DriverArgs, llvm::opt::ArgStringList &CC1Args) const override; + void AddMathDeviceFunctions(const llvm::opt::ArgList &DriverArgs, + llvm::opt::ArgStringList &CC1Args) const override; + void addClangWarningOptions(llvm::opt::ArgStringList &CC1Args) const override; CXXStdlibType GetCXXStdlibType(const llvm::opt::ArgList &Args) const override; void Index: lib/Driver/ToolChains/Cuda.cpp =================================================================== --- lib/Driver/ToolChains/Cuda.cpp +++ lib/Driver/ToolChains/Cuda.cpp @@ -255,6 +255,16 @@ CC1Args.push_back("__clang_cuda_runtime_wrapper.h"); } +void CudaInstallationDetector::AddMathDeviceFunctions( + const ArgList &DriverArgs, ArgStringList &CC1Args) const { + CC1Args.push_back("-internal-isystem"); + CC1Args.push_back(DriverArgs.MakeArgString(getIncludePath())); + CC1Args.push_back("-include"); + CC1Args.push_back("__clang_openmp_math.h"); + CC1Args.push_back("-I"); + CC1Args.push_back(DriverArgs.MakeArgString(getIncludePath())); +} + void CudaInstallationDetector::CheckCudaVersionSupportsArch( CudaArch Arch) const { if (Arch == CudaArch::UNKNOWN || Version == CudaVersion::UNKNOWN || @@ -898,6 +908,11 @@ CudaInstallation.AddCudaIncludeArgs(DriverArgs, CC1Args); } +void CudaToolChain::AddMathDeviceFunctions( + const ArgList &DriverArgs, ArgStringList &CC1Args) const { + CudaInstallation.AddMathDeviceFunctions(DriverArgs, CC1Args); +} + llvm::opt::DerivedArgList * CudaToolChain::TranslateArgs(const llvm::opt::DerivedArgList &Args, StringRef BoundArch, Index: lib/Headers/CMakeLists.txt =================================================================== --- lib/Headers/CMakeLists.txt +++ lib/Headers/CMakeLists.txt @@ -31,6 +31,7 @@ avxintrin.h bmi2intrin.h bmiintrin.h + __clang_openmp_math.h __clang_cuda_builtin_vars.h __clang_cuda_cmath.h __clang_cuda_complex_builtins.h Index: lib/Headers/__clang_openmp_math.h =================================================================== --- /dev/null +++ lib/Headers/__clang_openmp_math.h @@ -0,0 +1,96 @@ +/*===---- __clang_openmp_math.h - Target OpenMP math support ---------------=== + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + * See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + *===-----------------------------------------------------------------------=== + */ + +#ifndef __CLANG_OPENMP_MATH_H__ +#define __CLANG_OPENMP_MATH_H__ + +#ifdef __NVPTX__ +#pragma omp declare target + +// Declarations of function in libomptarget +#if defined(__cplusplus) +extern "C" { +#endif + +// POW +float __kmpc_powf(float, float); +double __kmpc_pow(double, double); +long double __kmpc_powl(long double, long double); + +// LOG +double __kmpc_log(double); +float __kmpc_logf(float); +double __kmpc_log10(double); +float __kmpc_log10f(float); +double __kmpc_log1p(double); +float __kmpc_log1pf(float); +double __kmpc_log2(double); +float __kmpc_log2f(float); +double __kmpc_logb(double); +float __kmpc_logbf(float); + +// SIN +float __kmpc_sinf(float); +double __kmpc_sin(double); +long double __kmpc_sinl(long double); + +// COS +float __kmpc_cosf(float); +double __kmpc_cos(double); +long double __kmpc_cosl(long double); + +#if defined(__cplusplus) +} +#endif + +// Single argument functions +#define __OPENMP_MATH_FUNC_1(__ty, __fn, __kmpc_fn) \ + __attribute__((always_inline, used)) static __ty \ + __fn(__ty __x) { \ + return __kmpc_fn(__x); \ + } + +// Double argument functions +#define __OPENMP_MATH_FUNC_2(__ty, __fn, __kmpc_fn) \ + __attribute__((always_inline, used)) static __ty \ + __fn(__ty __x, __ty __y) { \ + return __kmpc_fn(__x, __y); \ + } + +// POW +__OPENMP_MATH_FUNC_2(float, powf, __kmpc_powf); +__OPENMP_MATH_FUNC_2(double, pow, __kmpc_pow); +__OPENMP_MATH_FUNC_2(long double, powl, __kmpc_powl); + +// LOG +__OPENMP_MATH_FUNC_1(double, log, __kmpc_log); +__OPENMP_MATH_FUNC_1(float, logf, __kmpc_logf); +__OPENMP_MATH_FUNC_1(double, log10, __kmpc_log10); +__OPENMP_MATH_FUNC_1(float, log10f, __kmpc_log10f); +__OPENMP_MATH_FUNC_1(double, log1p, __kmpc_log1p); +__OPENMP_MATH_FUNC_1(float, log1pf, __kmpc_log1pf); +__OPENMP_MATH_FUNC_1(double, log2, __kmpc_log2); +__OPENMP_MATH_FUNC_1(float, log2f, __kmpc_log2f); +__OPENMP_MATH_FUNC_1(double, logb, __kmpc_logb); +__OPENMP_MATH_FUNC_1(float, logbf, __kmpc_logbf); + +// SIN +__OPENMP_MATH_FUNC_1(float, sinf, __kmpc_sinf); +__OPENMP_MATH_FUNC_1(double, sin, __kmpc_sin); +__OPENMP_MATH_FUNC_1(long double, sinl, __kmpc_sinl); + +// COS +__OPENMP_MATH_FUNC_1(float, cosf, __kmpc_cosf); +__OPENMP_MATH_FUNC_1(double, cos, __kmpc_cos); +__OPENMP_MATH_FUNC_1(long double, cosl, __kmpc_cosl); + +#pragma omp end declare target +#endif +#endif +