diff --git a/openmp/libomptarget/plugins/cuda/src/rtl.cpp b/openmp/libomptarget/plugins/cuda/src/rtl.cpp
--- a/openmp/libomptarget/plugins/cuda/src/rtl.cpp
+++ b/openmp/libomptarget/plugins/cuda/src/rtl.cpp
@@ -167,30 +167,17 @@
 /// Functions \p create and \p destroy return OFFLOAD_SUCCESS and OFFLOAD_FAIL
 /// accordingly. The implementation should not raise any exception.
 template <typename T> struct AllocatorTy {
-  AllocatorTy(CUcontext C) noexcept : Context(C) {}
   using ElementTy = T;
-
-  virtual ~AllocatorTy() {}
-
   /// Create a resource and assign to R.
   virtual int create(T &R) noexcept = 0;
   /// Destroy the resource.
   virtual int destroy(T) noexcept = 0;
-
-protected:
-  CUcontext Context;
 };
 
 /// Allocator for CUstream.
 struct StreamAllocatorTy final : public AllocatorTy<CUstream> {
-  StreamAllocatorTy(CUcontext C) noexcept : AllocatorTy<CUstream>(C) {}
-
   /// See AllocatorTy<T>::create.
   int create(CUstream &Stream) noexcept override {
-    if (!checkResult(cuCtxSetCurrent(Context),
-                     "Error returned from cuCtxSetCurrent\n"))
-      return OFFLOAD_FAIL;
-
     if (!checkResult(cuStreamCreate(&Stream, CU_STREAM_NON_BLOCKING),
                      "Error returned from cuStreamCreate\n"))
       return OFFLOAD_FAIL;
@@ -200,9 +187,6 @@
 
   /// See AllocatorTy<T>::destroy.
   int destroy(CUstream Stream) noexcept override {
-    if (!checkResult(cuCtxSetCurrent(Context),
-                     "Error returned from cuCtxSetCurrent\n"))
-      return OFFLOAD_FAIL;
     if (!checkResult(cuStreamDestroy(Stream),
                      "Error returned from cuStreamDestroy\n"))
       return OFFLOAD_FAIL;
@@ -213,8 +197,6 @@
 
 /// Allocator for CUevent.
 struct EventAllocatorTy final : public AllocatorTy<CUevent> {
-  EventAllocatorTy(CUcontext C) noexcept : AllocatorTy<CUevent>(C) {}
-
   /// See AllocatorTy<T>::create.
   int create(CUevent &Event) noexcept override {
     if (!checkResult(cuEventCreate(&Event, CU_EVENT_DEFAULT),
@@ -363,23 +345,15 @@
   /// A class responsible for interacting with device native runtime library to
   /// allocate and free memory.
   class CUDADeviceAllocatorTy : public DeviceAllocatorTy {
-    const int DeviceId;
-    const std::vector<DeviceDataTy> &DeviceData;
     std::unordered_map<void *, TargetAllocTy> HostPinnedAllocs;
 
   public:
-    CUDADeviceAllocatorTy(int DeviceId, std::vector<DeviceDataTy> &DeviceData)
-        : DeviceId(DeviceId), DeviceData(DeviceData) {}
-
     void *allocate(size_t Size, void *, TargetAllocTy Kind) override {
       if (Size == 0)
         return nullptr;
 
-      CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);
-      if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n"))
-        return nullptr;
-
       void *MemAlloc = nullptr;
+      CUresult Err;
       switch (Kind) {
       case TARGET_ALLOC_DEFAULT:
       case TARGET_ALLOC_DEVICE:
@@ -410,10 +384,7 @@
     }
 
     int free(void *TgtPtr) override {
-      CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);
-      if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n"))
-        return OFFLOAD_FAIL;
-
+      CUresult Err;
       // Host pinned memory must be freed differently.
       TargetAllocTy Kind =
           (HostPinnedAllocs.find(TgtPtr) == HostPinnedAllocs.end())
@@ -566,7 +537,7 @@
     }
 
     for (int I = 0; I < NumberOfDevices; ++I)
-      DeviceAllocators.emplace_back(I, DeviceData);
+      DeviceAllocators.emplace_back();
 
     // Get the size threshold from environment variable
     std::pair<size_t, bool> Res = MemoryManagerTy::getSizeThresholdFromEnv();
@@ -641,13 +612,13 @@
 
     // Initialize the stream pool.
     if (!StreamPool[DeviceId])
-      StreamPool[DeviceId] = std::make_unique<StreamPoolTy>(
-          StreamAllocatorTy(DeviceData[DeviceId].Context), NumInitialStreams);
+      StreamPool[DeviceId] = std::make_unique<StreamPoolTy>(StreamAllocatorTy(),
+                                                            NumInitialStreams);
 
     // Initialize the event pool.
     if (!EventPool[DeviceId])
-      EventPool[DeviceId] = std::make_unique<EventPoolTy>(
-          EventAllocatorTy(DeviceData[DeviceId].Context), NumInitialEvents);
+      EventPool[DeviceId] =
+          std::make_unique<EventPoolTy>(EventAllocatorTy(), NumInitialEvents);
 
     // Query attributes to determine number of threads/block and blocks/grid.
     int MaxGridDimX;
@@ -806,18 +777,14 @@
 
   __tgt_target_table *loadBinary(const int DeviceId,
                                  const __tgt_device_image *Image) {
-    // Set the context we are using
-    CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);
-    if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n"))
-      return nullptr;
-
     // Clear the offload table as we are going to create a new one.
     clearOffloadEntriesTable(DeviceId);
 
     // Create the module and extract the function pointers.
     CUmodule Module;
     DP("Load data from image " DPxMOD "\n", DPxPTR(Image->ImageStart));
-    Err = cuModuleLoadDataEx(&Module, Image->ImageStart, 0, nullptr, nullptr);
+    CUresult Err =
+        cuModuleLoadDataEx(&Module, Image->ImageStart, 0, nullptr, nullptr);
     if (!checkResult(Err, "Error returned from cuModuleLoadDataEx\n"))
       return nullptr;
 
@@ -1004,13 +971,8 @@
                  const int64_t Size, __tgt_async_info *AsyncInfo) const {
     assert(AsyncInfo && "AsyncInfo is nullptr");
 
-    CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);
-    if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n"))
-      return OFFLOAD_FAIL;
-
     CUstream Stream = getStream(DeviceId, AsyncInfo);
-
-    Err = cuMemcpyHtoDAsync((CUdeviceptr)TgtPtr, HstPtr, Size, Stream);
+    CUresult Err = cuMemcpyHtoDAsync((CUdeviceptr)TgtPtr, HstPtr, Size, Stream);
     if (Err != CUDA_SUCCESS) {
       DP("Error when copying data from host to device. Pointers: host "
          "= " DPxMOD ", device = " DPxMOD ", size = %" PRId64 "\n",
@@ -1026,13 +988,8 @@
                    const int64_t Size, __tgt_async_info *AsyncInfo) const {
     assert(AsyncInfo && "AsyncInfo is nullptr");
 
-    CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);
-    if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n"))
-      return OFFLOAD_FAIL;
-
     CUstream Stream = getStream(DeviceId, AsyncInfo);
-
-    Err = cuMemcpyDtoHAsync(HstPtr, (CUdeviceptr)TgtPtr, Size, Stream);
+    CUresult Err = cuMemcpyDtoHAsync(HstPtr, (CUdeviceptr)TgtPtr, Size, Stream);
     if (Err != CUDA_SUCCESS) {
       DP("Error when copying data from device to host. Pointers: host "
          "= " DPxMOD ", device = " DPxMOD ", size = %" PRId64 "\n",
@@ -1048,10 +1005,7 @@
                    int64_t Size, __tgt_async_info *AsyncInfo) const {
     assert(AsyncInfo && "AsyncInfo is nullptr");
 
-    CUresult Err = cuCtxSetCurrent(DeviceData[SrcDevId].Context);
-    if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n"))
-      return OFFLOAD_FAIL;
-
+    CUresult Err;
     CUstream Stream = getStream(SrcDevId, AsyncInfo);
 
     // If they are two devices, we try peer to peer copy first
@@ -1107,10 +1061,6 @@
                           const int TeamNum, const int ThreadLimit,
                           const unsigned int LoopTripCount,
                           __tgt_async_info *AsyncInfo) const {
-    CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);
-    if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n"))
-      return OFFLOAD_FAIL;
-
     // All args are references.
     std::vector<void *> Args(ArgNum);
     std::vector<void *> Ptrs(ArgNum);
@@ -1150,6 +1100,7 @@
       CudaThreadsPerBlock = DeviceData[DeviceId].ThreadsPerBlock;
     }
 
+    CUresult Err;
     if (!KernelInfo->MaxThreadsPerBlock) {
       Err = cuFuncGetAttribute(&KernelInfo->MaxThreadsPerBlock,
                                CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK,
@@ -1476,10 +1427,6 @@
   }
 
   int initAsyncInfo(int DeviceId, __tgt_async_info **AsyncInfo) const {
-    CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);
-    if (!checkResult(Err, "error returned from cuCtxSetCurrent"))
-      return OFFLOAD_FAIL;
-
     *AsyncInfo = new __tgt_async_info;
     getStream(DeviceId, *AsyncInfo);
     return OFFLOAD_SUCCESS;
@@ -1503,6 +1450,16 @@
     }
     return OFFLOAD_SUCCESS;
   }
+
+  int setContext(int DeviceId) {
+    assert(InitializedFlags[DeviceId] && "Device is not initialized");
+
+    CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);
+    if (!checkResult(Err, "error returned from cuCtxSetCurrent"))
+      return OFFLOAD_FAIL;
+
+    return OFFLOAD_SUCCESS;
+  }
 };
 
 DeviceRTLTy DeviceRTL;
@@ -1535,12 +1492,14 @@
 
 int32_t __tgt_rtl_init_device(int32_t device_id) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
+  // Context is set when init the device.
 
   return DeviceRTL.initDevice(device_id);
 }
 
 int32_t __tgt_rtl_deinit_device(int32_t device_id) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
+  // Context is set when deinit the device.
 
   return DeviceRTL.deinitDevice(device_id);
 }
@@ -1549,6 +1508,9 @@
                                           __tgt_device_image *image) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
 
+  if (!DeviceRTL.setContext(device_id))
+    return nullptr;
+
   return DeviceRTL.loadBinary(device_id, image);
 }
 
@@ -1556,12 +1518,16 @@
                            int32_t kind) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
 
+  if (!DeviceRTL.setContext(device_id))
+    return nullptr;
+
   return DeviceRTL.dataAlloc(device_id, size, (TargetAllocTy)kind);
 }
 
 int32_t __tgt_rtl_data_submit(int32_t device_id, void *tgt_ptr, void *hst_ptr,
                               int64_t size) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
+  // Context is set in __tgt_rtl_data_submit_async.
 
   __tgt_async_info AsyncInfo;
   const int32_t rc = __tgt_rtl_data_submit_async(device_id, tgt_ptr, hst_ptr,
@@ -1578,6 +1544,9 @@
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
   assert(async_info_ptr && "async_info_ptr is nullptr");
 
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.dataSubmit(device_id, tgt_ptr, hst_ptr, size,
                               async_info_ptr);
 }
@@ -1585,6 +1554,7 @@
 int32_t __tgt_rtl_data_retrieve(int32_t device_id, void *hst_ptr, void *tgt_ptr,
                                 int64_t size) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
+  // Context is set in __tgt_rtl_data_retrieve_async.
 
   __tgt_async_info AsyncInfo;
   const int32_t rc = __tgt_rtl_data_retrieve_async(device_id, hst_ptr, tgt_ptr,
@@ -1601,6 +1571,9 @@
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
   assert(async_info_ptr && "async_info_ptr is nullptr");
 
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.dataRetrieve(device_id, hst_ptr, tgt_ptr, size,
                                 async_info_ptr);
 }
@@ -1612,7 +1585,8 @@
   assert(DeviceRTL.isValidDeviceId(src_dev_id) && "src_dev_id is invalid");
   assert(DeviceRTL.isValidDeviceId(dst_dev_id) && "dst_dev_id is invalid");
   assert(AsyncInfo && "AsyncInfo is nullptr");
-
+  // NOTE: We don't need to set context for data exchange as the device contexts
+  // are passed to CUDA function directly.
   return DeviceRTL.dataExchange(src_dev_id, src_ptr, dst_dev_id, dst_ptr, size,
                                 AsyncInfo);
 }
@@ -1622,6 +1596,7 @@
                                 int64_t size) {
   assert(DeviceRTL.isValidDeviceId(src_dev_id) && "src_dev_id is invalid");
   assert(DeviceRTL.isValidDeviceId(dst_dev_id) && "dst_dev_id is invalid");
+  // Context is set in __tgt_rtl_data_exchange_async.
 
   __tgt_async_info AsyncInfo;
   const int32_t rc = __tgt_rtl_data_exchange_async(
@@ -1635,6 +1610,9 @@
 int32_t __tgt_rtl_data_delete(int32_t device_id, void *tgt_ptr) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
 
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.dataDelete(device_id, tgt_ptr);
 }
 
@@ -1645,6 +1623,7 @@
                                          int32_t thread_limit,
                                          uint64_t loop_tripcount) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
+  // Context is set in __tgt_rtl_run_target_team_region_async.
 
   __tgt_async_info AsyncInfo;
   const int32_t rc = __tgt_rtl_run_target_team_region_async(
@@ -1663,6 +1642,9 @@
     __tgt_async_info *async_info_ptr) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
 
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.runTargetTeamRegion(
       device_id, tgt_entry_ptr, tgt_args, tgt_offsets, arg_num, team_num,
       thread_limit, loop_tripcount, async_info_ptr);
@@ -1672,6 +1654,7 @@
                                     void **tgt_args, ptrdiff_t *tgt_offsets,
                                     int32_t arg_num) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
+  // Context is set in __tgt_rtl_run_target_region_async.
 
   __tgt_async_info AsyncInfo;
   const int32_t rc = __tgt_rtl_run_target_region_async(
@@ -1688,7 +1671,7 @@
                                           int32_t arg_num,
                                           __tgt_async_info *async_info_ptr) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
-
+  // Context is set in __tgt_rtl_run_target_team_region_async.
   return __tgt_rtl_run_target_team_region_async(
       device_id, tgt_entry_ptr, tgt_args, tgt_offsets, arg_num,
       /* team num*/ 1, /* thread_limit */ 1, /* loop_tripcount */ 0,
@@ -1700,7 +1683,7 @@
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
   assert(async_info_ptr && "async_info_ptr is nullptr");
   assert(async_info_ptr->Queue && "async_info_ptr->Queue is nullptr");
-
+  // NOTE: We don't need to set context for stream sync.
   return DeviceRTL.synchronize(device_id, async_info_ptr);
 }
 
@@ -1711,11 +1694,16 @@
 
 void __tgt_rtl_print_device_info(int32_t device_id) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
+  // NOTE: We don't need to set context for print device info.
   DeviceRTL.printDeviceInfo(device_id);
 }
 
 int32_t __tgt_rtl_create_event(int32_t device_id, void **event) {
   assert(event && "event is nullptr");
+
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.createEvent(device_id, event);
 }
 
@@ -1724,7 +1712,7 @@
   assert(async_info_ptr && "async_info_ptr is nullptr");
   assert(async_info_ptr->Queue && "async_info_ptr->Queue is nullptr");
   assert(event_ptr && "event_ptr is nullptr");
-
+  // NOTE: We might not need to set context for event record.
   return recordEvent(event_ptr, async_info_ptr);
 }
 
@@ -1733,19 +1721,22 @@
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
   assert(async_info_ptr && "async_info_ptr is nullptr");
   assert(event_ptr && "event is nullptr");
-
+  // NOTE: We might not need to set context for event sync.
   return DeviceRTL.waitEvent(device_id, async_info_ptr, event_ptr);
 }
 
 int32_t __tgt_rtl_sync_event(int32_t device_id, void *event_ptr) {
   assert(event_ptr && "event is nullptr");
-
+  // NOTE: We might not need to set context for event sync.
   return syncEvent(event_ptr);
 }
 
 int32_t __tgt_rtl_destroy_event(int32_t device_id, void *event_ptr) {
   assert(event_ptr && "event is nullptr");
 
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.destroyEvent(device_id, event_ptr);
 }
 
@@ -1754,6 +1745,9 @@
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
   assert(async_info && "async_info is nullptr");
 
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.releaseAsyncInfo(device_id, async_info);
 }
 
@@ -1762,6 +1756,9 @@
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
   assert(async_info && "async_info is nullptr");
 
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.initAsyncInfo(device_id, async_info);
 }
 
@@ -1771,6 +1768,9 @@
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
   assert(device_info_ptr && "device_info_ptr is nullptr");
 
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.initDeviceInfo(device_id, device_info_ptr, err_str);
 }