diff --git a/openmp/libomptarget/plugins/cuda/src/rtl.cpp b/openmp/libomptarget/plugins/cuda/src/rtl.cpp --- a/openmp/libomptarget/plugins/cuda/src/rtl.cpp +++ b/openmp/libomptarget/plugins/cuda/src/rtl.cpp @@ -355,6 +355,10 @@ /// devices. std::vector InitializedFlags; + enum class PeerAccessState : uint8_t { Unkown, Enabled, Yes, No }; + std::vector> PeerAccessMatrix; + std::mutex PeerAccessMatrixLock; + /// A class responsible for interacting with device native runtime library to /// allocate and free memory. class CUDADeviceAllocatorTy : public DeviceAllocatorTy { @@ -520,6 +524,9 @@ Modules.resize(NumberOfDevices); StreamPool.resize(NumberOfDevices); EventPool.resize(NumberOfDevices); + PeerAccessMatrix.resize(NumberOfDevices); + for (auto &V : PeerAccessMatrix) + V.resize(NumberOfDevices, PeerAccessState::Unkown); // Get environment variables regarding teams if (const char *EnvStr = getenv("OMP_TEAM_LIMIT")) { @@ -581,6 +588,8 @@ void setRequiresFlag(const int64_t Flags) { this->RequiresFlags = Flags; } int initDevice(const int DeviceId) { + assert(InitializedFlags[DeviceId] == false && "Reinitializing device!"); + CUdevice Device; DP("Getting device %d\n", DeviceId); @@ -588,9 +597,6 @@ if (!checkResult(Err, "Error returned from cuDeviceGet\n")) return OFFLOAD_FAIL; - assert(InitializedFlags[DeviceId] == false && "Reinitializing device!"); - InitializedFlags[DeviceId] = true; - // Query the current flags of the primary context and set its flags if // it is inactive unsigned int FormerPrimaryCtxFlags = 0; @@ -623,6 +629,8 @@ if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n")) return OFFLOAD_FAIL; + InitializedFlags[DeviceId] = true; + // Initialize the stream pool. if (!StreamPool[DeviceId]) StreamPool[DeviceId] = std::make_unique(StreamAllocatorTy(), @@ -1015,7 +1023,7 @@ } int dataExchange(int SrcDevId, const void *SrcPtr, int DstDevId, void *DstPtr, - int64_t Size, __tgt_async_info *AsyncInfo) const { + int64_t Size, __tgt_async_info *AsyncInfo) { assert(AsyncInfo && "AsyncInfo is nullptr"); CUresult Err; @@ -1023,40 +1031,72 @@ // If they are two devices, we try peer to peer copy first if (SrcDevId != DstDevId) { - int CanAccessPeer = 0; - Err = cuDeviceCanAccessPeer(&CanAccessPeer, SrcDevId, DstDevId); - if (Err != CUDA_SUCCESS) { - REPORT("Error returned from cuDeviceCanAccessPeer. src = %" PRId32 - ", dst = %" PRId32 "\n", + std::lock_guard LG(PeerAccessMatrixLock); + + switch (PeerAccessMatrix[SrcDevId][DstDevId]) { + case PeerAccessState::No: { + REPORT("Peer access from %" PRId32 " to %" PRId32 + " is not supported. Fall back to D2D memcpy.\n", SrcDevId, DstDevId); + return memcpyDtoD(SrcPtr, DstPtr, Size, Stream); + } + case PeerAccessState::Unkown: { + int CanAccessPeer = 0; + Err = cuDeviceCanAccessPeer(&CanAccessPeer, SrcDevId, DstDevId); + if (Err != CUDA_SUCCESS) { + REPORT("Error returned from cuDeviceCanAccessPeer. src = %" PRId32 + ", dst = %" PRId32 ". Fall back to D2D memcpy.\n", + SrcDevId, DstDevId); + CUDA_ERR_STRING(Err); + PeerAccessMatrix[SrcDevId][DstDevId] = PeerAccessState::No; + return memcpyDtoD(SrcPtr, DstPtr, Size, Stream); + } + + if (!CanAccessPeer) { + REPORT("P2P access from %d to %d is not supported. Fall back to D2D " + "memcpy.\n", + SrcDevId, DstDevId); + PeerAccessMatrix[SrcDevId][DstDevId] = PeerAccessState::No; + return memcpyDtoD(SrcPtr, DstPtr, Size, Stream); + } + + LLVM_FALLTHROUGH; + } + case PeerAccessState::Yes: { + Err = cuCtxEnablePeerAccess(DeviceData[DstDevId].Context, 0); + if (Err != CUDA_SUCCESS) { + REPORT("Error returned from cuCtxEnablePeerAccess. src = %" PRId32 + ", dst = %" PRId32 ". Fall back to D2D memcpy.\n", + SrcDevId, DstDevId); + CUDA_ERR_STRING(Err); + PeerAccessMatrix[SrcDevId][DstDevId] = PeerAccessState::No; + return memcpyDtoD(SrcPtr, DstPtr, Size, Stream); + } + + PeerAccessMatrix[SrcDevId][DstDevId] = PeerAccessState::Enabled; + + LLVM_FALLTHROUGH; + } + case PeerAccessState::Enabled: { + Err = cuMemcpyPeerAsync( + (CUdeviceptr)DstPtr, DeviceData[DstDevId].Context, + (CUdeviceptr)SrcPtr, DeviceData[SrcDevId].Context, Size, Stream); + if (Err == CUDA_SUCCESS) + return OFFLOAD_SUCCESS; + + DP("Error returned from cuMemcpyPeerAsync. src_ptr = " DPxMOD + ", src_id =%" PRId32 ", dst_ptr = " DPxMOD ", dst_id =%" PRId32 + ". Fall back to D2D memcpy.\n", + DPxPTR(SrcPtr), SrcDevId, DPxPTR(DstPtr), DstDevId); CUDA_ERR_STRING(Err); + return memcpyDtoD(SrcPtr, DstPtr, Size, Stream); } - - if (!CanAccessPeer) { - DP("P2P memcpy not supported so fall back to D2D memcpy"); - return memcpyDtoD(SrcPtr, DstPtr, Size, Stream); + default: + REPORT("Unknown PeerAccessState %d.\n", + int(PeerAccessMatrix[SrcDevId][DstDevId])); + return OFFLOAD_FAIL; } - - Err = cuCtxEnablePeerAccess(DeviceData[DstDevId].Context, 0); - if (Err != CUDA_SUCCESS) { - REPORT("Error returned from cuCtxEnablePeerAccess. src = %" PRId32 - ", dst = %" PRId32 "\n", - SrcDevId, DstDevId); - CUDA_ERR_STRING(Err); - return memcpyDtoD(SrcPtr, DstPtr, Size, Stream); - } - - Err = cuMemcpyPeerAsync((CUdeviceptr)DstPtr, DeviceData[DstDevId].Context, - (CUdeviceptr)SrcPtr, DeviceData[SrcDevId].Context, - Size, Stream); - if (Err == CUDA_SUCCESS) - return OFFLOAD_SUCCESS; - - DP("Error returned from cuMemcpyPeerAsync. src_ptr = " DPxMOD - ", src_id =%" PRId32 ", dst_ptr = " DPxMOD ", dst_id =%" PRId32 "\n", - DPxPTR(SrcPtr), SrcDevId, DPxPTR(DstPtr), DstDevId); - CUDA_ERR_STRING(Err); } return memcpyDtoD(SrcPtr, DstPtr, Size, Stream); @@ -1464,7 +1504,7 @@ return OFFLOAD_SUCCESS; } - int setContext(int DeviceId) { + int setContext(int DeviceId) const { assert(InitializedFlags[DeviceId] && "Device is not initialized"); CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); @@ -1598,8 +1638,10 @@ assert(DeviceRTL.isValidDeviceId(src_dev_id) && "src_dev_id is invalid"); assert(DeviceRTL.isValidDeviceId(dst_dev_id) && "dst_dev_id is invalid"); assert(AsyncInfo && "AsyncInfo is nullptr"); - // NOTE: We don't need to set context for data exchange as the device contexts - // are passed to CUDA function directly. + + if (DeviceRTL.setContext(src_dev_id) != OFFLOAD_SUCCESS) + return OFFLOAD_FAIL; + return DeviceRTL.dataExchange(src_dev_id, src_ptr, dst_dev_id, dst_ptr, size, AsyncInfo); }