diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -18,6 +18,12 @@ #include "cuda.h" +// We need to know the CUDA version to determine how to map some of the runtime +// calls below. +#if !defined(CUDA_VERSION) +#error "cuda.h did not define CUDA_VERSION" +#endif + #ifdef _WIN32 #define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport) #else @@ -134,15 +140,28 @@ CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream)); } -extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, CUstream /*stream*/) { +extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, CUstream stream) { ScopedContext scopedContext; CUdeviceptr ptr; +#if CUDA_VERSION >= 11020 + // Use the async version that was available since CUDA 11.2. + CUDA_REPORT_IF_ERROR(cuMemAllocAsync(&ptr, sizeBytes, stream)); +#else CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes)); + (void)stream; +#endif return reinterpret_cast(ptr); } -extern "C" void mgpuMemFree(void *ptr, CUstream /*stream*/) { +extern "C" void mgpuMemFree(void *ptr, CUstream stream) { +#if CUDA_VERSION >= 11020 + // Use the async version that was available since CUDA 11.2. + CUDA_REPORT_IF_ERROR( + cuMemFreeAsync(reinterpret_cast(ptr), stream)); +#else CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast(ptr))); + (void)stream; +#endif } extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,