Index: openmp/libomptarget/plugins/cuda/src/rtl.cpp =================================================================== --- openmp/libomptarget/plugins/cuda/src/rtl.cpp +++ openmp/libomptarget/plugins/cuda/src/rtl.cpp @@ -265,6 +265,7 @@ class DeviceRTLTy { int NumberOfDevices; + bool UsePrimaryCtx; // OpenMP environment properties int EnvNumTeams; int EnvTeamLimit; @@ -333,8 +334,8 @@ DeviceRTLTy(DeviceRTLTy &&) = delete; DeviceRTLTy() - : NumberOfDevices(0), EnvNumTeams(-1), EnvTeamLimit(-1), - RequiresFlags(OMP_REQ_UNDEFINED) { + : NumberOfDevices(0), UsePrimaryCtx(true), EnvNumTeams(-1), + EnvTeamLimit(-1), RequiresFlags(OMP_REQ_UNDEFINED) { #ifdef OMPTARGET_DEBUG if (const char *EnvStr = getenv("LIBOMPTARGET_DEBUG")) DebugLevel = std::stoi(EnvStr); @@ -370,6 +371,18 @@ DP("Parsed OMP_NUM_TEAMS=%d\n", EnvNumTeams); } + // Get environment variables regarding using primary or independent context + if (const char *EnvStr = getenv("LIBOMPTARGET_CUDA_USE_PRIMARY_CONTEXT")) { + std::string Str(EnvStr); + if (!Str.compare("TRUE") || !Str.compare("true")) + UsePrimaryCtx = true; + else if (!Str.compare("FALSE") || !Str.compare("false")) + UsePrimaryCtx = false; + else + DP("Input LIBOMPTARGET_CUDA_USE_PRIMARY_CONTEXT ignored.\n"); + DP("Parsed LIBOMPTARGET_CUDA_USE_PRIMARY_CONTEXT=%s\n", EnvStr); + } + StreamManager = std::make_unique(NumberOfDevices, DeviceData); } @@ -385,9 +398,19 @@ for (DeviceDataTy &D : DeviceData) { // Destroy context - if (D.Context) - checkResult(cuCtxDestroy(D.Context), - "Error returned from cuCtxDestroy\n"); + if (D.Context) { + if (UsePrimaryCtx) { + checkResult(cuCtxSetCurrent(D.Context), + "Error returned from cuCtxSetCurrent\n"); + CUdevice Device; + checkResult(cuCtxGetDevice(&Device), + "Error returned from cuCtxGetDevice\n"); + checkResult(cuDevicePrimaryCtxRelease(Device), + "Error returned from cuDevicePrimaryCtxRelease\n"); + } else + checkResult(cuCtxDestroy(D.Context), + "Error returned from cuCtxDestroy\n"); + } } } @@ -408,11 +431,18 @@ if (!checkResult(Err, "Error returned from cuDeviceGet\n")) return OFFLOAD_FAIL; - // Create the context and save it to use whenever this device is selected. - Err = cuCtxCreate(&DeviceData[DeviceId].Context, CU_CTX_SCHED_BLOCKING_SYNC, - Device); - if (!checkResult(Err, "Error returned from cuCtxCreate\n")) - return OFFLOAD_FAIL; + // Create or retain a context and save it to use whenever this device is + // selected. + if (UsePrimaryCtx) { + Err = cuDevicePrimaryCtxRetain(&DeviceData[DeviceId].Context, Device); + if (!checkResult(Err, "Error returned from cuDevicePrimaryCtxRetain\n")) + return OFFLOAD_FAIL; + } else { + Err = cuCtxCreate(&DeviceData[DeviceId].Context, + CU_CTX_SCHED_BLOCKING_SYNC, Device); + if (!checkResult(Err, "Error returned from cuCtxCreate\n")) + return OFFLOAD_FAIL; + } Err = cuCtxSetCurrent(DeviceData[DeviceId].Context); if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n"))