diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1860,7 +1860,7 @@ Example: ```mlir - %buffersz, %token = gpu.spmv_buffersize async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY + %buffersz, %token = gpu.spmv_buffer_size async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY ``` }]; let arguments = (ins Variadic:$asyncDependencies, @@ -1868,24 +1868,27 @@ GPU_TransposeModeAttr:$modeA, GPU_SparseSpMatHandle:$spmatA, GPU_SparseDnVecHandle:$dnX, - GPU_SparseDnVecHandle:$dnY); + GPU_SparseDnVecHandle:$dnY, + OptionalAttr:$computeType); let results = (outs Res:$bufferSz, Optional:$asyncToken); let builders = [OpBuilder<(ins - "::mlir::Type":$bufferSz, - "::mlir::Type":$asyncToken, - "::mlir::ValueRange":$asyncDependencies, - "::mlir::Value":$env, - "::mlir::Value":$spmatA, - "::mlir::Value":$dnX, - "::mlir::Value":$dnY), [{ + "Type":$bufferSz, + "Type":$asyncToken, + "ValueRange":$asyncDependencies, + "Value":$env, + "Value":$spmatA, + "Value":$dnX, + "Value":$dnY) + , [{ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; - return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, env, - modeA, spmatA, dnX, dnY);}]> + return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, + env, modeA, spmatA, dnX, dnY, {});}]> ]; let assemblyFormat = [{ + (`{` $computeType^ `}`)? custom(type($asyncToken), $asyncDependencies) $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY attr-dict }]; @@ -1919,23 +1922,25 @@ GPU_SparseSpMatHandle:$spmatA, GPU_SparseDnVecHandle:$dnX, GPU_SparseDnVecHandle:$dnY, + OptionalAttr:$computeType, AnyMemRef:$buffer); let results = (outs Optional:$asyncToken); let builders = [OpBuilder<(ins - "::mlir::Type":$asyncToken, - "::mlir::ValueRange":$asyncDependencies, - "::mlir::Value":$env, - "::mlir::Value":$spmatA, - "::mlir::Value":$dnX, - "::mlir::Value":$dnY, - "::mlir::Value":$buffer), [{ + "Type":$asyncToken, + "ValueRange":$asyncDependencies, + "Value":$env, + "Value":$spmatA, + "Value":$dnX, + "Value":$dnY, + "Value":$buffer), [{ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA, - spmatA, dnX, dnY, buffer);}]> + spmatA, dnX, dnY, {}, buffer);}]> ]; let assemblyFormat = [{ + (`{` $computeType^ `}`)? custom(type($asyncToken), $asyncDependencies) $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY `,` $buffer attr-dict `:` type($buffer) }]; @@ -1970,25 +1975,27 @@ GPU_TransposeModeAttr:$modeB, GPU_SparseSpMatHandle:$spmatA, GPU_SparseDnMatHandle:$dnmatB, - GPU_SparseDnMatHandle:$dnmatC); + GPU_SparseDnMatHandle:$dnmatC, + OptionalAttr:$computeType); let results = (outs Res:$bufferSz, Optional:$asyncToken); let builders = [OpBuilder<(ins - "::mlir::Type":$bufferSz, - "::mlir::Type":$asyncToken, - "::mlir::ValueRange":$asyncDependencies, - "::mlir::Value":$env, - "::mlir::Value":$spmatA, - "::mlir::Value":$dnmatB, - "::mlir::Value":$dnmatC), [{ + "Type":$bufferSz, + "Type":$asyncToken, + "ValueRange":$asyncDependencies, + "Value":$env, + "Value":$spmatA, + "Value":$dnmatB, + "Value":$dnmatC), [{ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; auto modeB = gpu::TransposeMode::NON_TRANSPOSE; return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, - env, modeA, modeB, spmatA, dnmatB, dnmatC);}]> + env, modeA, modeB, spmatA, dnmatB, dnmatC, {});}]> ]; let assemblyFormat = [{ + (`{` $computeType^ `}`)? custom(type($asyncToken), $asyncDependencies) $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC attr-dict }]; @@ -2024,24 +2031,26 @@ GPU_SparseSpMatHandle:$spmatA, GPU_SparseDnMatHandle:$dnmatB, GPU_SparseDnMatHandle:$dnmatC, + OptionalAttr:$computeType, AnyMemRef:$buffer); let results = (outs Optional:$asyncToken); let builders = [OpBuilder<(ins - "::mlir::Type":$asyncToken, - "::mlir::ValueRange":$asyncDependencies, - "::mlir::Value":$env, - "::mlir::Value":$spmatA, - "::mlir::Value":$dnmatB, - "::mlir::Value":$dnmatC, - "::mlir::Value":$buffer), [{ + "Type":$asyncToken, + "ValueRange":$asyncDependencies, + "Value":$env, + "Value":$spmatA, + "Value":$dnmatB, + "Value":$dnmatC, + "Value":$buffer), [{ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; auto modeB = gpu::TransposeMode::NON_TRANSPOSE; return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA, - modeB, spmatA, dnmatB, dnmatC, buffer);}]> + modeB, spmatA, dnmatB, dnmatC, {}, buffer);}]> ]; let assemblyFormat = [{ + (`{` $computeType^ `}`)? custom(type($asyncToken), $asyncDependencies) $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC `,` $buffer attr-dict `:` type($buffer) }]; diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -238,25 +238,26 @@ "mgpuSpMVBufferSize", llvmIntPtrType, {llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType, - llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}}; + llvmPointerType, llvmInt32Type, llvmInt32Type, + llvmPointerType /* void *stream */}}; FunctionCallBuilder spMVCallBuilder = { "mgpuSpMV", llvmVoidType, {llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType, - llvmPointerType, llvmInt32Type, llvmPointerType, + llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}}; FunctionCallBuilder spMMBufferSizeCallBuilder = { "mgpuSpMMBufferSize", llvmIntPtrType, {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType, - llvmPointerType, llvmPointerType, llvmInt32Type, + llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType /* void *stream */}}; FunctionCallBuilder spMMCallBuilder = { "mgpuSpMM", llvmVoidType, {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType, - llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType, - llvmPointerType /* void *stream */}}; + llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, + llvmPointerType, llvmPointerType /* void *stream */}}; }; /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime @@ -650,6 +651,24 @@ return builder.create(loc, function, arguments); } +// corresponding to cudaDataType_t defined in library_types.h +// TODO: add support to complex types +static int32_t getCUSparseDataTypeEnumFrom(Type type) { + if (type.isBF16()) + return 14; + if (type.isF16()) + return 2; + if (type.isF32()) + return 0; + if (type.isF64()) + return 1; + if (type.isInteger(8)) + return 3; + if (type.isInteger(32)) + return 10; + llvm_unreachable("unsupported element type"); +} + // Returns whether all operands are of LLVM type. static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter) { @@ -1199,6 +1218,34 @@ llvm_unreachable("cannot find spmat def"); } +// Returns the element type of the defining dnmat or dnvec op. +static Type getDnElemType(Value dn, OpBuilder &builder) { + Type llvmInt32Type = builder.getIntegerType(32); + if (auto op = dn.getDefiningOp()) + return llvm::cast(op.getMemref().getType()).getElementType(); + if (auto op = dn.getDefiningOp()) + return llvm::cast(op.getMemref().getType()).getElementType(); + // the defining op may also be llvm.call after partial lowering + if (auto op = dn.getDefiningOp()) { + if (op.getCallee() == "mgpuCreateDnVec") { + mlir::Attribute dw = + op.getOperand(2).getDefiningOp().getValue(); + if (dw == (mlir::Attribute)builder.getIntegerAttr(llvmInt32Type, 32)) + return builder.getF32Type(); + else if (dw == (mlir::Attribute)builder.getIntegerAttr(llvmInt32Type, 64)) + return builder.getF64Type(); + } else if (op.getCallee() == "mgpuCreateDnMat") { + mlir::Attribute dw = + op.getOperand(3).getDefiningOp().getValue(); + if (dw == (mlir::Attribute)builder.getIntegerAttr(llvmInt32Type, 32)) + return builder.getF32Type(); + else if (dw == (mlir::Attribute)builder.getIntegerAttr(llvmInt32Type, 64)) + return builder.getF64Type(); + } + } + llvm_unreachable("cannot find dn def"); +} + static Value genConstFrom(OpBuilder &builder, Location loc, gpu::TransposeMode mode) { Type llvmInt32Type = builder.getIntegerType(32); @@ -1206,6 +1253,23 @@ static_cast(mode)); } +static Value genConstFrom(OpBuilder &builder, Location loc, + int32_t computeTypeInt) { + Type llvmInt32Type = builder.getIntegerType(32); + return builder.create(loc, llvmInt32Type, computeTypeInt); +} + +static Value +genConstFromOptionalComputeMode(OpBuilder &builder, Location loc, + std::optional computeTypeOptional, + Value result) { + auto computeTypeInt = getCUSparseDataTypeEnumFrom( + computeTypeOptional.has_value() ? computeTypeOptional.value() + : getDnElemType(result, builder)); + auto computeType = genConstFrom(builder, loc, computeTypeInt); + return computeType; +} + LogicalResult ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite( gpu::CreateSparseEnvOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -1245,7 +1309,8 @@ MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); if (!getTypeConverter()->useOpaquePointers()) pVec = rewriter.create(loc, llvmPointerType, pVec); - Type dType = llvm::cast(op.getMemref().getType()).getElementType(); + Type dType = + llvm::cast(op.getMemref().getType()).getElementType(); auto dw = rewriter.create(loc, llvmInt32Type, dType.getIntOrFloatBitWidth()); auto handle = @@ -1281,7 +1346,8 @@ MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); if (!getTypeConverter()->useOpaquePointers()) pMat = rewriter.create(loc, llvmPointerType, pMat); - Type dType = llvm::cast(op.getMemref().getType()).getElementType(); + Type dType = + llvm::cast(op.getMemref().getType()).getElementType(); auto dw = rewriter.create(loc, llvmInt32Type, dType.getIntOrFloatBitWidth()); auto handle = @@ -1325,8 +1391,10 @@ pColIdxs = rewriter.create(loc, llvmPointerType, pColIdxs); pValues = rewriter.create(loc, llvmPointerType, pValues); } - Type iType = llvm::cast(op.getColIdxs().getType()).getElementType(); - Type dType = llvm::cast(op.getValues().getType()).getElementType(); + Type iType = + llvm::cast(op.getColIdxs().getType()).getElementType(); + Type dType = + llvm::cast(op.getValues().getType()).getElementType(); auto iw = rewriter.create( loc, llvmInt32Type, iType.isIndex() ? 64 : iType.getIntOrFloatBitWidth()); auto dw = rewriter.create(loc, llvmInt32Type, @@ -1360,9 +1428,12 @@ pColIdxs = rewriter.create(loc, llvmPointerType, pColIdxs); pValues = rewriter.create(loc, llvmPointerType, pValues); } - Type pType = llvm::cast(op.getRowPos().getType()).getElementType(); - Type iType = llvm::cast(op.getColIdxs().getType()).getElementType(); - Type dType = llvm::cast(op.getValues().getType()).getElementType(); + Type pType = + llvm::cast(op.getRowPos().getType()).getElementType(); + Type iType = + llvm::cast(op.getColIdxs().getType()).getElementType(); + Type dType = + llvm::cast(op.getValues().getType()).getElementType(); auto pw = rewriter.create( loc, llvmInt32Type, pType.isIndex() ? 64 : pType.getIntOrFloatBitWidth()); auto iw = rewriter.create( @@ -1403,12 +1474,15 @@ Type dType = getSpMatElemType(op.getSpmatA()); auto dw = rewriter.create(loc, llvmInt32Type, dType.getIntOrFloatBitWidth()); + // retrieve the compute type, notice that it may be optional + auto computeType = genConstFromOptionalComputeMode( + rewriter, loc, adaptor.getComputeType(), adaptor.getDnY()); auto stream = adaptor.getAsyncDependencies().front(); auto bufferSize = spMVBufferSizeCallBuilder .create(loc, rewriter, {adaptor.getEnv(), modeA, adaptor.getSpmatA(), - adaptor.getDnX(), adaptor.getDnY(), dw, stream}) + adaptor.getDnX(), adaptor.getDnY(), dw, computeType, stream}) .getResult(); rewriter.replaceOp(op, {bufferSize, stream}); return success(); @@ -1425,6 +1499,9 @@ auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA()); auto dw = rewriter.create(loc, llvmInt32Type, dType.getIntOrFloatBitWidth()); + // retrieve the compute type, notice that it may be optional + auto computeType = genConstFromOptionalComputeMode( + rewriter, loc, adaptor.getComputeType(), adaptor.getDnY()); auto stream = adaptor.getAsyncDependencies().front(); Value pBuf = MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); @@ -1432,8 +1509,8 @@ pBuf = rewriter.create(loc, llvmPointerType, pBuf); spMVCallBuilder.create(loc, rewriter, {adaptor.getEnv(), modeA, adaptor.getSpmatA(), - adaptor.getDnX(), adaptor.getDnY(), dw, pBuf, - stream}); + adaptor.getDnX(), adaptor.getDnY(), dw, computeType, + pBuf, stream}); rewriter.replaceOp(op, {stream}); return success(); } @@ -1451,12 +1528,16 @@ auto dw = rewriter.create(loc, llvmInt32Type, dType.getIntOrFloatBitWidth()); auto stream = adaptor.getAsyncDependencies().front(); - auto bufferSize = - spMMBufferSizeCallBuilder - .create(loc, rewriter, - {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(), - adaptor.getDnmatB(), adaptor.getDnmatC(), dw, stream}) - .getResult(); + // retrieve the compute type, notice that it may be optional + auto computeType = genConstFromOptionalComputeMode( + rewriter, loc, adaptor.getComputeType(), adaptor.getDnmatB()); + + auto bufferSize = spMMBufferSizeCallBuilder + .create(loc, rewriter, + {adaptor.getEnv(), modeA, modeB, + adaptor.getSpmatA(), adaptor.getDnmatB(), + adaptor.getDnmatC(), dw, computeType, stream}) + .getResult(); rewriter.replaceOp(op, {bufferSize, stream}); return success(); } @@ -1473,6 +1554,10 @@ dType.getIntOrFloatBitWidth()); auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA()); auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB()); + // retrieve the compute type, notice that it may be optional + auto computeType = genConstFromOptionalComputeMode( + rewriter, loc, adaptor.getComputeType(), adaptor.getDnmatB()); + auto stream = adaptor.getAsyncDependencies().front(); Value pBuf = MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); @@ -1480,8 +1565,8 @@ pBuf = rewriter.create(loc, llvmPointerType, pBuf); spMMCallBuilder.create(loc, rewriter, {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(), - adaptor.getDnmatB(), adaptor.getDnmatC(), dw, pBuf, - stream}); + adaptor.getDnmatB(), adaptor.getDnmatC(), dw, + computeType, pBuf, stream}); rewriter.replaceOp(op, {stream}); return success(); } diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -340,67 +340,67 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMVBufferSize(void *h, int32_t ma, void *a, void *x, void *y, int32_t dw, - CUstream /*stream*/) { + int32_t dtp, CUstream /*stream*/) { cusparseHandle_t handle = reinterpret_cast(h); cusparseOperation_t modeA = static_cast(ma); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnVecDescr_t vecX = reinterpret_cast(x); cusparseDnVecDescr_t vecY = reinterpret_cast(y); - cudaDataType_t dtp = dataTp(dw); + cudaDataType_t dTp = static_cast(dtp); ALPHABETA(dw, alpha, beta) size_t bufferSize = 0; CUSPARSE_REPORT_IF_ERROR( cusparseSpMV_bufferSize(handle, modeA, alphap, matA, vecX, betap, vecY, - dtp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize)) + dTp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize)) return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc } extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMV(void *h, int32_t ma, void *a, void *x, void *y, int32_t dw, - void *buf, + int32_t dtp, void *buf, CUstream /*stream*/) { cusparseHandle_t handle = reinterpret_cast(h); cusparseOperation_t modeA = static_cast(ma); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnVecDescr_t vecX = reinterpret_cast(x); cusparseDnVecDescr_t vecY = reinterpret_cast(y); - cudaDataType_t dtp = dataTp(dw); + cudaDataType_t dTp = static_cast(dtp); ALPHABETA(dw, alpha, beta) CUSPARSE_REPORT_IF_ERROR(cusparseSpMV(handle, modeA, alphap, matA, vecX, - betap, vecY, dtp, + betap, vecY, dTp, CUSPARSE_SPMV_ALG_DEFAULT, buf)) } extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMMBufferSize(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, - int32_t dw, CUstream /*stream*/) { + int32_t dw, int32_t dtp, CUstream /*stream*/) { cusparseHandle_t handle = reinterpret_cast(h); cusparseOperation_t modeA = static_cast(ma); cusparseOperation_t modeB = static_cast(mb); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnMatDescr_t matB = reinterpret_cast(b); cusparseDnMatDescr_t matC = reinterpret_cast(c); - cudaDataType_t dtp = dataTp(dw); + cudaDataType_t dTp = static_cast(dtp); ALPHABETA(dw, alpha, beta) size_t bufferSize = 0; CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize( - handle, modeA, modeB, alphap, matA, matB, betap, matC, dtp, + handle, modeA, modeB, alphap, matA, matB, betap, matC, dTp, CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize)) return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc } extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t dw, - void *buf, CUstream /*stream*/) { + int32_t dtp, void *buf, CUstream /*stream*/) { cusparseHandle_t handle = reinterpret_cast(h); cusparseOperation_t modeA = static_cast(ma); cusparseOperation_t modeB = static_cast(mb); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnMatDescr_t matB = reinterpret_cast(b); cusparseDnMatDescr_t matC = reinterpret_cast(c); - cudaDataType_t dtp = dataTp(dw); + cudaDataType_t dTp = static_cast(dtp); ALPHABETA(dw, alpha, beta) CUSPARSE_REPORT_IF_ERROR(cusparseSpMM(handle, modeA, modeB, alphap, matA, - matB, betap, matC, dtp, + matB, betap, matC, dTp, CUSPARSE_SPMM_ALG_DEFAULT, buf)) }