diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1860,7 +1860,7 @@ Example: ```mlir - %buffersz, %token = gpu.spmv_buffersize async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY + %buffersz, %token = gpu.spmv_buffer_size async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY ``` }]; let arguments = (ins Variadic:$asyncDependencies, @@ -1868,24 +1868,27 @@ GPU_TransposeModeAttr:$modeA, GPU_SparseSpMatHandle:$spmatA, GPU_SparseDnVecHandle:$dnX, - GPU_SparseDnVecHandle:$dnY); + GPU_SparseDnVecHandle:$dnY, + OptionalAttr:$computeType); let results = (outs Res:$bufferSz, Optional:$asyncToken); let builders = [OpBuilder<(ins - "::mlir::Type":$bufferSz, - "::mlir::Type":$asyncToken, - "::mlir::ValueRange":$asyncDependencies, - "::mlir::Value":$env, - "::mlir::Value":$spmatA, - "::mlir::Value":$dnX, - "::mlir::Value":$dnY), [{ + "Type":$bufferSz, + "Type":$asyncToken, + "ValueRange":$asyncDependencies, + "Value":$env, + "Value":$spmatA, + "Value":$dnX, + "Value":$dnY) + , [{ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; - return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, env, - modeA, spmatA, dnX, dnY);}]> + return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, + env, modeA, spmatA, dnX, dnY, {});}]> ]; let assemblyFormat = [{ + (`{` $computeType^ `}`)? custom(type($asyncToken), $asyncDependencies) $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY attr-dict }]; @@ -1919,23 +1922,25 @@ GPU_SparseSpMatHandle:$spmatA, GPU_SparseDnVecHandle:$dnX, GPU_SparseDnVecHandle:$dnY, + OptionalAttr:$computeType, AnyMemRef:$buffer); let results = (outs Optional:$asyncToken); let builders = [OpBuilder<(ins - "::mlir::Type":$asyncToken, - "::mlir::ValueRange":$asyncDependencies, - "::mlir::Value":$env, - "::mlir::Value":$spmatA, - "::mlir::Value":$dnX, - "::mlir::Value":$dnY, - "::mlir::Value":$buffer), [{ + "Type":$asyncToken, + "ValueRange":$asyncDependencies, + "Value":$env, + "Value":$spmatA, + "Value":$dnX, + "Value":$dnY, + "Value":$buffer), [{ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA, - spmatA, dnX, dnY, buffer);}]> + spmatA, dnX, dnY, {}, buffer);}]> ]; let assemblyFormat = [{ + (`{` $computeType^ `}`)? custom(type($asyncToken), $asyncDependencies) $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY `,` $buffer attr-dict `:` type($buffer) }]; @@ -1970,25 +1975,27 @@ GPU_TransposeModeAttr:$modeB, GPU_SparseSpMatHandle:$spmatA, GPU_SparseDnMatHandle:$dnmatB, - GPU_SparseDnMatHandle:$dnmatC); + GPU_SparseDnMatHandle:$dnmatC, + OptionalAttr:$computeType); let results = (outs Res:$bufferSz, Optional:$asyncToken); let builders = [OpBuilder<(ins - "::mlir::Type":$bufferSz, - "::mlir::Type":$asyncToken, - "::mlir::ValueRange":$asyncDependencies, - "::mlir::Value":$env, - "::mlir::Value":$spmatA, - "::mlir::Value":$dnmatB, - "::mlir::Value":$dnmatC), [{ + "Type":$bufferSz, + "Type":$asyncToken, + "ValueRange":$asyncDependencies, + "Value":$env, + "Value":$spmatA, + "Value":$dnmatB, + "Value":$dnmatC), [{ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; auto modeB = gpu::TransposeMode::NON_TRANSPOSE; return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, - env, modeA, modeB, spmatA, dnmatB, dnmatC);}]> + env, modeA, modeB, spmatA, dnmatB, dnmatC, {});}]> ]; let assemblyFormat = [{ + (`{` $computeType^ `}`)? custom(type($asyncToken), $asyncDependencies) $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC attr-dict }]; @@ -2024,24 +2031,26 @@ GPU_SparseSpMatHandle:$spmatA, GPU_SparseDnMatHandle:$dnmatB, GPU_SparseDnMatHandle:$dnmatC, + OptionalAttr:$computeType, AnyMemRef:$buffer); let results = (outs Optional:$asyncToken); let builders = [OpBuilder<(ins - "::mlir::Type":$asyncToken, - "::mlir::ValueRange":$asyncDependencies, - "::mlir::Value":$env, - "::mlir::Value":$spmatA, - "::mlir::Value":$dnmatB, - "::mlir::Value":$dnmatC, - "::mlir::Value":$buffer), [{ + "Type":$asyncToken, + "ValueRange":$asyncDependencies, + "Value":$env, + "Value":$spmatA, + "Value":$dnmatB, + "Value":$dnmatC, + "Value":$buffer), [{ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; auto modeB = gpu::TransposeMode::NON_TRANSPOSE; return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA, - modeB, spmatA, dnmatB, dnmatC, buffer);}]> + modeB, spmatA, dnmatB, dnmatC, {}, buffer);}]> ]; let assemblyFormat = [{ + (`{` $computeType^ `}`)? custom(type($asyncToken), $asyncDependencies) $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC `,` $buffer attr-dict `:` type($buffer) }]; @@ -2076,7 +2085,8 @@ GPU_TransposeModeAttr:$modeB, GPU_SparseDnMatHandle:$dnmatA, GPU_SparseDnMatHandle:$dnmatB, - GPU_SparseSpMatHandle:$spmatC); + GPU_SparseSpMatHandle:$spmatC, + OptionalAttr:$computeType); let results = (outs Res:$bufferSz, Optional:$asyncToken); let builders = [OpBuilder<(ins @@ -2090,10 +2100,11 @@ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; auto modeB = gpu::TransposeMode::NON_TRANSPOSE; return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, - env, modeA, modeB, dnmatA, dnmatB, spmatC);}]> + env, modeA, modeB, dnmatA, dnmatB, spmatC, {});}]> ]; let assemblyFormat = [{ + (`{` $computeType^ `}`)? custom(type($asyncToken), $asyncDependencies) $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC attr-dict }]; @@ -2129,6 +2140,7 @@ GPU_SparseDnMatHandle:$dnmatA, GPU_SparseDnMatHandle:$dnmatB, GPU_SparseSpMatHandle:$spmatC, + OptionalAttr:$computeType, AnyMemRef:$buffer); let results = (outs Optional:$asyncToken); @@ -2143,10 +2155,11 @@ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; auto modeB = gpu::TransposeMode::NON_TRANSPOSE; return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA, - modeB, dnmatA, dnmatB, spmatC, buffer);}]> + modeB, dnmatA, dnmatB, spmatC, {}, buffer);}]> ]; let assemblyFormat = [{ + (`{` $computeType^ `}`)? custom(type($asyncToken), $asyncDependencies) $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $buffer attr-dict `:` type($buffer) }]; diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -238,37 +238,38 @@ "mgpuSpMVBufferSize", llvmIntPtrType, {llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType, - llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}}; + llvmPointerType, llvmInt32Type, llvmInt32Type, + llvmPointerType /* void *stream */}}; FunctionCallBuilder spMVCallBuilder = { "mgpuSpMV", llvmVoidType, {llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType, - llvmPointerType, llvmInt32Type, llvmPointerType, + llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}}; FunctionCallBuilder spMMBufferSizeCallBuilder = { "mgpuSpMMBufferSize", llvmIntPtrType, {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType, - llvmPointerType, llvmPointerType, llvmInt32Type, + llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType /* void *stream */}}; FunctionCallBuilder spMMCallBuilder = { "mgpuSpMM", llvmVoidType, {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType, - llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType, - llvmPointerType /* void *stream */}}; + llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, + llvmPointerType, llvmPointerType /* void *stream */}}; FunctionCallBuilder SDDMMBufferSizeCallBuilder = { "mgpuSDDMMBufferSize", llvmIntPtrType, {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType, - llvmPointerType, llvmPointerType, llvmInt32Type, + llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType /* void *stream */}}; FunctionCallBuilder SDDMMCallBuilder = { "mgpuSDDMM", llvmVoidType, {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType, - llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType, - llvmPointerType /* void *stream */}}; + llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, + llvmPointerType, llvmPointerType /* void *stream */}}; }; /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime @@ -688,6 +689,55 @@ return builder.create(loc, function, arguments); } +// corresponding to cudaDataType_t defined in library_types.h +// TODO: add support to complex types +static int32_t getCuSparseFloatingTypeFrom(unsigned dataWidth) { + if (dataWidth == 16) + return 2; // CUDA_R_16F + if (dataWidth == 32) + return 0; // CUDA_R_32F + if (dataWidth == 64) + return 1; // CUDA_R_64F + llvm_unreachable("unsupported data type"); +} + +static int32_t getCuSparseDataTypeFrom(Type type) { + if (llvm::isa(type)) { + // get the element type + auto elementType = type.cast().getElementType(); + if (elementType.isBF16()) + return 15; // CUDA_C_16BF + if (elementType.isF16()) + return 6; // CUDA_C_16F + if (elementType.isF32()) + return 4; // CUDA_C_32F + if (elementType.isF64()) + return 5; // CUDA_C_64F + if (elementType.isInteger(8)) + return 7; // CUDA_C_8I + if (elementType.isInteger(16)) + return 21; // CUDA_C_16I + if (elementType.isInteger(32)) + return 11; // CUDA_C_32I + } + if (type.isBF16()) + return 14; // CUDA_R_16BF + if (type.isF16()) + return 2; // CUDA_R_16F + if (type.isF32()) + return 0; // CUDA_R_32F + if (type.isF64()) + return 1; // CUDA_R_64F + if (type.isInteger(8)) + return 3; // CUDA_R_8I + if (type.isInteger(16)) + return 20; // CUDA_R_16I + if (type.isInteger(32)) + return 10; // CUDA_R_32I + + llvm_unreachable("unsupported element type"); +} + // Returns whether all operands are of LLVM type. static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter) { @@ -1237,11 +1287,37 @@ llvm_unreachable("cannot find spmat def"); } -static Value genConstFrom(OpBuilder &builder, Location loc, - gpu::TransposeMode mode) { +// Returns the element type of the defining dnmat or dnvec op. +static Type getDnElemType(Value dn) { + if (auto op = dn.getDefiningOp()) + return op.getMemref().getType().getElementType(); + if (auto op = dn.getDefiningOp()) + return op.getMemref().getType().getElementType(); + llvm_unreachable("cannot find dn def"); +} + +template +static Value genConstFrom(OpBuilder &builder, Location loc, T TValue) { Type llvmInt32Type = builder.getIntegerType(32); return builder.create(loc, llvmInt32Type, - static_cast(mode)); + static_cast(TValue)); +} + +// static Value genConstFrom(OpBuilder &builder, Location loc, +// int32_t computeTypeInt) { +// Type llvmInt32Type = builder.getIntegerType(32); +// return builder.create(loc, llvmInt32Type, +// computeTypeInt); +// } + +static Value +genConstFromOptionalComputeMode(OpBuilder &builder, Location loc, + std::optional computeTypeOptional, + Type defaultType) { + auto computeTypeInt = + getCuSparseDataTypeFrom(computeTypeOptional.value_or(defaultType)); + auto computeType = genConstFrom(builder, loc, computeTypeInt); + return computeType; } LogicalResult ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite( @@ -1283,10 +1359,8 @@ MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); if (!getTypeConverter()->useOpaquePointers()) pVec = rewriter.create(loc, llvmPointerType, pVec); - Type dType = - llvm::cast(op.getMemref().getType()).getElementType(); - auto dw = rewriter.create(loc, llvmInt32Type, - dType.getIntOrFloatBitWidth()); + Type dType = op.getMemref().getType().getElementType(); + auto dw = genConstFrom(rewriter, loc, dType.getIntOrFloatBitWidth()); auto handle = createDnVecCallBuilder .create(loc, rewriter, {adaptor.getSize(), pVec, dw, stream}) @@ -1320,10 +1394,8 @@ MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); if (!getTypeConverter()->useOpaquePointers()) pMat = rewriter.create(loc, llvmPointerType, pMat); - Type dType = - llvm::cast(op.getMemref().getType()).getElementType(); - auto dw = rewriter.create(loc, llvmInt32Type, - dType.getIntOrFloatBitWidth()); + Type dType = op.getMemref().getType().getElementType(); + auto dw = genConstFrom(rewriter, loc, dType.getIntOrFloatBitWidth()); auto handle = createDnMatCallBuilder .create(loc, rewriter, @@ -1371,8 +1443,7 @@ llvm::cast(op.getValues().getType()).getElementType(); auto iw = rewriter.create( loc, llvmInt32Type, iType.isIndex() ? 64 : iType.getIntOrFloatBitWidth()); - auto dw = rewriter.create(loc, llvmInt32Type, - dType.getIntOrFloatBitWidth()); + auto dw = genConstFrom(rewriter, loc, dType.getIntOrFloatBitWidth()); auto handle = createCooCallBuilder .create(loc, rewriter, @@ -1412,8 +1483,7 @@ loc, llvmInt32Type, pType.isIndex() ? 64 : pType.getIntOrFloatBitWidth()); auto iw = rewriter.create( loc, llvmInt32Type, iType.isIndex() ? 64 : iType.getIntOrFloatBitWidth()); - auto dw = rewriter.create(loc, llvmInt32Type, - dType.getIntOrFloatBitWidth()); + auto dw = genConstFrom(rewriter, loc, dType.getIntOrFloatBitWidth()); auto handle = createCsrCallBuilder .create(loc, rewriter, @@ -1446,14 +1516,16 @@ Location loc = op.getLoc(); auto modeA = genConstFrom(rewriter, loc, op.getModeA()); Type dType = getSpMatElemType(op.getSpmatA()); - auto dw = rewriter.create(loc, llvmInt32Type, - dType.getIntOrFloatBitWidth()); + auto dw = genConstFrom(rewriter, loc, dType.getIntOrFloatBitWidth()); + // retrieve the compute type, notice that it may be optional + auto computeType = genConstFromOptionalComputeMode( + rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnY())); auto stream = adaptor.getAsyncDependencies().front(); auto bufferSize = spMVBufferSizeCallBuilder .create(loc, rewriter, {adaptor.getEnv(), modeA, adaptor.getSpmatA(), - adaptor.getDnX(), adaptor.getDnY(), dw, stream}) + adaptor.getDnX(), adaptor.getDnY(), dw, computeType, stream}) .getResult(); rewriter.replaceOp(op, {bufferSize, stream}); return success(); @@ -1468,8 +1540,10 @@ Location loc = op.getLoc(); Type dType = getSpMatElemType(op.getSpmatA()); auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA()); - auto dw = rewriter.create(loc, llvmInt32Type, - dType.getIntOrFloatBitWidth()); + auto dw = genConstFrom(rewriter, loc, dType.getIntOrFloatBitWidth()); + // retrieve the compute type, notice that it may be optional + auto computeType = genConstFromOptionalComputeMode( + rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnY())); auto stream = adaptor.getAsyncDependencies().front(); Value pBuf = MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); @@ -1477,8 +1551,8 @@ pBuf = rewriter.create(loc, llvmPointerType, pBuf); spMVCallBuilder.create(loc, rewriter, {adaptor.getEnv(), modeA, adaptor.getSpmatA(), - adaptor.getDnX(), adaptor.getDnY(), dw, pBuf, - stream}); + adaptor.getDnX(), adaptor.getDnY(), dw, computeType, + pBuf, stream}); rewriter.replaceOp(op, {stream}); return success(); } @@ -1493,15 +1567,18 @@ auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA()); auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB()); Type dType = getSpMatElemType(op.getSpmatA()); - auto dw = rewriter.create(loc, llvmInt32Type, - dType.getIntOrFloatBitWidth()); + auto dw = genConstFrom(rewriter, loc, dType.getIntOrFloatBitWidth()); auto stream = adaptor.getAsyncDependencies().front(); - auto bufferSize = - spMMBufferSizeCallBuilder - .create(loc, rewriter, - {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(), - adaptor.getDnmatB(), adaptor.getDnmatC(), dw, stream}) - .getResult(); + // retrieve the compute type, notice that it may be optional + auto computeType = genConstFromOptionalComputeMode( + rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnmatC())); + + auto bufferSize = spMMBufferSizeCallBuilder + .create(loc, rewriter, + {adaptor.getEnv(), modeA, modeB, + adaptor.getSpmatA(), adaptor.getDnmatB(), + adaptor.getDnmatC(), dw, computeType, stream}) + .getResult(); rewriter.replaceOp(op, {bufferSize, stream}); return success(); } @@ -1516,15 +1593,17 @@ auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA()); auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB()); Type dType = getSpMatElemType(op.getSpmatC()); - auto dw = rewriter.create(loc, llvmInt32Type, - dType.getIntOrFloatBitWidth()); + auto dw = genConstFrom(rewriter, loc, dType.getIntOrFloatBitWidth()); + auto computeType = + genConstFromOptionalComputeMode(rewriter, loc, adaptor.getComputeType(), + getSpMatElemType(op.getSpmatC())); auto stream = adaptor.getAsyncDependencies().front(); - auto bufferSize = - SDDMMBufferSizeCallBuilder - .create(loc, rewriter, - {adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(), - adaptor.getDnmatB(), adaptor.getSpmatC(), dw, stream}) - .getResult(); + auto bufferSize = SDDMMBufferSizeCallBuilder + .create(loc, rewriter, + {adaptor.getEnv(), modeA, modeB, + adaptor.getDnmatA(), adaptor.getDnmatB(), + adaptor.getSpmatC(), dw, computeType, stream}) + .getResult(); rewriter.replaceOp(op, {bufferSize, stream}); return success(); } @@ -1539,8 +1618,11 @@ auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA()); auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB()); Type dType = getSpMatElemType(op.getSpmatA()); - auto dw = rewriter.create(loc, llvmInt32Type, - dType.getIntOrFloatBitWidth()); + auto dw = genConstFrom(rewriter, loc, dType.getIntOrFloatBitWidth()); + // retrieve the compute type, notice that it may be optional + auto computeType = genConstFromOptionalComputeMode( + rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnmatC())); + auto stream = adaptor.getAsyncDependencies().front(); Value pBuf = MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); @@ -1548,8 +1630,8 @@ pBuf = rewriter.create(loc, llvmPointerType, pBuf); spMMCallBuilder.create(loc, rewriter, {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(), - adaptor.getDnmatB(), adaptor.getDnmatC(), dw, pBuf, - stream}); + adaptor.getDnmatB(), adaptor.getDnmatC(), dw, + computeType, pBuf, stream}); rewriter.replaceOp(op, {stream}); return success(); } @@ -1570,8 +1652,10 @@ return failure(); Location loc = op.getLoc(); Type dType = getSpMatElemType(op.getSpmatC()); - auto dw = rewriter.create(loc, llvmInt32Type, - dType.getIntOrFloatBitWidth()); + auto dw = genConstFrom(rewriter, loc, dType.getIntOrFloatBitWidth()); + auto computeType = + genConstFromOptionalComputeMode(rewriter, loc, adaptor.getComputeType(), + getSpMatElemType(op.getSpmatC())); auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA()); auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB()); auto stream = adaptor.getAsyncDependencies().front(); @@ -1581,8 +1665,8 @@ pBuf = rewriter.create(loc, llvmPointerType, pBuf); SDDMMCallBuilder.create(loc, rewriter, {adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(), - adaptor.getDnmatB(), adaptor.getSpmatC(), dw, pBuf, - stream}); + adaptor.getDnmatB(), adaptor.getSpmatC(), dw, + computeType, pBuf, stream}); rewriter.replaceOp(op, {stream}); return success(); } diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -18,6 +18,8 @@ #include "cuda.h" #include "cusparse.h" +#include "cuda_fp16.h" +#include "cuda_bf16.h" #ifdef _WIN32 #define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport) @@ -246,14 +248,37 @@ } } +// TODO: to support __nv_bfloat16 +// TODO: to support __half // Some macro magic to get float/double alpha and beta on host. #define ALPHABETA(w, alpha, beta) \ + __nv_bfloat16(alpha##bf16) = 1.0f; \ + __nv_bfloat16(beta##bf16) = 1.0f; \ + __half(alpha##f16) = 1.0f; \ + __half(beta##f16) = 1.0f; \ float(alpha##f) = 1.0f; \ float(beta##f) = 1.0f; \ double(alpha##d) = 1.0; \ double(beta##d) = 1.0; \ const void *(alpha##p) = nullptr; \ const void *(beta##p) = nullptr; \ + /* TODO finish this */ \ + if (dTp == CUDA_R_16BF || dTp == CUDA_C_16BF) { \ + (alpha##p) = reinterpret_cast(&(alpha##16bf)); \ + (beta##p) = reinterpret_cast(&(beta##16bf)); \ + } else if (dTp == CUDA_R_16F || dTp == CUDA_C_16F) { \ + (alpha##p) = reinterpret_cast(&(alpha##16f)); \ + (beta##p) = reinterpret_cast(&(beta##16f)); \ + } else if (dTp == CUDA_R_32F || dTp == CUDA_C_32F) { \ + (alpha##p) = reinterpret_cast(&(alpha##f)); \ + (beta##p) = reinterpret_cast(&(beta##f)); \ + } else if (dTp == CUDA_R_64F || dTp == CUDA_C_64F) { \ + (alpha##p) = reinterpret_cast(&(alpha##d)); \ + (beta##p) = reinterpret_cast(&(beta##d)); \ + } \ + else { \ + llvm_unreachable("Unsupported data type"); \ + } \ if ((w) == 32) { \ (alpha##p) = reinterpret_cast(&(alpha##f)); \ (beta##p) = reinterpret_cast(&(beta##f)); \ @@ -305,6 +330,7 @@ CUSPARSE_REPORT_IF_ERROR(cusparseDestroyDnMat(mat)) } +// TODO: pass in int type and data type instead of width extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * mgpuCreateCoo(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowIdxs, void *colIdxs, void *values, int32_t iw, int32_t dw, @@ -318,6 +344,7 @@ return reinterpret_cast(mat); } +// TODO: pass in int type and data type instead of width extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * mgpuCreateCsr(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowPos, void *colIdxs, void *values, int32_t pw, int32_t iw, int32_t dw, @@ -338,73 +365,78 @@ CUSPARSE_REPORT_IF_ERROR(cusparseDestroySpMat(mat)) } +// TODO: pass in computeType instead of dw. use computeType for alpha and beta extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMVBufferSize(void *h, int32_t ma, void *a, void *x, void *y, int32_t dw, - CUstream /*stream*/) { + int32_t dtp, CUstream /*stream*/) { cusparseHandle_t handle = reinterpret_cast(h); cusparseOperation_t modeA = static_cast(ma); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnVecDescr_t vecX = reinterpret_cast(x); cusparseDnVecDescr_t vecY = reinterpret_cast(y); - cudaDataType_t dtp = dataTp(dw); + cudaDataType_t dTp = static_cast(dtp); ALPHABETA(dw, alpha, beta) size_t bufferSize = 0; CUSPARSE_REPORT_IF_ERROR( cusparseSpMV_bufferSize(handle, modeA, alphap, matA, vecX, betap, vecY, - dtp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize)) + dTp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize)) return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc } +// TODO: pass in computeType instead of dw. use computeType for alpha and beta extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMV(void *h, int32_t ma, void *a, void *x, void *y, int32_t dw, - void *buf, + int32_t dtp, void *buf, CUstream /*stream*/) { cusparseHandle_t handle = reinterpret_cast(h); cusparseOperation_t modeA = static_cast(ma); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnVecDescr_t vecX = reinterpret_cast(x); cusparseDnVecDescr_t vecY = reinterpret_cast(y); - cudaDataType_t dtp = dataTp(dw); + cudaDataType_t dTp = static_cast(dtp); ALPHABETA(dw, alpha, beta) CUSPARSE_REPORT_IF_ERROR(cusparseSpMV(handle, modeA, alphap, matA, vecX, - betap, vecY, dtp, + betap, vecY, dTp, CUSPARSE_SPMV_ALG_DEFAULT, buf)) } +// TODO: pass in computeType instead of dw. use computeType for alpha and beta extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMMBufferSize(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, - int32_t dw, CUstream /*stream*/) { + int32_t dw, int32_t dtp, CUstream /*stream*/) { cusparseHandle_t handle = reinterpret_cast(h); cusparseOperation_t modeA = static_cast(ma); cusparseOperation_t modeB = static_cast(mb); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnMatDescr_t matB = reinterpret_cast(b); cusparseDnMatDescr_t matC = reinterpret_cast(c); - cudaDataType_t dtp = dataTp(dw); + cudaDataType_t dTp = static_cast(dtp); ALPHABETA(dw, alpha, beta) size_t bufferSize = 0; CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize( - handle, modeA, modeB, alphap, matA, matB, betap, matC, dtp, + handle, modeA, modeB, alphap, matA, matB, betap, matC, dTp, CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize)) return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc } +// TODO: pass in computeType instead of dw. use computeType for alpha and beta extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t dw, - void *buf, CUstream /*stream*/) { + int32_t dtp, void *buf, CUstream /*stream*/) { cusparseHandle_t handle = reinterpret_cast(h); cusparseOperation_t modeA = static_cast(ma); cusparseOperation_t modeB = static_cast(mb); cusparseSpMatDescr_t matA = reinterpret_cast(a); cusparseDnMatDescr_t matB = reinterpret_cast(b); cusparseDnMatDescr_t matC = reinterpret_cast(c); - cudaDataType_t dtp = dataTp(dw); + cudaDataType_t dTp = static_cast(dtp); ALPHABETA(dw, alpha, beta) CUSPARSE_REPORT_IF_ERROR(cusparseSpMM(handle, modeA, modeB, alphap, matA, - matB, betap, matC, dtp, + matB, betap, matC, dTp, CUSPARSE_SPMM_ALG_DEFAULT, buf)) } +// TODO: pass in computeType instead of dw. use computeType for alpha and beta extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSDDMMBufferSize(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t dw, CUstream /*stream*/) { @@ -423,6 +455,7 @@ return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc } +// TODO: pass in computeType instead of dw. use computeType for alpha and beta extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSDDMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t dw, void *buf, CUstream /*stream*/) {