diff --git a/openmp/libomptarget/plugins/cuda/dynamic_cuda/cuda.h b/openmp/libomptarget/plugins/cuda/dynamic_cuda/cuda.h --- a/openmp/libomptarget/plugins/cuda/dynamic_cuda/cuda.h +++ b/openmp/libomptarget/plugins/cuda/dynamic_cuda/cuda.h @@ -48,18 +48,6 @@ CU_CTX_SCHED_MASK = 0x07, } CUctx_flags; -#define cuMemFree cuMemFree_v2 -#define cuMemAlloc cuMemAlloc_v2 -#define cuMemcpyDtoH cuMemcpyDtoH_v2 -#define cuMemcpyHtoD cuMemcpyHtoD_v2 -#define cuStreamDestroy cuStreamDestroy_v2 -#define cuModuleGetGlobal cuModuleGetGlobal_v2 -#define cuMemcpyDtoHAsync cuMemcpyDtoHAsync_v2 -#define cuMemcpyDtoDAsync cuMemcpyDtoDAsync_v2 -#define cuMemcpyHtoDAsync cuMemcpyHtoDAsync_v2 -#define cuDevicePrimaryCtxRelease cuDevicePrimaryCtxRelease_v2 -#define cuDevicePrimaryCtxSetFlags cuDevicePrimaryCtxSetFlags_v2 - CUresult cuCtxGetDevice(CUdevice *); CUresult cuDeviceGet(CUdevice *, int); CUresult cuDeviceGetAttribute(int *, CUdevice_attribute, CUdevice); diff --git a/openmp/libomptarget/plugins/cuda/dynamic_cuda/cuda.cpp b/openmp/libomptarget/plugins/cuda/dynamic_cuda/cuda.cpp --- a/openmp/libomptarget/plugins/cuda/dynamic_cuda/cuda.cpp +++ b/openmp/libomptarget/plugins/cuda/dynamic_cuda/cuda.cpp @@ -15,6 +15,9 @@ #include "Debug.h" #include "dlwrap.h" +#include +#include + #include DLWRAP_INTERNAL(cuInit, 1); @@ -67,6 +70,21 @@ static bool checkForCUDA() { // return true if dlopen succeeded and all functions found + // Prefer _v2 versions of functions if found in the library + std::unordered_map TryFirst = { + {"cuMemAlloc", "cuMemAlloc_v2"}, + {"cuMemFree", "cuMemFree_v2"}, + {"cuMemcpyDtoH", "cuMemcpyDtoH_v2"}, + {"cuMemcpyHtoD", "cuMemcpyHtoD_v2"}, + {"cuStreamDestroy", "cuStreamDestroy_v2"}, + {"cuModuleGetGlobal", "cuModuleGetGlobal_v2"}, + {"cuMemcpyDtoHAsync", "cuMemcpyDtoHAsync_v2"}, + {"cuMemcpyDtoDAsync", "cuMemcpyDtoDAsync_v2"}, + {"cuMemcpyHtoDAsync", "cuMemcpyHtoDAsync_v2"}, + {"cuDevicePrimaryCtxRelease", "cuDevicePrimaryCtxRelease_v2"}, + {"cuDevicePrimaryCtxSetFlags", "cuDevicePrimaryCtxSetFlags_v2"}, + }; + const char *CudaLib = DYNAMIC_CUDA_PATH; void *DynlibHandle = dlopen(CudaLib, RTLD_NOW); if (!DynlibHandle) { @@ -77,11 +95,23 @@ for (size_t I = 0; I < dlwrap::size(); I++) { const char *Sym = dlwrap::symbol(I); + auto It = TryFirst.find(Sym); + if (It != TryFirst.end()) { + const char *First = It->second; + void *P = dlsym(DynlibHandle, First); + if (P) { + DP("Implementing %s with dlsym(%s) -> %p\n", Sym, First, P); + *dlwrap::pointer(I) = P; + continue; + } + } + void *P = dlsym(DynlibHandle, Sym); if (P == nullptr) { DP("Unable to find '%s' in '%s'!\n", Sym, CudaLib); return false; } + DP("Implementing %s with dlsym(%s) -> %p\n", Sym, Sym, P); *dlwrap::pointer(I) = P; }