diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp --- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp +++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp @@ -32,17 +32,33 @@ llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \ }(expr) -// Static initialization of CUDA context for device ordinal 0. -static auto InitializeCtx = [] { +// Static reference to CUDA primary context for device ordinal 0. +static CUcontext Context = [] { CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); CUdevice device; CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0)); CUcontext context; - CUDA_REPORT_IF_ERROR(cuCtxCreate(&context, /*flags=*/0, device)); - return 0; + CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&context, device)); + return context; }(); +// Sets the `Context` for the duration of the instance and restores the previous +// context on destruction. +class ScopedContext { +public: + ScopedContext() { + CUDA_REPORT_IF_ERROR(cuCtxGetCurrent(&previous)); + CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(Context)); + } + + ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(previous)); } + +private: + CUcontext previous; +}; + extern "C" CUmodule mgpuModuleLoad(void *data) { + ScopedContext scopedContext; CUmodule module = nullptr; CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); return module; @@ -66,12 +82,14 @@ intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, CUstream stream, void **params, void **extra) { + ScopedContext scopedContext; CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ, smem, stream, params, extra)); } extern "C" CUstream mgpuStreamCreate() { + ScopedContext scopedContext; CUstream stream = nullptr; CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); return stream; @@ -90,6 +108,7 @@ } extern "C" CUevent mgpuEventCreate() { + ScopedContext scopedContext; CUevent event = nullptr; CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING)); return event; @@ -108,6 +127,7 @@ } extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, CUstream /*stream*/) { + ScopedContext scopedContext; CUdeviceptr ptr; CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes)); return reinterpret_cast(ptr); diff --git a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp --- a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp +++ b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp @@ -31,17 +31,33 @@ llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \ }(expr) -// Static initialization of HIP context for device ordinal 0. -static auto InitializeCtx = [] { +// Static reference to HIP primary context for device ordinal 0. +static hipCtx_t Context = [] { HIP_REPORT_IF_ERROR(hipInit(/*flags=*/0)); hipDevice_t device; HIP_REPORT_IF_ERROR(hipDeviceGet(&device, /*ordinal=*/0)); - hipContext_t context; - HIP_REPORT_IF_ERROR(hipCtxCreate(&context, /*flags=*/0, device)); - return 0; + hipCtx_t context; + HIP_REPORT_IF_ERROR(hipDevicePrimaryCtxRetain(&context, device)); + return context; }(); +// Sets the `Context` for the duration of the instance and restores the previous +// context on destruction. +class ScopedContext { +public: + ScopedContext() { + HIP_REPORT_IF_ERROR(hipCtxPushCurrent(&previous)); + HIP_REPORT_IF_ERROR(hipCtxSetCurrent(context)); + } + + ~ScopedContext() { HIP_REPORT_IF_ERROR(hipCtxSetCurrent(previous)); } + +private: + hipCtx_t previous; +}; + extern "C" hipModule_t mgpuModuleLoad(void *data) { + ScopedContext scopedContext; hipModule_t module = nullptr; HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data)); return module; @@ -67,12 +83,14 @@ intptr_t blockZ, int32_t smem, hipStream_t stream, void **params, void **extra) { + ScopedContext scopedContext; HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ, smem, stream, params, extra)); } extern "C" hipStream_t mgpuStreamCreate() { + ScopedContext scopedContext; hipStream_t stream = nullptr; HIP_REPORT_IF_ERROR(hipStreamCreate(&stream)); return stream; @@ -91,6 +109,7 @@ } extern "C" hipEvent_t mgpuEventCreate() { + ScopedContext scopedContext; hipEvent_t event = nullptr; HIP_REPORT_IF_ERROR(hipEventCreateWithFlags(&event, hipEventDisableTiming)); return event; @@ -109,6 +128,7 @@ } extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/) { + ScopedContext scopedContext; void *ptr; HIP_REPORT_IF_ERROR(hipMemAlloc(&ptr, sizeBytes)); return ptr;