diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp --- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp +++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp @@ -33,16 +33,17 @@ }(expr) // Static initialization of CUDA context for device ordinal 0. -static auto InitializeCtx = [] { +static auto Context = [] { CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); CUdevice device; CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0)); CUcontext context; CUDA_REPORT_IF_ERROR(cuCtxCreate(&context, /*flags=*/0, device)); - return 0; + return context; }(); extern "C" CUmodule mgpuModuleLoad(void *data) { + CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(Context)); CUmodule module = nullptr; CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); return module; @@ -66,12 +67,14 @@ intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, CUstream stream, void **params, void **extra) { + CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(Context)); CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ, smem, stream, params, extra)); } extern "C" CUstream mgpuStreamCreate() { + CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(Context)); CUstream stream = nullptr; CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); return stream; @@ -90,6 +93,7 @@ } extern "C" CUevent mgpuEventCreate() { + CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(Context)); CUevent event = nullptr; CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING)); return event; @@ -108,6 +112,7 @@ } extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, CUstream /*stream*/) { + CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(Context)); CUdeviceptr ptr; CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes)); return reinterpret_cast(ptr); diff --git a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp --- a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp +++ b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp @@ -32,16 +32,17 @@ }(expr) // Static initialization of HIP context for device ordinal 0. -static auto InitializeCtx = [] { +static auto Context = [] { HIP_REPORT_IF_ERROR(hipInit(/*flags=*/0)); hipDevice_t device; HIP_REPORT_IF_ERROR(hipDeviceGet(&device, /*ordinal=*/0)); hipContext_t context; HIP_REPORT_IF_ERROR(hipCtxCreate(&context, /*flags=*/0, device)); - return 0; + return context; }(); extern "C" hipModule_t mgpuModuleLoad(void *data) { + HIP_REPORT_IF_ERROR(hipCtxSetCurrent(Context)); hipModule_t module = nullptr; HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data)); return module; @@ -67,12 +68,14 @@ intptr_t blockZ, int32_t smem, hipStream_t stream, void **params, void **extra) { + HIP_REPORT_IF_ERROR(hipCtxSetCurrent(Context)); HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ, smem, stream, params, extra)); } extern "C" hipStream_t mgpuStreamCreate() { + HIP_REPORT_IF_ERROR(hipCtxSetCurrent(Context)); hipStream_t stream = nullptr; HIP_REPORT_IF_ERROR(hipStreamCreate(&stream)); return stream; @@ -91,6 +94,7 @@ } extern "C" hipEvent_t mgpuEventCreate() { + HIP_REPORT_IF_ERROR(hipCtxSetCurrent(Context)); hipEvent_t event = nullptr; HIP_REPORT_IF_ERROR(hipEventCreateWithFlags(&event, hipEventDisableTiming)); return event; @@ -109,6 +113,7 @@ } extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/) { + HIP_REPORT_IF_ERROR(hipCtxSetCurrent(Context)); void *ptr; HIP_REPORT_IF_ERROR(hipMemAlloc(&ptr, sizeBytes)); return ptr;