diff --git a/openmp/libomptarget/plugins/cuda/src/rtl.cpp b/openmp/libomptarget/plugins/cuda/src/rtl.cpp --- a/openmp/libomptarget/plugins/cuda/src/rtl.cpp +++ b/openmp/libomptarget/plugins/cuda/src/rtl.cpp @@ -1047,6 +1047,23 @@ return memcpyDtoD(SrcPtr, DstPtr, Size, Stream); } + // Switch to destination context to enable peer access. + if (setContext(DstDevId) != OFFLOAD_SUCCESS) + return OFFLOAD_FAIL; + + Err = cuCtxEnablePeerAccess(DeviceData[SrcDevId].Context, 0); + if (Err != CUDA_SUCCESS) { + REPORT("Error returned from cuCtxEnablePeerAccess. src = %" PRId32 + ", dst = %" PRId32 "\n", + SrcDevId, DstDevId); + CUDA_ERR_STRING(Err); + return memcpyDtoD(SrcPtr, DstPtr, Size, Stream); + } + + // Switch back to source context to issue memcpy. + if (setContext(SrcDevId) != OFFLOAD_SUCCESS) + return OFFLOAD_FAIL; + Err = cuMemcpyPeerAsync((CUdeviceptr)DstPtr, DeviceData[DstDevId].Context, (CUdeviceptr)SrcPtr, DeviceData[SrcDevId].Context, Size, Stream); @@ -1464,7 +1481,7 @@ return OFFLOAD_SUCCESS; } - int setContext(int DeviceId) { + int setContext(int DeviceId) const { assert(InitializedFlags[DeviceId] && "Device is not initialized"); CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); @@ -1597,9 +1614,11 @@ __tgt_async_info *AsyncInfo) { assert(DeviceRTL.isValidDeviceId(src_dev_id) && "src_dev_id is invalid"); assert(DeviceRTL.isValidDeviceId(dst_dev_id) && "dst_dev_id is invalid"); + assert(AsyncInfo && "AsyncInfo is nullptr"); - // NOTE: We don't need to set context for data exchange as the device contexts - // are passed to CUDA function directly. + if (DeviceRTL.setContext(src_dev_id) != OFFLOAD_SUCCESS) + return OFFLOAD_FAIL; + return DeviceRTL.dataExchange(src_dev_id, src_ptr, dst_dev_id, dst_ptr, size, AsyncInfo); }