diff --git a/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp b/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp --- a/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp @@ -110,7 +110,11 @@ CUdevice device; RETURN_ON_CUDA_ERROR(cuDeviceGet(&device, 0)); CUcontext context; - RETURN_ON_CUDA_ERROR(cuCtxCreate(&context, 0, device)); + // Use the primary context. + RETURN_ON_CUDA_ERROR(cuDevicePrimaryCtxRetain(&context, device)); + // Push the primary context so that the next CUDA operations + // actually use it. + RETURN_ON_CUDA_ERROR(cuCtxPushCurrent(context)); CUlinkState linkState; CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER, @@ -146,7 +150,10 @@ // This will also destroy the cubin data. RETURN_ON_CUDA_ERROR(cuLinkDestroy(linkState)); - RETURN_ON_CUDA_ERROR(cuCtxDestroy(context)); + // Pop and release the primary context. + CUcontext poppedContext; + RETURN_ON_CUDA_ERROR(cuCtxPopCurrent(&poppedContext)); + RETURN_ON_CUDA_ERROR(cuDevicePrimaryCtxRelease(device)); return result; }