diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -79,6 +79,28 @@ ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); } }; +#ifdef MLIR_ENABLE_CUDA_CUSPARSE +// Create the cusparse handles once for the duration of the instance +class ScopedCuSparseHandleStorage { +public: + static cusparseHandle_t env; + static bool initiated; + ScopedCuSparseHandleStorage() { + // Static reference to CUDA cuSparse environment handle + if (!initiated) { + CUSPARSE_REPORT_IF_ERROR(cusparseCreate(&env)); + initiated = true; + } + } + + ~ScopedCuSparseHandleStorage() {} +}; + +cusparseHandle_t ScopedCuSparseHandleStorage::env = nullptr; +bool ScopedCuSparseHandleStorage::initiated = false; + +#endif // MLIR_ENABLE_CUDA_CUSPARSE + extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) { ScopedContext scopedContext; CUmodule module = nullptr; @@ -272,15 +294,15 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * mgpuCreateSparseEnv(CUstream /*stream*/) { - cusparseHandle_t handle = nullptr; - CUSPARSE_REPORT_IF_ERROR(cusparseCreate(&handle)) - return reinterpret_cast(handle); + ScopedCuSparseHandleStorage hstorage; + return reinterpret_cast(hstorage.env); } extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuDestroySparseEnv(void *h, CUstream /*stream*/) { - cusparseHandle_t handle = reinterpret_cast(h); - CUSPARSE_REPORT_IF_ERROR(cusparseDestroy(handle)) + ScopedCuSparseHandleStorage hstorage; + CUSPARSE_REPORT_IF_ERROR(cusparseDestroy(hstorage.env)) + hstorage.initiated = false; } extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * @@ -362,7 +384,8 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMVBufferSize(void *h, int32_t ma, void *a, void *x, void *y, int32_t ctp, CUstream /*stream*/) { - cusparseHandle_t handle = reinterpret_cast(h); + ScopedCuSparseHandleStorage hstorage; + cusparseOperation_t modeA = static_cast(ma); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnVecDescr_t vecX = reinterpret_cast(x); @@ -370,9 +393,9 @@ cudaDataType_t cTp = static_cast(ctp); ALPHABETA(cTp, alpha, beta) size_t bufferSize = 0; - CUSPARSE_REPORT_IF_ERROR( - cusparseSpMV_bufferSize(handle, modeA, alphap, matA, vecX, betap, vecY, - cTp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize)) + CUSPARSE_REPORT_IF_ERROR(cusparseSpMV_bufferSize( + hstorage.env, modeA, alphap, matA, vecX, betap, vecY, cTp, + CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize)) return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc } @@ -380,14 +403,15 @@ void *x, void *y, int32_t ctp, void *buf, CUstream /*stream*/) { - cusparseHandle_t handle = reinterpret_cast(h); + + ScopedCuSparseHandleStorage hstorage; cusparseOperation_t modeA = static_cast(ma); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnVecDescr_t vecX = reinterpret_cast(x); cusparseDnVecDescr_t vecY = reinterpret_cast(y); cudaDataType_t cTp = static_cast(ctp); ALPHABETA(cTp, alpha, beta) - CUSPARSE_REPORT_IF_ERROR(cusparseSpMV(handle, modeA, alphap, matA, vecX, + CUSPARSE_REPORT_IF_ERROR(cusparseSpMV(hstorage.env, modeA, alphap, matA, vecX, betap, vecY, cTp, CUSPARSE_SPMV_ALG_DEFAULT, buf)) } @@ -395,7 +419,7 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMMBufferSize(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t ctp, CUstream /*stream*/) { - cusparseHandle_t handle = reinterpret_cast(h); + ScopedCuSparseHandleStorage hstorage; cusparseOperation_t modeA = static_cast(ma); cusparseOperation_t modeB = static_cast(mb); cusparseSpMatDescr_t matA = reinterpret_cast(a); @@ -405,7 +429,7 @@ ALPHABETA(cTp, alpha, beta) size_t bufferSize = 0; CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize( - handle, modeA, modeB, alphap, matA, matB, betap, matC, cTp, + hstorage.env, modeA, modeB, alphap, matA, matB, betap, matC, cTp, CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize)) return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc } @@ -413,7 +437,7 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t ctp, void *buf, CUstream /*stream*/) { - cusparseHandle_t handle = reinterpret_cast(h); + ScopedCuSparseHandleStorage hstorage; cusparseOperation_t modeA = static_cast(ma); cusparseOperation_t modeB = static_cast(mb); cusparseSpMatDescr_t matA = reinterpret_cast(a); @@ -421,8 +445,8 @@ cusparseDnMatDescr_t matC = reinterpret_cast(c); cudaDataType_t cTp = static_cast(ctp); ALPHABETA(cTp, alpha, beta) - CUSPARSE_REPORT_IF_ERROR(cusparseSpMM(handle, modeA, modeB, alphap, matA, - matB, betap, matC, cTp, + CUSPARSE_REPORT_IF_ERROR(cusparseSpMM(hstorage.env, modeA, modeB, alphap, + matA, matB, betap, matC, cTp, CUSPARSE_SPMM_ALG_DEFAULT, buf)) } @@ -430,7 +454,7 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSDDMMBufferSize(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t ctp, CUstream /*stream*/) { - cusparseHandle_t handle = reinterpret_cast(h); + ScopedCuSparseHandleStorage hstorage; cusparseOperation_t modeA = static_cast(ma); cusparseOperation_t modeB = static_cast(mb); cusparseDnMatDescr_t matA = reinterpret_cast(a); @@ -440,7 +464,7 @@ ALPHABETA(cTp, alpha, beta) size_t bufferSize = 0; CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM_bufferSize( - handle, modeA, modeB, alphap, matA, matB, betap, matC, cTp, + hstorage.env, modeA, modeB, alphap, matA, matB, betap, matC, cTp, CUSPARSE_SDDMM_ALG_DEFAULT, &bufferSize)) return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc } @@ -448,7 +472,7 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSDDMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t ctp, void *buf, CUstream /*stream*/) { - cusparseHandle_t handle = reinterpret_cast(h); + ScopedCuSparseHandleStorage hstorage; cusparseOperation_t modeA = static_cast(ma); cusparseOperation_t modeB = static_cast(mb); cusparseDnMatDescr_t matA = reinterpret_cast(a); @@ -456,8 +480,8 @@ cusparseSpMatDescr_t matC = reinterpret_cast(c); auto cTp = static_cast(ctp); ALPHABETA(cTp, alpha, beta) - CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM(handle, modeA, modeB, alphap, matA, - matB, betap, matC, cTp, + CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM(hstorage.env, modeA, modeB, alphap, + matA, matB, betap, matC, cTp, CUSPARSE_SDDMM_ALG_DEFAULT, buf)) }