diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -113,10 +113,28 @@ }}; FunctionCallBuilder streamCreateCallBuilder = { "mgpuStreamCreate", llvmPointerType /* void *stream */, {}}; + FunctionCallBuilder streamDestroyCallBuilder = { + "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}}; FunctionCallBuilder streamSynchronizeCallBuilder = { "mgpuStreamSynchronize", llvmVoidType, {llvmPointerType /* void *stream */}}; + FunctionCallBuilder streamWaitEventCallBuilder = { + "mgpuStreamWaitEvent", + llvmVoidType, + {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}}; + FunctionCallBuilder eventCreateCallBuilder = { + "mgpuEventCreate", llvmPointerType /* void *event */, {}}; + FunctionCallBuilder eventDestroyCallBuilder = { + "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}}; + FunctionCallBuilder eventSynchronizeCallBuilder = { + "mgpuEventSynchronize", + llvmVoidType, + {llvmPointerType /* void *event */}}; + FunctionCallBuilder eventRecordCallBuilder = { + "mgpuEventRecord", + llvmVoidType, + {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}}; FunctionCallBuilder hostRegisterCallBuilder = { "mgpuMemHostRegisterMemRef", llvmVoidType, diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp --- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp +++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp @@ -32,6 +32,15 @@ llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \ }(expr) +// Static initialization of CUDA context for device ordinal 0. +static auto InitializeCtx = [] { + CUdevice device; + CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0)); + CUcontext context; + CUDA_REPORT_IF_ERROR(cuCtxCreate(&context, /*flags=*/0, device)); + return 0; +}(); + extern "C" CUmodule mgpuModuleLoad(void *data) { CUmodule module = nullptr; CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); @@ -63,10 +72,36 @@ return stream; } +extern "C" void mgpuStreamDestroy(CUstream stream) { + CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream)); +} + extern "C" void mgpuStreamSynchronize(CUstream stream) { CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream)); } +extern "C" void mgpuStreamWaitEvent(CUstream stream, CUevent event) { + CUDA_REPORT_IF_ERROR(cuStreamWaitEvent(stream, event, /*flags=*/0)); +} + +extern "C" CUevent mgpuEventCreate() { + CUevent event = nullptr; + CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING)); + return event; +} + +extern "C" void mgpuEventDestroy(CUevent event) { + CUDA_REPORT_IF_ERROR(cuEventDestroy(event)); +} + +extern "C" void mgpuEventSynchronize(CUevent event) { + CUDA_REPORT_IF_ERROR(cuEventSynchronize(event)); +} + +extern "C" void mgpuEventRecord(CUevent event, CUstream stream) { + CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream)); +} + /// Helper functions for writing mlir example code // Allows to register byte array with the CUDA runtime. Helpful until we have diff --git a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp --- a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp +++ b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp @@ -31,6 +31,15 @@ llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \ }(expr) +// Static initialization of HIP context for device ordinal 0. +static auto InitializeCtx = [] { + hipDevice_t device; + HIP_REPORT_IF_ERROR(hipDeviceGet(&device, /*ordinal=*/0)); + hipContext_t context; + HIP_REPORT_IF_ERROR(hipCtxCreate(&context, /*flags=*/0, device)); + return 0; +}(); + extern "C" hipModule_t mgpuModuleLoad(void *data) { hipModule_t module = nullptr; HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data)); @@ -58,16 +67,42 @@ stream, params, extra)); } -extern "C" void *mgpuStreamCreate() { +extern "C" hipStream_t mgpuStreamCreate() { hipStream_t stream = nullptr; HIP_REPORT_IF_ERROR(hipStreamCreate(&stream)); return stream; } +extern "C" void mgpuStreamDestroy(hipStream_t stream) { + HIP_REPORT_IF_ERROR(hipStreamDestroy(stream)); +} + extern "C" void mgpuStreamSynchronize(hipStream_t stream) { return HIP_REPORT_IF_ERROR(hipStreamSynchronize(stream)); } +extern "C" void mgpuStreamWaitEvent(hipStream_t stream, hipEvent_t event) { + HIP_REPORT_IF_ERROR(hipStreamWaitEvent(stream, event, /*flags=*/0)); +} + +extern "C" hipEvent_t mgpuEventCreate() { + hipEvent_t event = nullptr; + HIP_REPORT_IF_ERROR(hipEventCreateWithFlags(&event, hipEventDisableTiming)); + return event; +} + +extern "C" void mgpuEventDestroy(hipEvent_t event) { + HIP_REPORT_IF_ERROR(hipEventDestroy(event)); +} + +extern "C" void mgpuEventSynchronize(hipEvent_t event) { + HIP_REPORT_IF_ERROR(hipEventSynchronize(event)); +} + +extern "C" void mgpuEventRecord(hipEvent_t event, hipStream_t stream) { + HIP_REPORT_IF_ERROR(hipEventRecord(event, stream)); +} + /// Helper functions for writing mlir example code // Allows to register byte array with the ROCM runtime. Helpful until we have