diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -246,6 +246,19 @@ } } +// Some macro magic to get float/double alpha and beta on host. +#define ALPHABETA(w) \ + float alphaf = 1.0, betaf = 1.0; \ + double alphad = 1.0, betad = 1.0; \ + void *alphap, *betap; \ + if ((w) == 32) { \ + alphap = reinterpret_cast(&alphaf); \ + betap = reinterpret_cast(&betaf); \ + } else { \ + alphap = reinterpret_cast(&alphad); \ + betap = reinterpret_cast(&betad); \ + } + extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * mgpuCreateSparseEnv(CUstream /*stream*/) { cusparseHandle_t handle = nullptr; @@ -329,11 +342,10 @@ cusparseDnVecDescr_t vecX = reinterpret_cast(x); cusparseDnVecDescr_t vecY = reinterpret_cast(y); cudaDataType_t dtp = dataTp(dw); - double alpha = 1.0; - double beta = 1.0; + ALPHABETA(dw) size_t bufferSize = 0; CUSPARSE_REPORT_IF_ERROR(cusparseSpMV_bufferSize( - handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, vecX, &beta, vecY, + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, alphap, matA, vecX, betap, vecY, dtp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize)) return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc } @@ -347,11 +359,10 @@ cusparseDnVecDescr_t vecX = reinterpret_cast(x); cusparseDnVecDescr_t vecY = reinterpret_cast(y); cudaDataType_t dtp = dataTp(dw); - double alpha = 1.0; - double beta = 1.0; + ALPHABETA(dw) CUSPARSE_REPORT_IF_ERROR( - cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, vecX, - &beta, vecY, dtp, CUSPARSE_SPMV_ALG_DEFAULT, buf)) + cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, alphap, matA, vecX, + betap, vecY, dtp, CUSPARSE_SPMV_ALG_DEFAULT, buf)) } extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMMBufferSize( @@ -361,12 +372,11 @@ cusparseDnMatDescr_t matB = reinterpret_cast(b); cusparseDnMatDescr_t matC = reinterpret_cast(c); cudaDataType_t dtp = dataTp(dw); - double alpha = 1.0; - double beta = 1.0; + ALPHABETA(dw) size_t bufferSize = 0; CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize( handle, CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC, dtp, + CUSPARSE_OPERATION_NON_TRANSPOSE, alphap, matA, matB, betap, matC, dtp, CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize)) return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc } @@ -380,10 +390,9 @@ cusparseDnMatDescr_t matB = reinterpret_cast(b); cusparseDnMatDescr_t matC = reinterpret_cast(c); cudaDataType_t dtp = dataTp(dw); - double alpha = 1.0; - double beta = 1.0; + ALPHABETA(dw) CUSPARSE_REPORT_IF_ERROR( cusparseSpMM(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, + CUSPARSE_OPERATION_NON_TRANSPOSE, alphap, matA, matB, betap, matC, dtp, CUSPARSE_SPMM_ALG_DEFAULT, buf)) }