diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1860,7 +1860,7 @@ Example: ```mlir - %buffersz, %token = gpu.spmv_buffersize async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY + %buffersz, %token = gpu.spmv_buffer_size async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY ``` }]; let arguments = (ins Variadic:$asyncDependencies, @@ -1868,24 +1868,27 @@ GPU_TransposeModeAttr:$modeA, GPU_SparseSpMatHandle:$spmatA, GPU_SparseDnVecHandle:$dnX, - GPU_SparseDnVecHandle:$dnY); + GPU_SparseDnVecHandle:$dnY, + OptionalAttr:$computeType); let results = (outs Res:$bufferSz, Optional:$asyncToken); let builders = [OpBuilder<(ins - "::mlir::Type":$bufferSz, - "::mlir::Type":$asyncToken, - "::mlir::ValueRange":$asyncDependencies, - "::mlir::Value":$env, - "::mlir::Value":$spmatA, - "::mlir::Value":$dnX, - "::mlir::Value":$dnY), [{ + "Type":$bufferSz, + "Type":$asyncToken, + "ValueRange":$asyncDependencies, + "Value":$env, + "Value":$spmatA, + "Value":$dnX, + "Value":$dnY) + , [{ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; - return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, env, - modeA, spmatA, dnX, dnY);}]> + return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, + env, modeA, spmatA, dnX, dnY, {});}]> ]; let assemblyFormat = [{ + (`{` $computeType^ `}`)? custom(type($asyncToken), $asyncDependencies) $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY attr-dict }]; @@ -1919,23 +1922,25 @@ GPU_SparseSpMatHandle:$spmatA, GPU_SparseDnVecHandle:$dnX, GPU_SparseDnVecHandle:$dnY, + OptionalAttr:$computeType, AnyMemRef:$buffer); let results = (outs Optional:$asyncToken); let builders = [OpBuilder<(ins - "::mlir::Type":$asyncToken, - "::mlir::ValueRange":$asyncDependencies, - "::mlir::Value":$env, - "::mlir::Value":$spmatA, - "::mlir::Value":$dnX, - "::mlir::Value":$dnY, - "::mlir::Value":$buffer), [{ + "Type":$asyncToken, + "ValueRange":$asyncDependencies, + "Value":$env, + "Value":$spmatA, + "Value":$dnX, + "Value":$dnY, + "Value":$buffer), [{ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA, - spmatA, dnX, dnY, buffer);}]> + spmatA, dnX, dnY, {}, buffer);}]> ]; let assemblyFormat = [{ + (`{` $computeType^ `}`)? custom(type($asyncToken), $asyncDependencies) $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY `,` $buffer attr-dict `:` type($buffer) }]; @@ -1970,25 +1975,27 @@ GPU_TransposeModeAttr:$modeB, GPU_SparseSpMatHandle:$spmatA, GPU_SparseDnMatHandle:$dnmatB, - GPU_SparseDnMatHandle:$dnmatC); + GPU_SparseDnMatHandle:$dnmatC, + OptionalAttr:$computeType); let results = (outs Res:$bufferSz, Optional:$asyncToken); let builders = [OpBuilder<(ins - "::mlir::Type":$bufferSz, - "::mlir::Type":$asyncToken, - "::mlir::ValueRange":$asyncDependencies, - "::mlir::Value":$env, - "::mlir::Value":$spmatA, - "::mlir::Value":$dnmatB, - "::mlir::Value":$dnmatC), [{ + "Type":$bufferSz, + "Type":$asyncToken, + "ValueRange":$asyncDependencies, + "Value":$env, + "Value":$spmatA, + "Value":$dnmatB, + "Value":$dnmatC), [{ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; auto modeB = gpu::TransposeMode::NON_TRANSPOSE; return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, - env, modeA, modeB, spmatA, dnmatB, dnmatC);}]> + env, modeA, modeB, spmatA, dnmatB, dnmatC, {});}]> ]; let assemblyFormat = [{ + (`{` $computeType^ `}`)? custom(type($asyncToken), $asyncDependencies) $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC attr-dict }]; @@ -2024,24 +2031,26 @@ GPU_SparseSpMatHandle:$spmatA, GPU_SparseDnMatHandle:$dnmatB, GPU_SparseDnMatHandle:$dnmatC, + OptionalAttr:$computeType, AnyMemRef:$buffer); let results = (outs Optional:$asyncToken); let builders = [OpBuilder<(ins - "::mlir::Type":$asyncToken, - "::mlir::ValueRange":$asyncDependencies, - "::mlir::Value":$env, - "::mlir::Value":$spmatA, - "::mlir::Value":$dnmatB, - "::mlir::Value":$dnmatC, - "::mlir::Value":$buffer), [{ + "Type":$asyncToken, + "ValueRange":$asyncDependencies, + "Value":$env, + "Value":$spmatA, + "Value":$dnmatB, + "Value":$dnmatC, + "Value":$buffer), [{ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; auto modeB = gpu::TransposeMode::NON_TRANSPOSE; return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA, - modeB, spmatA, dnmatB, dnmatC, buffer);}]> + modeB, spmatA, dnmatB, dnmatC, {}, buffer);}]> ]; let assemblyFormat = [{ + (`{` $computeType^ `}`)? custom(type($asyncToken), $asyncDependencies) $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC `,` $buffer attr-dict `:` type($buffer) }]; @@ -2076,7 +2085,8 @@ GPU_TransposeModeAttr:$modeB, GPU_SparseDnMatHandle:$dnmatA, GPU_SparseDnMatHandle:$dnmatB, - GPU_SparseSpMatHandle:$spmatC); + GPU_SparseSpMatHandle:$spmatC, + OptionalAttr:$computeType); let results = (outs Res:$bufferSz, Optional:$asyncToken); let builders = [OpBuilder<(ins @@ -2090,10 +2100,11 @@ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; auto modeB = gpu::TransposeMode::NON_TRANSPOSE; return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, - env, modeA, modeB, dnmatA, dnmatB, spmatC);}]> + env, modeA, modeB, dnmatA, dnmatB, spmatC, {});}]> ]; let assemblyFormat = [{ + (`{` $computeType^ `}`)? custom(type($asyncToken), $asyncDependencies) $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC attr-dict }]; @@ -2129,6 +2140,7 @@ GPU_SparseDnMatHandle:$dnmatA, GPU_SparseDnMatHandle:$dnmatB, GPU_SparseSpMatHandle:$spmatC, + OptionalAttr:$computeType, AnyMemRef:$buffer); let results = (outs Optional:$asyncToken); @@ -2143,10 +2155,11 @@ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; auto modeB = gpu::TransposeMode::NON_TRANSPOSE; return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA, - modeB, dnmatA, dnmatB, spmatC, buffer);}]> + modeB, dnmatA, dnmatB, spmatC, {}, buffer);}]> ]; let assemblyFormat = [{ + (`{` $computeType^ `}`)? custom(type($asyncToken), $asyncDependencies) $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $buffer attr-dict `:` type($buffer) }]; diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -238,37 +238,38 @@ "mgpuSpMVBufferSize", llvmIntPtrType, {llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType, - llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}}; + llvmPointerType, llvmInt32Type, llvmInt32Type, + llvmPointerType /* void *stream */}}; FunctionCallBuilder spMVCallBuilder = { "mgpuSpMV", llvmVoidType, {llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType, - llvmPointerType, llvmInt32Type, llvmPointerType, + llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}}; FunctionCallBuilder spMMBufferSizeCallBuilder = { "mgpuSpMMBufferSize", llvmIntPtrType, {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType, - llvmPointerType, llvmPointerType, llvmInt32Type, + llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType /* void *stream */}}; FunctionCallBuilder spMMCallBuilder = { "mgpuSpMM", llvmVoidType, {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType, - llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType, - llvmPointerType /* void *stream */}}; + llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, + llvmPointerType, llvmPointerType /* void *stream */}}; FunctionCallBuilder SDDMMBufferSizeCallBuilder = { "mgpuSDDMMBufferSize", llvmIntPtrType, {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType, - llvmPointerType, llvmPointerType, llvmInt32Type, + llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType /* void *stream */}}; FunctionCallBuilder SDDMMCallBuilder = { "mgpuSDDMM", llvmVoidType, {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType, - llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType, - llvmPointerType /* void *stream */}}; + llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, + llvmPointerType, llvmPointerType /* void *stream */}}; }; /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime @@ -688,6 +689,24 @@ return builder.create(loc, function, arguments); } +// corresponding to cudaDataType_t defined in library_types.h +// TODO: add support to complex types +static int32_t getCUSparseDataTypeEnumFrom(Type type) { + if (type.isBF16()) + return 14; + if (type.isF16()) + return 2; + if (type.isF32()) + return 0; + if (type.isF64()) + return 1; + if (type.isInteger(8)) + return 3; + if (type.isInteger(32)) + return 10; + llvm_unreachable("unsupported element type"); +} + // Returns whether all operands are of LLVM type. static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter) { @@ -1234,9 +1253,70 @@ return llvm::cast(op.getValues().getType()).getElementType(); if (auto op = spMat.getDefiningOp()) return llvm::cast(op.getValues().getType()).getElementType(); + // the defining op may also be llvm.call after partial lowering + if (auto op = spMat.getDefiningOp()) { + if (op.getCallee() == "mgpuCreateCsr") { + mlir::Attribute dw = + op.getOperand(8).getDefiningOp().getValue(); + if (!getConstantIntValue(dw).has_value()) { + llvm_unreachable("expecting dw to be a constant but not"); + } + auto dw_ = getConstantIntValue(dw).value(); + if (dw_ == 32) + return FloatType::getF32(spMat.getContext()); + else if (dw_ == 64) + return FloatType::getF64(spMat.getContext()); + } else if (op.getCallee() == "mgpuCreateCoo") { + mlir::Attribute dw = + op.getOperand(7).getDefiningOp().getValue(); + if (!getConstantIntValue(dw).has_value()) { + llvm_unreachable("expecting dw to be a constant but not"); + } + auto dw_ = getConstantIntValue(dw).value(); + if (dw_ == 32) + return FloatType::getF32(spMat.getContext()); + else if (dw_ == 64) + return FloatType::getF64(spMat.getContext()); + } + } llvm_unreachable("cannot find spmat def"); } +// Returns the element type of the defining dnmat or dnvec op. +static Type getDnElemType(Value dn) { + if (auto op = dn.getDefiningOp()) + return llvm::cast(op.getMemref().getType()).getElementType(); + if (auto op = dn.getDefiningOp()) + return llvm::cast(op.getMemref().getType()).getElementType(); + // the defining op may also be llvm.call after partial lowering + if (auto op = dn.getDefiningOp()) { + if (op.getCallee() == "mgpuCreateDnVec") { + mlir::Attribute dw = + op.getOperand(2).getDefiningOp().getValue(); + if (!getConstantIntValue(dw).has_value()) { + llvm_unreachable("expecting dw to be a constant but not"); + } + auto dw_ = getConstantIntValue(dw).value(); + if (dw_ == 32) + return FloatType::getF32(dn.getContext()); + else if (dw_ == 64) + return FloatType::getF64(dn.getContext()); + } else if (op.getCallee() == "mgpuCreateDnMat") { + mlir::Attribute dw = + op.getOperand(3).getDefiningOp().getValue(); + if (!getConstantIntValue(dw).has_value()) { + llvm_unreachable("expecting dw to be a constant but not"); + } + auto dw_ = getConstantIntValue(dw).value(); + if (dw_ == 32) + return FloatType::getF32(dn.getContext()); + else if (dw_ == 64) + return FloatType::getF64(dn.getContext()); + } + } + llvm_unreachable("cannot find dn def"); +} + static Value genConstFrom(OpBuilder &builder, Location loc, gpu::TransposeMode mode) { Type llvmInt32Type = builder.getIntegerType(32); @@ -1244,6 +1324,23 @@ static_cast(mode)); } +static Value genConstFrom(OpBuilder &builder, Location loc, + int32_t computeTypeInt) { + Type llvmInt32Type = builder.getIntegerType(32); + return builder.create(loc, llvmInt32Type, computeTypeInt); +} + +static Value +genConstFromOptionalComputeMode(OpBuilder &builder, Location loc, + std::optional computeTypeOptional, + Type defaultType) { + auto computeTypeInt = getCUSparseDataTypeEnumFrom( + computeTypeOptional.has_value() ? computeTypeOptional.value() + : defaultType); + auto computeType = genConstFrom(builder, loc, computeTypeInt); + return computeType; +} + LogicalResult ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite( gpu::CreateSparseEnvOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -1448,12 +1545,15 @@ Type dType = getSpMatElemType(op.getSpmatA()); auto dw = rewriter.create(loc, llvmInt32Type, dType.getIntOrFloatBitWidth()); + // retrieve the compute type, notice that it may be optional + auto computeType = genConstFromOptionalComputeMode( + rewriter, loc, adaptor.getComputeType(), getDnElemType(adaptor.getDnY())); auto stream = adaptor.getAsyncDependencies().front(); auto bufferSize = spMVBufferSizeCallBuilder .create(loc, rewriter, {adaptor.getEnv(), modeA, adaptor.getSpmatA(), - adaptor.getDnX(), adaptor.getDnY(), dw, stream}) + adaptor.getDnX(), adaptor.getDnY(), dw, computeType, stream}) .getResult(); rewriter.replaceOp(op, {bufferSize, stream}); return success(); @@ -1470,6 +1570,9 @@ auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA()); auto dw = rewriter.create(loc, llvmInt32Type, dType.getIntOrFloatBitWidth()); + // retrieve the compute type, notice that it may be optional + auto computeType = genConstFromOptionalComputeMode( + rewriter, loc, adaptor.getComputeType(), getDnElemType(adaptor.getDnY())); auto stream = adaptor.getAsyncDependencies().front(); Value pBuf = MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); @@ -1477,8 +1580,8 @@ pBuf = rewriter.create(loc, llvmPointerType, pBuf); spMVCallBuilder.create(loc, rewriter, {adaptor.getEnv(), modeA, adaptor.getSpmatA(), - adaptor.getDnX(), adaptor.getDnY(), dw, pBuf, - stream}); + adaptor.getDnX(), adaptor.getDnY(), dw, computeType, + pBuf, stream}); rewriter.replaceOp(op, {stream}); return success(); } @@ -1496,12 +1599,17 @@ auto dw = rewriter.create(loc, llvmInt32Type, dType.getIntOrFloatBitWidth()); auto stream = adaptor.getAsyncDependencies().front(); - auto bufferSize = - spMMBufferSizeCallBuilder - .create(loc, rewriter, - {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(), - adaptor.getDnmatB(), adaptor.getDnmatC(), dw, stream}) - .getResult(); + // retrieve the compute type, notice that it may be optional + auto computeType = + genConstFromOptionalComputeMode(rewriter, loc, adaptor.getComputeType(), + getDnElemType(adaptor.getDnmatC())); + + auto bufferSize = spMMBufferSizeCallBuilder + .create(loc, rewriter, + {adaptor.getEnv(), modeA, modeB, + adaptor.getSpmatA(), adaptor.getDnmatB(), + adaptor.getDnmatC(), dw, computeType, stream}) + .getResult(); rewriter.replaceOp(op, {bufferSize, stream}); return success(); } @@ -1518,13 +1626,16 @@ Type dType = getSpMatElemType(op.getSpmatC()); auto dw = rewriter.create(loc, llvmInt32Type, dType.getIntOrFloatBitWidth()); + auto computeType = + genConstFromOptionalComputeMode(rewriter, loc, adaptor.getComputeType(), + getSpMatElemType(adaptor.getSpmatC())); auto stream = adaptor.getAsyncDependencies().front(); - auto bufferSize = - SDDMMBufferSizeCallBuilder - .create(loc, rewriter, - {adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(), - adaptor.getDnmatB(), adaptor.getSpmatC(), dw, stream}) - .getResult(); + auto bufferSize = SDDMMBufferSizeCallBuilder + .create(loc, rewriter, + {adaptor.getEnv(), modeA, modeB, + adaptor.getDnmatA(), adaptor.getDnmatB(), + adaptor.getSpmatC(), dw, computeType, stream}) + .getResult(); rewriter.replaceOp(op, {bufferSize, stream}); return success(); } @@ -1541,6 +1652,11 @@ Type dType = getSpMatElemType(op.getSpmatA()); auto dw = rewriter.create(loc, llvmInt32Type, dType.getIntOrFloatBitWidth()); + // retrieve the compute type, notice that it may be optional + auto computeType = + genConstFromOptionalComputeMode(rewriter, loc, adaptor.getComputeType(), + getDnElemType(adaptor.getDnmatC())); + auto stream = adaptor.getAsyncDependencies().front(); Value pBuf = MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); @@ -1548,8 +1664,8 @@ pBuf = rewriter.create(loc, llvmPointerType, pBuf); spMMCallBuilder.create(loc, rewriter, {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(), - adaptor.getDnmatB(), adaptor.getDnmatC(), dw, pBuf, - stream}); + adaptor.getDnmatB(), adaptor.getDnmatC(), dw, + computeType, pBuf, stream}); rewriter.replaceOp(op, {stream}); return success(); } @@ -1572,6 +1688,9 @@ Type dType = getSpMatElemType(op.getSpmatC()); auto dw = rewriter.create(loc, llvmInt32Type, dType.getIntOrFloatBitWidth()); + auto computeType = + genConstFromOptionalComputeMode(rewriter, loc, adaptor.getComputeType(), + getSpMatElemType(adaptor.getSpmatC())); auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA()); auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB()); auto stream = adaptor.getAsyncDependencies().front(); @@ -1581,8 +1700,8 @@ pBuf = rewriter.create(loc, llvmPointerType, pBuf); SDDMMCallBuilder.create(loc, rewriter, {adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(), - adaptor.getDnmatB(), adaptor.getSpmatC(), dw, pBuf, - stream}); + adaptor.getDnmatB(), adaptor.getSpmatC(), dw, + computeType, pBuf, stream}); rewriter.replaceOp(op, {stream}); return success(); } diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -340,68 +340,68 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMVBufferSize(void *h, int32_t ma, void *a, void *x, void *y, int32_t dw, - CUstream /*stream*/) { + int32_t dtp, CUstream /*stream*/) { cusparseHandle_t handle = reinterpret_cast(h); cusparseOperation_t modeA = static_cast(ma); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnVecDescr_t vecX = reinterpret_cast(x); cusparseDnVecDescr_t vecY = reinterpret_cast(y); - cudaDataType_t dtp = dataTp(dw); + cudaDataType_t dTp = static_cast(dtp); ALPHABETA(dw, alpha, beta) size_t bufferSize = 0; CUSPARSE_REPORT_IF_ERROR( cusparseSpMV_bufferSize(handle, modeA, alphap, matA, vecX, betap, vecY, - dtp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize)) + dTp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize)) return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc } extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMV(void *h, int32_t ma, void *a, void *x, void *y, int32_t dw, - void *buf, + int32_t dtp, void *buf, CUstream /*stream*/) { cusparseHandle_t handle = reinterpret_cast(h); cusparseOperation_t modeA = static_cast(ma); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnVecDescr_t vecX = reinterpret_cast(x); cusparseDnVecDescr_t vecY = reinterpret_cast(y); - cudaDataType_t dtp = dataTp(dw); + cudaDataType_t dTp = static_cast(dtp); ALPHABETA(dw, alpha, beta) CUSPARSE_REPORT_IF_ERROR(cusparseSpMV(handle, modeA, alphap, matA, vecX, - betap, vecY, dtp, + betap, vecY, dTp, CUSPARSE_SPMV_ALG_DEFAULT, buf)) } extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMMBufferSize(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, - int32_t dw, CUstream /*stream*/) { + int32_t dw, int32_t dtp, CUstream /*stream*/) { cusparseHandle_t handle = reinterpret_cast(h); cusparseOperation_t modeA = static_cast(ma); cusparseOperation_t modeB = static_cast(mb); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnMatDescr_t matB = reinterpret_cast(b); cusparseDnMatDescr_t matC = reinterpret_cast(c); - cudaDataType_t dtp = dataTp(dw); + cudaDataType_t dTp = static_cast(dtp); ALPHABETA(dw, alpha, beta) size_t bufferSize = 0; CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize( - handle, modeA, modeB, alphap, matA, matB, betap, matC, dtp, + handle, modeA, modeB, alphap, matA, matB, betap, matC, dTp, CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize)) return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc } extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t dw, - void *buf, CUstream /*stream*/) { + int32_t dtp, void *buf, CUstream /*stream*/) { cusparseHandle_t handle = reinterpret_cast(h); cusparseOperation_t modeA = static_cast(ma); cusparseOperation_t modeB = static_cast(mb); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnMatDescr_t matB = reinterpret_cast(b); cusparseDnMatDescr_t matC = reinterpret_cast(c); - cudaDataType_t dtp = dataTp(dw); + cudaDataType_t dTp = static_cast(dtp); ALPHABETA(dw, alpha, beta) CUSPARSE_REPORT_IF_ERROR(cusparseSpMM(handle, modeA, modeB, alphap, matA, - matB, betap, matC, dtp, + matB, betap, matC, dTp, CUSPARSE_SPMM_ALG_DEFAULT, buf)) }