Index: include/clang/Driver/ToolChain.h =================================================================== --- include/clang/Driver/ToolChain.h +++ include/clang/Driver/ToolChain.h @@ -572,6 +572,10 @@ virtual void AddCudaIncludeArgs(const llvm::opt::ArgList &DriverArgs, llvm::opt::ArgStringList &CC1Args) const; + /// Add target math header. + virtual void AddMathDeviceFunctions(const llvm::opt::ArgList &DriverArgs, + llvm::opt::ArgStringList &CC1Args) const {}; + /// Add arguments to use MCU GCC toolchain includes. virtual void AddIAMCUIncludeArgs(const llvm::opt::ArgList &DriverArgs, llvm::opt::ArgStringList &CC1Args) const; Index: lib/Driver/ToolChains/Clang.cpp =================================================================== --- lib/Driver/ToolChains/Clang.cpp +++ lib/Driver/ToolChains/Clang.cpp @@ -1150,6 +1150,14 @@ if (JA.isOffloading(Action::OFK_Cuda)) getToolChain().AddCudaIncludeArgs(Args, CmdArgs); + // If we are offloading to a target via OpenMP and this target happens + // to be an NVIDIA GPU then we need to include the CUDA runtime wrapper + // to ensure the correct math functions are called in the offloaded + // code. + if (JA.isDeviceOffloading(Action::OFK_OpenMP) && + getToolChain().getTriple().isNVPTX()) + getToolChain().AddMathDeviceFunctions(Args, CmdArgs); + // Add -i* options, and automatically translate to // -include-pch/-include-pth for transparent PCH support. It's // wonky, but we include looking for .gch so we can support seamless Index: lib/Driver/ToolChains/Cuda.h =================================================================== --- lib/Driver/ToolChains/Cuda.h +++ lib/Driver/ToolChains/Cuda.h @@ -48,6 +48,9 @@ void AddCudaIncludeArgs(const llvm::opt::ArgList &DriverArgs, llvm::opt::ArgStringList &CC1Args) const; + void AddMathDeviceFunctions(const llvm::opt::ArgList &DriverArgs, + llvm::opt::ArgStringList &CC1Args) const; + /// Emit an error if Version does not support the given Arch. /// /// If either Version or Arch is unknown, does not emit an error. Emits at @@ -165,6 +168,9 @@ void AddCudaIncludeArgs(const llvm::opt::ArgList &DriverArgs, llvm::opt::ArgStringList &CC1Args) const override; + void AddMathDeviceFunctions(const llvm::opt::ArgList &DriverArgs, + llvm::opt::ArgStringList &CC1Args) const override; + void addClangWarningOptions(llvm::opt::ArgStringList &CC1Args) const override; CXXStdlibType GetCXXStdlibType(const llvm::opt::ArgList &Args) const override; void Index: lib/Driver/ToolChains/Cuda.cpp =================================================================== --- lib/Driver/ToolChains/Cuda.cpp +++ lib/Driver/ToolChains/Cuda.cpp @@ -255,6 +255,16 @@ CC1Args.push_back("__clang_cuda_runtime_wrapper.h"); } +void CudaInstallationDetector::AddMathDeviceFunctions( + const ArgList &DriverArgs, ArgStringList &CC1Args) const { + CC1Args.push_back("-internal-isystem"); + CC1Args.push_back(DriverArgs.MakeArgString(getIncludePath())); + CC1Args.push_back("-include"); + CC1Args.push_back("__clang_openmp_math.h"); + CC1Args.push_back("-I"); + CC1Args.push_back(DriverArgs.MakeArgString(getIncludePath())); +} + void CudaInstallationDetector::CheckCudaVersionSupportsArch( CudaArch Arch) const { if (Arch == CudaArch::UNKNOWN || Version == CudaVersion::UNKNOWN || @@ -898,6 +908,11 @@ CudaInstallation.AddCudaIncludeArgs(DriverArgs, CC1Args); } +void CudaToolChain::AddMathDeviceFunctions( + const ArgList &DriverArgs, ArgStringList &CC1Args) const { + CudaInstallation.AddMathDeviceFunctions(DriverArgs, CC1Args); +} + llvm::opt::DerivedArgList * CudaToolChain::TranslateArgs(const llvm::opt::DerivedArgList &Args, StringRef BoundArch, Index: lib/Headers/CMakeLists.txt =================================================================== --- lib/Headers/CMakeLists.txt +++ lib/Headers/CMakeLists.txt @@ -31,6 +31,7 @@ avxintrin.h bmi2intrin.h bmiintrin.h + __clang_openmp_math.h __clang_cuda_builtin_vars.h __clang_cuda_cmath.h __clang_cuda_complex_builtins.h Index: lib/Headers/__clang_openmp_math.h =================================================================== --- /dev/null +++ lib/Headers/__clang_openmp_math.h @@ -0,0 +1,65 @@ +/*===---- __clang_openmp_math.h - Target OpenMP math support ---------------=== + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + * See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + *===-----------------------------------------------------------------------=== + */ + +#ifndef __CLANG_OPENMP_MATH_H__ +#define __CLANG_OPENMP_MATH_H__ + +#pragma omp declare target + +// Declarations of function in libomptarget +#if defined(__cplusplus) +extern "C" { +#endif + +// POW +float __kmpc_powf(float, float); +double __kmpc_pow(double, double); +long double __kmpc_powl(long double, long double); + +// SIN +float __kmpc_sinf(float); +double __kmpc_sin(double); +long double __kmpc_sinl(long double); + +#if defined(__cplusplus) +} +#endif + +// POW +__attribute__((always_inline, used)) static float powf(float a, float b) { + return __kmpc_powf(a, b); +} + +__attribute__((always_inline, used)) static double pow(double a, double b) { + return __kmpc_pow(a, b); +} + +__attribute__((always_inline, used)) static long double powl( + long double a, long double b) { + return __kmpc_powl(a, b); +} + +// SIN +__attribute__((always_inline, used)) static float sinf(float a) { + return __kmpc_sinf(a); +} + +__attribute__((always_inline, used)) static double sin(double a) { + return __kmpc_sin(a); +} + +__attribute__((always_inline, used)) static long double sinl( + long double a) { + return __kmpc_sinl(a); +} + +#pragma omp end declare target + +#endif +