diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -37,32 +37,26 @@ fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ }(expr) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wglobal-constructors" -// Static reference to CUDA primary context for device ordinal 0. -static CUcontext Context = [] { - CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); - CUdevice device; - CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0)); - CUcontext context; - CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&context, device)); - return context; -}(); -#pragma clang diagnostic pop - -// Sets the `Context` for the duration of the instance and restores the previous -// context on destruction. +// Make the primary context of device 0 current for the duration of the instance +// and restore the previous context on destruction. class ScopedContext { public: ScopedContext() { - CUDA_REPORT_IF_ERROR(cuCtxGetCurrent(&previous)); - CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(Context)); + // Static reference to CUDA primary context for device ordinal 0. + static CUcontext context = [] { + CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); + CUdevice device; + CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0)); + CUcontext ctx; + // Note: this does not affect the current context. + CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&ctx, device)); + return ctx; + }(); + + CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context)); } - ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(previous)); } - -private: - CUcontext previous; + ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); } }; extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) {