diff --git a/openmp/libomptarget/plugins/cuda/src/rtl.cpp b/openmp/libomptarget/plugins/cuda/src/rtl.cpp --- a/openmp/libomptarget/plugins/cuda/src/rtl.cpp +++ b/openmp/libomptarget/plugins/cuda/src/rtl.cpp @@ -385,9 +385,15 @@ for (DeviceDataTy &D : DeviceData) { // Destroy context - if (D.Context) - checkResult(cuCtxDestroy(D.Context), - "Error returned from cuCtxDestroy\n"); + if (D.Context) { + checkResult(cuCtxSetCurrent(D.Context), + "Error returned from cuCtxSetCurrent\n"); + CUdevice Device; + checkResult(cuCtxGetDevice(&Device), + "Error returned from cuCtxGetDevice\n"); + checkResult(cuDevicePrimaryCtxRelease(Device), + "Error returned from cuDevicePrimaryCtxRelease\n"); + } } } @@ -408,10 +414,32 @@ if (!checkResult(Err, "Error returned from cuDeviceGet\n")) return OFFLOAD_FAIL; - // Create the context and save it to use whenever this device is selected. - Err = cuCtxCreate(&DeviceData[DeviceId].Context, CU_CTX_SCHED_BLOCKING_SYNC, - Device); - if (!checkResult(Err, "Error returned from cuCtxCreate\n")) + // Query the current flags of the primary context and set its flags if + // it is inactive + unsigned int FormerPrimaryCtxFlags = 0; + int FormerPrimaryCtxIsActive = 0; + Err = cuDevicePrimaryCtxGetState(Device, &FormerPrimaryCtxFlags, + &FormerPrimaryCtxIsActive); + if (!checkResult(Err, "Error returned from cuDevicePrimaryCtxGetState\n")) + return OFFLOAD_FAIL; + + if (FormerPrimaryCtxIsActive) { + DP("The primary context is active, no change to its flags\n"); + if ((FormerPrimaryCtxFlags & CU_CTX_SCHED_MASK) != + CU_CTX_SCHED_BLOCKING_SYNC) + DP("Warning the current flags are not CU_CTX_SCHED_BLOCKING_SYNC\n"); + } else { + DP("The primary context is inactive, set its flags to " + "CU_CTX_SCHED_BLOCKING_SYNC\n"); + Err = cuDevicePrimaryCtxSetFlags(Device, CU_CTX_SCHED_BLOCKING_SYNC); + if (!checkResult(Err, "Error returned from cuDevicePrimaryCtxSetFlags\n")) + return OFFLOAD_FAIL; + } + + // Retain the per device primary context and save it to use whenever this + // device is selected. + Err = cuDevicePrimaryCtxRetain(&DeviceData[DeviceId].Context, Device); + if (!checkResult(Err, "Error returned from cuDevicePrimaryCtxRetain\n")) return OFFLOAD_FAIL; Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);