diff --git a/openmp/libomptarget/plugins/cuda/src/rtl.cpp b/openmp/libomptarget/plugins/cuda/src/rtl.cpp --- a/openmp/libomptarget/plugins/cuda/src/rtl.cpp +++ b/openmp/libomptarget/plugins/cuda/src/rtl.cpp @@ -167,30 +167,17 @@ /// Functions \p create and \p destroy return OFFLOAD_SUCCESS and OFFLOAD_FAIL /// accordingly. The implementation should not raise any exception. template struct AllocatorTy { - AllocatorTy(CUcontext C) noexcept : Context(C) {} using ElementTy = T; - - virtual ~AllocatorTy() {} - /// Create a resource and assign to R. virtual int create(T &R) noexcept = 0; /// Destroy the resource. virtual int destroy(T) noexcept = 0; - -protected: - CUcontext Context; }; /// Allocator for CUstream. struct StreamAllocatorTy final : public AllocatorTy { - StreamAllocatorTy(CUcontext C) noexcept : AllocatorTy(C) {} - /// See AllocatorTy::create. int create(CUstream &Stream) noexcept override { - if (!checkResult(cuCtxSetCurrent(Context), - "Error returned from cuCtxSetCurrent\n")) - return OFFLOAD_FAIL; - if (!checkResult(cuStreamCreate(&Stream, CU_STREAM_NON_BLOCKING), "Error returned from cuStreamCreate\n")) return OFFLOAD_FAIL; @@ -200,9 +187,6 @@ /// See AllocatorTy::destroy. int destroy(CUstream Stream) noexcept override { - if (!checkResult(cuCtxSetCurrent(Context), - "Error returned from cuCtxSetCurrent\n")) - return OFFLOAD_FAIL; if (!checkResult(cuStreamDestroy(Stream), "Error returned from cuStreamDestroy\n")) return OFFLOAD_FAIL; @@ -213,8 +197,6 @@ /// Allocator for CUevent. struct EventAllocatorTy final : public AllocatorTy { - EventAllocatorTy(CUcontext C) noexcept : AllocatorTy(C) {} - /// See AllocatorTy::create. int create(CUevent &Event) noexcept override { if (!checkResult(cuEventCreate(&Event, CU_EVENT_DEFAULT), @@ -363,23 +345,15 @@ /// A class responsible for interacting with device native runtime library to /// allocate and free memory. class CUDADeviceAllocatorTy : public DeviceAllocatorTy { - const int DeviceId; - const std::vector &DeviceData; std::unordered_map HostPinnedAllocs; public: - CUDADeviceAllocatorTy(int DeviceId, std::vector &DeviceData) - : DeviceId(DeviceId), DeviceData(DeviceData) {} - void *allocate(size_t Size, void *, TargetAllocTy Kind) override { if (Size == 0) return nullptr; - CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); - if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n")) - return nullptr; - void *MemAlloc = nullptr; + CUresult Err; switch (Kind) { case TARGET_ALLOC_DEFAULT: case TARGET_ALLOC_DEVICE: @@ -410,10 +384,7 @@ } int free(void *TgtPtr) override { - CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); - if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n")) - return OFFLOAD_FAIL; - + CUresult Err; // Host pinned memory must be freed differently. TargetAllocTy Kind = (HostPinnedAllocs.find(TgtPtr) == HostPinnedAllocs.end()) @@ -566,7 +537,7 @@ } for (int I = 0; I < NumberOfDevices; ++I) - DeviceAllocators.emplace_back(I, DeviceData); + DeviceAllocators.emplace_back(); // Get the size threshold from environment variable std::pair Res = MemoryManagerTy::getSizeThresholdFromEnv(); @@ -641,13 +612,13 @@ // Initialize the stream pool. if (!StreamPool[DeviceId]) - StreamPool[DeviceId] = std::make_unique( - StreamAllocatorTy(DeviceData[DeviceId].Context), NumInitialStreams); + StreamPool[DeviceId] = std::make_unique(StreamAllocatorTy(), + NumInitialStreams); // Initialize the event pool. if (!EventPool[DeviceId]) - EventPool[DeviceId] = std::make_unique( - EventAllocatorTy(DeviceData[DeviceId].Context), NumInitialEvents); + EventPool[DeviceId] = + std::make_unique(EventAllocatorTy(), NumInitialEvents); // Query attributes to determine number of threads/block and blocks/grid. int MaxGridDimX; @@ -801,18 +772,14 @@ __tgt_target_table *loadBinary(const int DeviceId, const __tgt_device_image *Image) { - // Set the context we are using - CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); - if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n")) - return nullptr; - // Clear the offload table as we are going to create a new one. clearOffloadEntriesTable(DeviceId); // Create the module and extract the function pointers. CUmodule Module; DP("Load data from image " DPxMOD "\n", DPxPTR(Image->ImageStart)); - Err = cuModuleLoadDataEx(&Module, Image->ImageStart, 0, nullptr, nullptr); + CUresult Err = + cuModuleLoadDataEx(&Module, Image->ImageStart, 0, nullptr, nullptr); if (!checkResult(Err, "Error returned from cuModuleLoadDataEx\n")) return nullptr; @@ -999,13 +966,8 @@ const int64_t Size, __tgt_async_info *AsyncInfo) const { assert(AsyncInfo && "AsyncInfo is nullptr"); - CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); - if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n")) - return OFFLOAD_FAIL; - CUstream Stream = getStream(DeviceId, AsyncInfo); - - Err = cuMemcpyHtoDAsync((CUdeviceptr)TgtPtr, HstPtr, Size, Stream); + CUresult Err = cuMemcpyHtoDAsync((CUdeviceptr)TgtPtr, HstPtr, Size, Stream); if (Err != CUDA_SUCCESS) { DP("Error when copying data from host to device. Pointers: host " "= " DPxMOD ", device = " DPxMOD ", size = %" PRId64 "\n", @@ -1021,13 +983,8 @@ const int64_t Size, __tgt_async_info *AsyncInfo) const { assert(AsyncInfo && "AsyncInfo is nullptr"); - CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); - if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n")) - return OFFLOAD_FAIL; - CUstream Stream = getStream(DeviceId, AsyncInfo); - - Err = cuMemcpyDtoHAsync(HstPtr, (CUdeviceptr)TgtPtr, Size, Stream); + CUresult Err = cuMemcpyDtoHAsync(HstPtr, (CUdeviceptr)TgtPtr, Size, Stream); if (Err != CUDA_SUCCESS) { DP("Error when copying data from device to host. Pointers: host " "= " DPxMOD ", device = " DPxMOD ", size = %" PRId64 "\n", @@ -1043,10 +1000,7 @@ int64_t Size, __tgt_async_info *AsyncInfo) const { assert(AsyncInfo && "AsyncInfo is nullptr"); - CUresult Err = cuCtxSetCurrent(DeviceData[SrcDevId].Context); - if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n")) - return OFFLOAD_FAIL; - + CUresult Err; CUstream Stream = getStream(SrcDevId, AsyncInfo); // If they are two devices, we try peer to peer copy first @@ -1102,10 +1056,6 @@ const int TeamNum, const int ThreadLimit, const unsigned int LoopTripCount, __tgt_async_info *AsyncInfo) const { - CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); - if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n")) - return OFFLOAD_FAIL; - // All args are references. std::vector Args(ArgNum); std::vector Ptrs(ArgNum); @@ -1145,6 +1095,7 @@ CudaThreadsPerBlock = DeviceData[DeviceId].ThreadsPerBlock; } + CUresult Err; if (!KernelInfo->MaxThreadsPerBlock) { Err = cuFuncGetAttribute(&KernelInfo->MaxThreadsPerBlock, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, @@ -1471,10 +1422,6 @@ } int initAsyncInfo(int DeviceId, __tgt_async_info **AsyncInfo) const { - CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); - if (!checkResult(Err, "error returned from cuCtxSetCurrent")) - return OFFLOAD_FAIL; - *AsyncInfo = new __tgt_async_info; getStream(DeviceId, *AsyncInfo); return OFFLOAD_SUCCESS; @@ -1498,6 +1445,16 @@ } return OFFLOAD_SUCCESS; } + + int setContext(int DeviceId) { + assert(InitializedFlags[DeviceId] && "Device is not initialized"); + + CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); + if (!checkResult(Err, "error returned from cuCtxSetCurrent")) + return OFFLOAD_FAIL; + + return OFFLOAD_SUCCESS; + } }; DeviceRTLTy DeviceRTL; @@ -1544,6 +1501,9 @@ __tgt_device_image *image) { assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); + if (!DeviceRTL.setContext(device_id)) + return nullptr; + return DeviceRTL.loadBinary(device_id, image); } @@ -1551,6 +1511,9 @@ int32_t kind) { assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); + if (!DeviceRTL.setContext(device_id)) + return nullptr; + return DeviceRTL.dataAlloc(device_id, size, (TargetAllocTy)kind); } @@ -1573,6 +1536,9 @@ assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); assert(async_info_ptr && "async_info_ptr is nullptr"); + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.dataSubmit(device_id, tgt_ptr, hst_ptr, size, async_info_ptr); } @@ -1596,6 +1562,9 @@ assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); assert(async_info_ptr && "async_info_ptr is nullptr"); + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.dataRetrieve(device_id, hst_ptr, tgt_ptr, size, async_info_ptr); } @@ -1630,6 +1599,9 @@ int32_t __tgt_rtl_data_delete(int32_t device_id, void *tgt_ptr) { assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.dataDelete(device_id, tgt_ptr); } @@ -1658,6 +1630,9 @@ __tgt_async_info *async_info_ptr) { assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.runTargetTeamRegion( device_id, tgt_entry_ptr, tgt_args, tgt_offsets, arg_num, team_num, thread_limit, loop_tripcount, async_info_ptr); @@ -1711,6 +1686,10 @@ int32_t __tgt_rtl_create_event(int32_t device_id, void **event) { assert(event && "event is nullptr"); + + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.createEvent(device_id, event); } @@ -1741,6 +1720,9 @@ int32_t __tgt_rtl_destroy_event(int32_t device_id, void *event_ptr) { assert(event_ptr && "event is nullptr"); + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.destroyEvent(device_id, event_ptr); } @@ -1749,6 +1731,9 @@ assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); assert(async_info && "async_info is nullptr"); + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.releaseAsyncInfo(device_id, async_info); } @@ -1757,6 +1742,9 @@ assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); assert(async_info && "async_info is nullptr"); + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.initAsyncInfo(device_id, async_info); } @@ -1766,6 +1754,9 @@ assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); assert(device_info_ptr && "device_info_ptr is nullptr"); + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.initDeviceInfo(device_id, device_info_ptr, err_str); }