diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -246,6 +246,19 @@ } } +// Some macro magic to get float/double alpha and beta on host. +#define ALPHABETA(w, alpha, beta) \ + float (alpha##f) = 1.0, (beta##f) = 1.0; \ + double (alpha##d) = 1.0, (beta##d) = 1.0; \ + void *(alpha##p), *(beta##p); \ + if ((w) == 32) { \ + (alpha##p) = reinterpret_cast(&(alpha##f)); \ + (beta##p) = reinterpret_cast(&(beta##f)); \ + } else { \ + (alpha##p) = reinterpret_cast(&(alpha##d)); \ + (beta##p) = reinterpret_cast(&(beta##d)); \ + } + extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * mgpuCreateSparseEnv(CUstream /*stream*/) { cusparseHandle_t handle = nullptr; @@ -329,11 +342,10 @@ cusparseDnVecDescr_t vecX = reinterpret_cast(x); cusparseDnVecDescr_t vecY = reinterpret_cast(y); cudaDataType_t dtp = dataTp(dw); - double alpha = 1.0; - double beta = 1.0; + ALPHABETA(dw, alpha, beta) size_t bufferSize = 0; CUSPARSE_REPORT_IF_ERROR(cusparseSpMV_bufferSize( - handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, vecX, &beta, vecY, + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, alphap, matA, vecX, betap, vecY, dtp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize)) return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc } @@ -347,11 +359,10 @@ cusparseDnVecDescr_t vecX = reinterpret_cast(x); cusparseDnVecDescr_t vecY = reinterpret_cast(y); cudaDataType_t dtp = dataTp(dw); - double alpha = 1.0; - double beta = 1.0; + ALPHABETA(dw, alpha, beta) CUSPARSE_REPORT_IF_ERROR( - cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, vecX, - &beta, vecY, dtp, CUSPARSE_SPMV_ALG_DEFAULT, buf)) + cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, alphap, matA, vecX, + betap, vecY, dtp, CUSPARSE_SPMV_ALG_DEFAULT, buf)) } extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMMBufferSize( @@ -361,12 +372,11 @@ cusparseDnMatDescr_t matB = reinterpret_cast(b); cusparseDnMatDescr_t matC = reinterpret_cast(c); cudaDataType_t dtp = dataTp(dw); - double alpha = 1.0; - double beta = 1.0; + ALPHABETA(dw, alpha, beta) size_t bufferSize = 0; CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize( handle, CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC, dtp, + CUSPARSE_OPERATION_NON_TRANSPOSE, alphap, matA, matB, betap, matC, dtp, CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize)) return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc } @@ -380,10 +390,9 @@ cusparseDnMatDescr_t matB = reinterpret_cast(b); cusparseDnMatDescr_t matC = reinterpret_cast(c); cudaDataType_t dtp = dataTp(dw); - double alpha = 1.0; - double beta = 1.0; + ALPHABETA(dw, alpha, beta) CUSPARSE_REPORT_IF_ERROR( cusparseSpMM(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, + CUSPARSE_OPERATION_NON_TRANSPOSE, alphap, matA, matB, betap, matC, dtp, CUSPARSE_SPMM_ALG_DEFAULT, buf)) }