diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp --- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp +++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp @@ -32,17 +32,44 @@ llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \ }(expr) -// Static initialization of CUDA context for device ordinal 0. -static auto InitializeCtx = [] { - CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); - CUdevice device; - CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0)); - CUcontext context; - CUDA_REPORT_IF_ERROR(cuCtxCreate(&context, /*flags=*/0, device)); - return 0; -}(); +// Context to use in functions below. +// Can be set with 'mgpuSetContext()', otherwise will be created lazily. +static CUcontext Context = nullptr; + +// Sets the CUDA context for the duration of the instance. +class ScopedContext { +public: + ScopedContext() { + static CUcontext context = MaybeCreateContext(); + CUDA_REPORT_IF_ERROR(cuCtxGetCurrent(&previous)); + CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(context)); + } + + ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(previous)); } + +private: + // Initializes 'Context' if necessary and returns it. + static CUcontext MaybeCreateContext() { + if (Context != nullptr) + return Context; + CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); + CUcontext previous; + CUDA_REPORT_IF_ERROR(cuCtxGetCurrent(&previous)); + CUdevice device; + CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0)); + CUDA_REPORT_IF_ERROR(cuCtxCreate(&Context, /*flags=*/0, device)); + CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(previous)); + return Context; + } + + CUcontext previous; +}; + +// Sets the CUDA context to use in functions below. Not thread safe. +extern "C" void mgpuSetContext(CUcontext context) { Context = context; } extern "C" CUmodule mgpuModuleLoad(void *data) { + ScopedContext scopedContext; CUmodule module = nullptr; CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); return module; @@ -66,12 +93,14 @@ intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, CUstream stream, void **params, void **extra) { + ScopedContext scopedContext; CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ, smem, stream, params, extra)); } extern "C" CUstream mgpuStreamCreate() { + ScopedContext scopedContext; CUstream stream = nullptr; CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); return stream; @@ -90,6 +119,7 @@ } extern "C" CUevent mgpuEventCreate() { + ScopedContext scopedContext; CUevent event = nullptr; CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING)); return event; @@ -108,6 +138,7 @@ } extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, CUstream /*stream*/) { + ScopedContext scopedContext; CUdeviceptr ptr; CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes)); return reinterpret_cast(ptr); diff --git a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp --- a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp +++ b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp @@ -31,17 +31,44 @@ llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \ }(expr) -// Static initialization of HIP context for device ordinal 0. -static auto InitializeCtx = [] { - HIP_REPORT_IF_ERROR(hipInit(/*flags=*/0)); - hipDevice_t device; - HIP_REPORT_IF_ERROR(hipDeviceGet(&device, /*ordinal=*/0)); - hipContext_t context; - HIP_REPORT_IF_ERROR(hipCtxCreate(&context, /*flags=*/0, device)); - return 0; -}(); +// Context to use in functions below. +// Can be set with 'mgpuSetContext()', otherwise will be created lazily. +static hipCtx_t Context = nullptr; + +// Sets the HIP context for the duration of the instance. +class ScopedContext { +public: + ScopedContext() { + static hipCtx_t context = MaybeCreateContext(); + HIP_REPORT_IF_ERROR(hipCtxGetCurrent(&previous)); + HIP_REPORT_IF_ERROR(hipCtxSetCurrent(context)); + } + + ~ScopedContext() { HIP_REPORT_IF_ERROR(hipCtxSetCurrent(previous)); } + +private: + // Initializes 'Context' if necessary and returns it. + static hipCtx_t MaybeCreateContext() { + if (Context != nullptr) + return Context; + HIP_REPORT_IF_ERROR(hipInit(/*flags=*/0)); + hipCtx_t previous; + HIP_REPORT_IF_ERROR(hipCtxGetCurrent(&previous)); + hipDevice_t device; + HIP_REPORT_IF_ERROR(hipDeviceGet(&device, /*ordinal=*/0)); + HIP_REPORT_IF_ERROR(hipCtxCreate(&Context, /*flags=*/0, device)); + HIP_REPORT_IF_ERROR(hipCtxSetCurrent(previous)); + return Context; + } + + hipCtx_t previous; +}; + +// Sets the CUDA context to use in functions below. Not thread safe. +extern "C" void mgpuSetContext(hipCtx_t context) { Context = context; } extern "C" hipModule_t mgpuModuleLoad(void *data) { + ScopedContext scopedContext; hipModule_t module = nullptr; HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data)); return module; @@ -67,12 +94,14 @@ intptr_t blockZ, int32_t smem, hipStream_t stream, void **params, void **extra) { + ScopedContext scopedContext; HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ, smem, stream, params, extra)); } extern "C" hipStream_t mgpuStreamCreate() { + ScopedContext scopedContext; hipStream_t stream = nullptr; HIP_REPORT_IF_ERROR(hipStreamCreate(&stream)); return stream; @@ -91,6 +120,7 @@ } extern "C" hipEvent_t mgpuEventCreate() { + ScopedContext scopedContext; hipEvent_t event = nullptr; HIP_REPORT_IF_ERROR(hipEventCreateWithFlags(&event, hipEventDisableTiming)); return event; @@ -109,6 +139,7 @@ } extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/) { + ScopedContext scopedContext; void *ptr; HIP_REPORT_IF_ERROR(hipMemAlloc(&ptr, sizeBytes)); return ptr;