diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -238,22 +238,22 @@ "mgpuSpMVBufferSize", llvmIntPtrType, {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, - llvmPointerType /* void *stream */}}; + llvmInt32Type, llvmPointerType /* void *stream */}}; FunctionCallBuilder spMVCallBuilder = { "mgpuSpMV", llvmVoidType, {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, - llvmPointerType, llvmPointerType /* void *stream */}}; + llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}}; FunctionCallBuilder spMMBufferSizeCallBuilder = { "mgpuSpMMBufferSize", llvmIntPtrType, {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, - llvmPointerType /* void *stream */}}; + llvmInt32Type, llvmPointerType /* void *stream */}}; FunctionCallBuilder spMMCallBuilder = { "mgpuSpMM", llvmVoidType, {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, - llvmPointerType, llvmPointerType /* void *stream */}}; + llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}}; }; /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime @@ -1186,6 +1186,16 @@ return success(); } +// Returns type of defining spmat op. +// TODO: safer to store data type in actual op instead? +static Type getSpMatType(Value spMat) { + if (auto op = spMat.getDefiningOp()) + return op.getValues().getType().cast().getElementType(); + if (auto op = spMat.getDefiningOp()) + return op.getValues().getType().cast().getElementType(); + llvm_unreachable("cannot find spmat def"); +} + LogicalResult ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite( gpu::CreateSparseEnvOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -1379,12 +1389,16 @@ failed(isAsyncWithOneDependency(rewriter, op))) return failure(); Location loc = op.getLoc(); + Type dType = getSpMatType(op.getSpmatA()); + auto dw = rewriter.create(loc, llvmInt32Type, + dType.getIntOrFloatBitWidth()); auto stream = adaptor.getAsyncDependencies().front(); - auto bufferSize = spMVBufferSizeCallBuilder - .create(loc, rewriter, - {adaptor.getEnv(), adaptor.getSpmatA(), - adaptor.getDnX(), adaptor.getDnY(), stream}) - .getResult(); + auto bufferSize = + spMVBufferSizeCallBuilder + .create(loc, rewriter, + {adaptor.getEnv(), adaptor.getSpmatA(), adaptor.getDnX(), + adaptor.getDnY(), dw, stream}) + .getResult(); rewriter.replaceOp(op, {bufferSize, stream}); return success(); } @@ -1396,6 +1410,9 @@ failed(isAsyncWithOneDependency(rewriter, op))) return failure(); Location loc = op.getLoc(); + Type dType = getSpMatType(op.getSpmatA()); + auto dw = rewriter.create(loc, llvmInt32Type, + dType.getIntOrFloatBitWidth()); auto stream = adaptor.getAsyncDependencies().front(); Value pBuf = MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); @@ -1403,7 +1420,8 @@ pBuf = rewriter.create(loc, llvmPointerType, pBuf); spMVCallBuilder.create(loc, rewriter, {adaptor.getEnv(), adaptor.getSpmatA(), - adaptor.getDnX(), adaptor.getDnY(), pBuf, stream}); + adaptor.getDnX(), adaptor.getDnY(), dw, pBuf, + stream}); rewriter.replaceOp(op, {stream}); return success(); } @@ -1415,12 +1433,15 @@ failed(isAsyncWithOneDependency(rewriter, op))) return failure(); Location loc = op.getLoc(); + Type dType = getSpMatType(op.getSpmatA()); + auto dw = rewriter.create(loc, llvmInt32Type, + dType.getIntOrFloatBitWidth()); auto stream = adaptor.getAsyncDependencies().front(); auto bufferSize = spMMBufferSizeCallBuilder .create(loc, rewriter, {adaptor.getEnv(), adaptor.getSpmatA(), adaptor.getDnmatB(), - adaptor.getDnmatC(), stream}) + adaptor.getDnmatC(), dw, stream}) .getResult(); rewriter.replaceOp(op, {bufferSize, stream}); return success(); @@ -1433,6 +1454,9 @@ failed(isAsyncWithOneDependency(rewriter, op))) return failure(); Location loc = op.getLoc(); + Type dType = getSpMatType(op.getSpmatA()); + auto dw = rewriter.create(loc, llvmInt32Type, + dType.getIntOrFloatBitWidth()); auto stream = adaptor.getAsyncDependencies().front(); Value pBuf = MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); @@ -1440,7 +1464,7 @@ pBuf = rewriter.create(loc, llvmPointerType, pBuf); spMMCallBuilder.create(loc, rewriter, {adaptor.getEnv(), adaptor.getSpmatA(), - adaptor.getDnmatB(), adaptor.getDnmatC(), pBuf, + adaptor.getDnmatB(), adaptor.getDnmatC(), dw, pBuf, stream}); rewriter.replaceOp(op, {stream}); return success(); diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -322,60 +322,68 @@ CUSPARSE_REPORT_IF_ERROR(cusparseDestroySpMat(mat)) } -extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t -mgpuSpMVBufferSize(void *h, void *a, void *x, void *y, CUstream /*stream*/) { +extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMVBufferSize( + void *h, void *a, void *x, void *y, int32_t dw, CUstream /*stream*/) { cusparseHandle_t handle = reinterpret_cast(h); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnVecDescr_t vecX = reinterpret_cast(x); cusparseDnVecDescr_t vecY = reinterpret_cast(y); + cudaDataType_t dtp = dataTp(dw); double alpha = 1.0; double beta = 1.0; size_t bufferSize = 0; CUSPARSE_REPORT_IF_ERROR(cusparseSpMV_bufferSize( handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, vecX, &beta, vecY, - CUDA_R_64F, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize)) + dtp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize)) return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc } -extern "C" MLIR_CUDA_WRAPPERS_EXPORT void -mgpuSpMV(void *h, void *a, void *x, void *y, void *buf, CUstream /*stream*/) { +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMV(void *h, void *a, void *x, + void *y, int32_t dw, + void *buf, + CUstream /*stream*/) { cusparseHandle_t handle = reinterpret_cast(h); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnVecDescr_t vecX = reinterpret_cast(x); cusparseDnVecDescr_t vecY = reinterpret_cast(y); + cudaDataType_t dtp = dataTp(dw); double alpha = 1.0; double beta = 1.0; CUSPARSE_REPORT_IF_ERROR( cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, vecX, - &beta, vecY, CUDA_R_64F, CUSPARSE_SPMV_ALG_DEFAULT, buf)) + &beta, vecY, dtp, CUSPARSE_SPMV_ALG_DEFAULT, buf)) } -extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t -mgpuSpMMBufferSize(void *h, void *a, void *b, void *c, CUstream /*stream*/) { +extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMMBufferSize( + void *h, void *a, void *b, void *c, int32_t dw, CUstream /*stream*/) { cusparseHandle_t handle = reinterpret_cast(h); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnMatDescr_t matB = reinterpret_cast(b); cusparseDnMatDescr_t matC = reinterpret_cast(c); + cudaDataType_t dtp = dataTp(dw); double alpha = 1.0; double beta = 1.0; size_t bufferSize = 0; CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize( handle, CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC, - CUDA_R_64F, CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize)) + CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC, dtp, + CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize)) return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc } -extern "C" MLIR_CUDA_WRAPPERS_EXPORT void -mgpuSpMM(void *h, void *a, void *b, void *c, void *buf, CUstream /*stream*/) { +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMM(void *h, void *a, void *b, + void *c, int32_t dw, + void *buf, + CUstream /*stream*/) { cusparseHandle_t handle = reinterpret_cast(h); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnMatDescr_t matB = reinterpret_cast(b); cusparseDnMatDescr_t matC = reinterpret_cast(c); + cudaDataType_t dtp = dataTp(dw); double alpha = 1.0; double beta = 1.0; CUSPARSE_REPORT_IF_ERROR( cusparseSpMM(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, - matC, CUDA_R_64F, CUSPARSE_SPMM_ALG_DEFAULT, buf)) + matC, dtp, CUSPARSE_SPMM_ALG_DEFAULT, buf)) }