diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -79,6 +79,41 @@ ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); } }; +#ifdef MLIR_ENABLE_CUDA_CUSPARSE +// Create the cusparse handles once for the duration of the instance +class ScopedCuSparseHandleStorage { +public: + ScopedCuSparseHandleStorage() { + // Static reference to CUDA cuSparse environment handle + static cusparseHandle_t env = [] { + cusparseHandle_t handle = nullptr; + CUSPARSE_REPORT_IF_ERROR(cusparseCreate(&handle)) + return handle; + }(); + } + + ~ScopedCuSparseHandleStorage() {} +}; + +#ifdef MLIR_ENABLE_CUDA_CUSPARSELT +// Create the cusparseLt handles once for the duration of the instance +class ScopedCuSparseLtHandleStorage { +public: + ScopedCuSparseLtHandleStorage() { + // Static reference to CUDA cuSparseLt environment handle. + static cusparseLtHandle_t env = [] { + cusparseLtHandle_t h; + CUSPARSE_REPORT_IF_ERROR(cusparseLtInit(&h)); + return h; + }(); + } + + ~ScopedCuSparseLtHandleStorage() {} +}; + +#endif // MLIR_ENABLE_CUDA_CUSPARSELT +#endif // MLIR_ENABLE_CUDA_CUSPARSE + extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) { ScopedContext scopedContext; CUmodule module = nullptr;