diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -37,25 +37,26 @@ fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ }(expr) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wglobal-constructors" -// Static reference to CUDA primary context for device ordinal 0. -static CUcontext Context = [] { - CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); - CUdevice device; - CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0)); - CUcontext context; - CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&context, device)); - return context; -}(); -#pragma clang diagnostic pop - // Sets the `Context` for the duration of the instance and restores the previous // context on destruction. class ScopedContext { public: ScopedContext() { + static std::once_flag CuInitFlag; + std::call_once(CuInitFlag, + [] { CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); }); + CUDA_REPORT_IF_ERROR(cuCtxGetCurrent(&previous)); + + // Static reference to CUDA primary context for device ordinal 0. + static CUcontext Context = [] { + CUdevice device; + CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0)); + CUcontext context; + CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&context, device)); + return context; + }(); + CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(Context)); }