diff --git a/openmp/libomptarget/plugins/cuda/src/rtl.cpp b/openmp/libomptarget/plugins/cuda/src/rtl.cpp --- a/openmp/libomptarget/plugins/cuda/src/rtl.cpp +++ b/openmp/libomptarget/plugins/cuda/src/rtl.cpp @@ -167,30 +167,17 @@ /// Functions \p create and \p destroy return OFFLOAD_SUCCESS and OFFLOAD_FAIL /// accordingly. The implementation should not raise any exception. template struct AllocatorTy { - AllocatorTy(CUcontext C) noexcept : Context(C) {} using ElementTy = T; - - virtual ~AllocatorTy() {} - /// Create a resource and assign to R. virtual int create(T &R) noexcept = 0; /// Destroy the resource. virtual int destroy(T) noexcept = 0; - -protected: - CUcontext Context; }; /// Allocator for CUstream. struct StreamAllocatorTy final : public AllocatorTy { - StreamAllocatorTy(CUcontext C) noexcept : AllocatorTy(C) {} - /// See AllocatorTy::create. int create(CUstream &Stream) noexcept override { - if (!checkResult(cuCtxSetCurrent(Context), - "Error returned from cuCtxSetCurrent\n")) - return OFFLOAD_FAIL; - if (!checkResult(cuStreamCreate(&Stream, CU_STREAM_NON_BLOCKING), "Error returned from cuStreamCreate\n")) return OFFLOAD_FAIL; @@ -200,9 +187,6 @@ /// See AllocatorTy::destroy. int destroy(CUstream Stream) noexcept override { - if (!checkResult(cuCtxSetCurrent(Context), - "Error returned from cuCtxSetCurrent\n")) - return OFFLOAD_FAIL; if (!checkResult(cuStreamDestroy(Stream), "Error returned from cuStreamDestroy\n")) return OFFLOAD_FAIL; @@ -213,8 +197,6 @@ /// Allocator for CUevent. struct EventAllocatorTy final : public AllocatorTy { - EventAllocatorTy(CUcontext C) noexcept : AllocatorTy(C) {} - /// See AllocatorTy::create. int create(CUevent &Event) noexcept override { if (!checkResult(cuEventCreate(&Event, CU_EVENT_DEFAULT), @@ -363,23 +345,15 @@ /// A class responsible for interacting with device native runtime library to /// allocate and free memory. class CUDADeviceAllocatorTy : public DeviceAllocatorTy { - const int DeviceId; - const std::vector &DeviceData; std::unordered_map HostPinnedAllocs; public: - CUDADeviceAllocatorTy(int DeviceId, std::vector &DeviceData) - : DeviceId(DeviceId), DeviceData(DeviceData) {} - void *allocate(size_t Size, void *, TargetAllocTy Kind) override { if (Size == 0) return nullptr; - CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); - if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n")) - return nullptr; - void *MemAlloc = nullptr; + CUresult Err; switch (Kind) { case TARGET_ALLOC_DEFAULT: case TARGET_ALLOC_DEVICE: @@ -410,10 +384,7 @@ } int free(void *TgtPtr) override { - CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); - if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n")) - return OFFLOAD_FAIL; - + CUresult Err; // Host pinned memory must be freed differently. TargetAllocTy Kind = (HostPinnedAllocs.find(TgtPtr) == HostPinnedAllocs.end()) @@ -566,7 +537,7 @@ } for (int I = 0; I < NumberOfDevices; ++I) - DeviceAllocators.emplace_back(I, DeviceData); + DeviceAllocators.emplace_back(); // Get the size threshold from environment variable std::pair Res = MemoryManagerTy::getSizeThresholdFromEnv(); @@ -641,13 +612,13 @@ // Initialize the stream pool. if (!StreamPool[DeviceId]) - StreamPool[DeviceId] = std::make_unique( - StreamAllocatorTy(DeviceData[DeviceId].Context), NumInitialStreams); + StreamPool[DeviceId] = std::make_unique(StreamAllocatorTy(), + NumInitialStreams); // Initialize the event pool. if (!EventPool[DeviceId]) - EventPool[DeviceId] = std::make_unique( - EventAllocatorTy(DeviceData[DeviceId].Context), NumInitialEvents); + EventPool[DeviceId] = + std::make_unique(EventAllocatorTy(), NumInitialEvents); // Query attributes to determine number of threads/block and blocks/grid. int MaxGridDimX; @@ -806,18 +777,14 @@ __tgt_target_table *loadBinary(const int DeviceId, const __tgt_device_image *Image) { - // Set the context we are using - CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); - if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n")) - return nullptr; - // Clear the offload table as we are going to create a new one. clearOffloadEntriesTable(DeviceId); // Create the module and extract the function pointers. CUmodule Module; DP("Load data from image " DPxMOD "\n", DPxPTR(Image->ImageStart)); - Err = cuModuleLoadDataEx(&Module, Image->ImageStart, 0, nullptr, nullptr); + CUresult Err = + cuModuleLoadDataEx(&Module, Image->ImageStart, 0, nullptr, nullptr); if (!checkResult(Err, "Error returned from cuModuleLoadDataEx\n")) return nullptr; @@ -1004,13 +971,8 @@ const int64_t Size, __tgt_async_info *AsyncInfo) const { assert(AsyncInfo && "AsyncInfo is nullptr"); - CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); - if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n")) - return OFFLOAD_FAIL; - CUstream Stream = getStream(DeviceId, AsyncInfo); - - Err = cuMemcpyHtoDAsync((CUdeviceptr)TgtPtr, HstPtr, Size, Stream); + CUresult Err = cuMemcpyHtoDAsync((CUdeviceptr)TgtPtr, HstPtr, Size, Stream); if (Err != CUDA_SUCCESS) { DP("Error when copying data from host to device. Pointers: host " "= " DPxMOD ", device = " DPxMOD ", size = %" PRId64 "\n", @@ -1026,13 +988,8 @@ const int64_t Size, __tgt_async_info *AsyncInfo) const { assert(AsyncInfo && "AsyncInfo is nullptr"); - CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); - if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n")) - return OFFLOAD_FAIL; - CUstream Stream = getStream(DeviceId, AsyncInfo); - - Err = cuMemcpyDtoHAsync(HstPtr, (CUdeviceptr)TgtPtr, Size, Stream); + CUresult Err = cuMemcpyDtoHAsync(HstPtr, (CUdeviceptr)TgtPtr, Size, Stream); if (Err != CUDA_SUCCESS) { DP("Error when copying data from device to host. Pointers: host " "= " DPxMOD ", device = " DPxMOD ", size = %" PRId64 "\n", @@ -1048,10 +1005,7 @@ int64_t Size, __tgt_async_info *AsyncInfo) const { assert(AsyncInfo && "AsyncInfo is nullptr"); - CUresult Err = cuCtxSetCurrent(DeviceData[SrcDevId].Context); - if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n")) - return OFFLOAD_FAIL; - + CUresult Err; CUstream Stream = getStream(SrcDevId, AsyncInfo); // If they are two devices, we try peer to peer copy first @@ -1107,10 +1061,6 @@ const int TeamNum, const int ThreadLimit, const unsigned int LoopTripCount, __tgt_async_info *AsyncInfo) const { - CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); - if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n")) - return OFFLOAD_FAIL; - // All args are references. std::vector Args(ArgNum); std::vector Ptrs(ArgNum); @@ -1150,6 +1100,7 @@ CudaThreadsPerBlock = DeviceData[DeviceId].ThreadsPerBlock; } + CUresult Err; if (!KernelInfo->MaxThreadsPerBlock) { Err = cuFuncGetAttribute(&KernelInfo->MaxThreadsPerBlock, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, @@ -1476,10 +1427,6 @@ } int initAsyncInfo(int DeviceId, __tgt_async_info **AsyncInfo) const { - CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); - if (!checkResult(Err, "error returned from cuCtxSetCurrent")) - return OFFLOAD_FAIL; - *AsyncInfo = new __tgt_async_info; getStream(DeviceId, *AsyncInfo); return OFFLOAD_SUCCESS; @@ -1503,6 +1450,16 @@ } return OFFLOAD_SUCCESS; } + + int setContext(int DeviceId) { + assert(InitializedFlags[DeviceId] && "Device is not initialized"); + + CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); + if (!checkResult(Err, "error returned from cuCtxSetCurrent")) + return OFFLOAD_FAIL; + + return OFFLOAD_SUCCESS; + } }; DeviceRTLTy DeviceRTL; @@ -1535,12 +1492,14 @@ int32_t __tgt_rtl_init_device(int32_t device_id) { assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); + // Context is set when init the device. return DeviceRTL.initDevice(device_id); } int32_t __tgt_rtl_deinit_device(int32_t device_id) { assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); + // Context is set when deinit the device. return DeviceRTL.deinitDevice(device_id); } @@ -1549,6 +1508,9 @@ __tgt_device_image *image) { assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); + if (!DeviceRTL.setContext(device_id)) + return nullptr; + return DeviceRTL.loadBinary(device_id, image); } @@ -1556,12 +1518,16 @@ int32_t kind) { assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); + if (!DeviceRTL.setContext(device_id)) + return nullptr; + return DeviceRTL.dataAlloc(device_id, size, (TargetAllocTy)kind); } int32_t __tgt_rtl_data_submit(int32_t device_id, void *tgt_ptr, void *hst_ptr, int64_t size) { assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); + // Context is set in __tgt_rtl_data_submit_async. __tgt_async_info AsyncInfo; const int32_t rc = __tgt_rtl_data_submit_async(device_id, tgt_ptr, hst_ptr, @@ -1578,6 +1544,9 @@ assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); assert(async_info_ptr && "async_info_ptr is nullptr"); + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.dataSubmit(device_id, tgt_ptr, hst_ptr, size, async_info_ptr); } @@ -1585,6 +1554,7 @@ int32_t __tgt_rtl_data_retrieve(int32_t device_id, void *hst_ptr, void *tgt_ptr, int64_t size) { assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); + // Context is set in __tgt_rtl_data_retrieve_async. __tgt_async_info AsyncInfo; const int32_t rc = __tgt_rtl_data_retrieve_async(device_id, hst_ptr, tgt_ptr, @@ -1601,6 +1571,9 @@ assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); assert(async_info_ptr && "async_info_ptr is nullptr"); + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.dataRetrieve(device_id, hst_ptr, tgt_ptr, size, async_info_ptr); } @@ -1612,7 +1585,8 @@ assert(DeviceRTL.isValidDeviceId(src_dev_id) && "src_dev_id is invalid"); assert(DeviceRTL.isValidDeviceId(dst_dev_id) && "dst_dev_id is invalid"); assert(AsyncInfo && "AsyncInfo is nullptr"); - + // NOTE: We don't need to set context for data exchange as the device contexts + // are passed to CUDA function directly. return DeviceRTL.dataExchange(src_dev_id, src_ptr, dst_dev_id, dst_ptr, size, AsyncInfo); } @@ -1622,6 +1596,7 @@ int64_t size) { assert(DeviceRTL.isValidDeviceId(src_dev_id) && "src_dev_id is invalid"); assert(DeviceRTL.isValidDeviceId(dst_dev_id) && "dst_dev_id is invalid"); + // Context is set in __tgt_rtl_data_exchange_async. __tgt_async_info AsyncInfo; const int32_t rc = __tgt_rtl_data_exchange_async( @@ -1635,6 +1610,9 @@ int32_t __tgt_rtl_data_delete(int32_t device_id, void *tgt_ptr) { assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.dataDelete(device_id, tgt_ptr); } @@ -1645,6 +1623,7 @@ int32_t thread_limit, uint64_t loop_tripcount) { assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); + // Context is set in __tgt_rtl_run_target_team_region_async. __tgt_async_info AsyncInfo; const int32_t rc = __tgt_rtl_run_target_team_region_async( @@ -1663,6 +1642,9 @@ __tgt_async_info *async_info_ptr) { assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.runTargetTeamRegion( device_id, tgt_entry_ptr, tgt_args, tgt_offsets, arg_num, team_num, thread_limit, loop_tripcount, async_info_ptr); @@ -1672,6 +1654,7 @@ void **tgt_args, ptrdiff_t *tgt_offsets, int32_t arg_num) { assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); + // Context is set in __tgt_rtl_run_target_region_async. __tgt_async_info AsyncInfo; const int32_t rc = __tgt_rtl_run_target_region_async( @@ -1688,7 +1671,7 @@ int32_t arg_num, __tgt_async_info *async_info_ptr) { assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); - + // Context is set in __tgt_rtl_run_target_team_region_async. return __tgt_rtl_run_target_team_region_async( device_id, tgt_entry_ptr, tgt_args, tgt_offsets, arg_num, /* team num*/ 1, /* thread_limit */ 1, /* loop_tripcount */ 0, @@ -1700,7 +1683,7 @@ assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); assert(async_info_ptr && "async_info_ptr is nullptr"); assert(async_info_ptr->Queue && "async_info_ptr->Queue is nullptr"); - + // NOTE: We don't need to set context for stream sync. return DeviceRTL.synchronize(device_id, async_info_ptr); } @@ -1711,11 +1694,16 @@ void __tgt_rtl_print_device_info(int32_t device_id) { assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); + // NOTE: We don't need to set context for print device info. DeviceRTL.printDeviceInfo(device_id); } int32_t __tgt_rtl_create_event(int32_t device_id, void **event) { assert(event && "event is nullptr"); + + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.createEvent(device_id, event); } @@ -1724,7 +1712,7 @@ assert(async_info_ptr && "async_info_ptr is nullptr"); assert(async_info_ptr->Queue && "async_info_ptr->Queue is nullptr"); assert(event_ptr && "event_ptr is nullptr"); - + // NOTE: We might not need to set context for event record. return recordEvent(event_ptr, async_info_ptr); } @@ -1733,19 +1721,22 @@ assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); assert(async_info_ptr && "async_info_ptr is nullptr"); assert(event_ptr && "event is nullptr"); - + // NOTE: We might not need to set context for event sync. return DeviceRTL.waitEvent(device_id, async_info_ptr, event_ptr); } int32_t __tgt_rtl_sync_event(int32_t device_id, void *event_ptr) { assert(event_ptr && "event is nullptr"); - + // NOTE: We might not need to set context for event sync. return syncEvent(event_ptr); } int32_t __tgt_rtl_destroy_event(int32_t device_id, void *event_ptr) { assert(event_ptr && "event is nullptr"); + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.destroyEvent(device_id, event_ptr); } @@ -1754,6 +1745,9 @@ assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); assert(async_info && "async_info is nullptr"); + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.releaseAsyncInfo(device_id, async_info); } @@ -1762,6 +1756,9 @@ assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); assert(async_info && "async_info is nullptr"); + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.initAsyncInfo(device_id, async_info); } @@ -1771,6 +1768,9 @@ assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid"); assert(device_info_ptr && "device_info_ptr is nullptr"); + if (!DeviceRTL.setContext(device_id)) + return OFFLOAD_FAIL; + return DeviceRTL.initDeviceInfo(device_id, device_info_ptr, err_str); }