diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1963,7 +1963,7 @@ }]; } -def GPU_SpMMBufferSizeOp : GPU_Op<"spmm_buffer_size", [GPU_AsyncOpInterface]> { +def GPU_SpMMBufferSizeOp : GPU_Op<"spmm_buffer_size", [GPU_AsyncOpInterface, AttrSizedResultSegments]> { let summary = "Precompute buffersize for SpMM operation"; let description = [{ The `gpu.spmm_buffer_size` operation returns the buffer size required @@ -1994,8 +1994,7 @@ GPU_SparseDnTensorHandle:$dnmatB, GPU_SparseDnTensorHandle:$dnmatC, TypeAttr:$computeType); - let results = (outs Res]>>:$bufferSzs, + let results = (outs Variadic:$bufferSzs, Optional:$asyncToken); let builders = [OpBuilder<(ins diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -1747,6 +1747,21 @@ rewriter.getIndexAttr(3)); bufferSize = rewriter.create(loc, llvmInt64PointerType, llvmInt64Type, three); + + auto bufferSize0 = + rewriter.create(loc, llvmInt64Type, bufferSize); + + auto bufferSize1 = rewriter.create( + loc, llvmInt64Type, bufferSize0, + rewriter.create( + loc, llvmInt64Type, + rewriter.getIntegerAttr(llvmInt64Type, 8 * sizeof(int)))); + auto bufferSize2 = rewriter.create( + loc, llvmInt64Type, bufferSize1, + rewriter.create( + loc, llvmInt64Type, + rewriter.getIntegerAttr(llvmInt64Type, 8 * sizeof(int)))); + bufferSize = rewriter.create(loc, llvmPointerType, bufferSize); @@ -1756,7 +1771,7 @@ adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(), computeType, stream}) .getResult(); - rewriter.replaceOp(op, {bufferSize, stream}); + rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream}); } else { auto computeType = genConstInt32From( rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); diff --git a/mlir/test/Conversion/GPUCommon/lower-2to4-sparse-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-2to4-sparse-to-gpu-runtime-calls.mlir --- a/mlir/test/Conversion/GPUCommon/lower-2to4-sparse-to-gpu-runtime-calls.mlir +++ b/mlir/test/Conversion/GPUCommon/lower-2to4-sparse-to-gpu-runtime-calls.mlir @@ -23,7 +23,7 @@ %env, %token3 = gpu.create_sparse_env async [%token2] %spmat, %token4 = gpu.create_2to4_spmat async [%token3] %env, %arg0, %arg0, %mem1: memref %dnmat, %token5 = gpu.create_dn_tensor async [%token4] %env, %mem2, %arg0, %arg0 : index, index into memref - %bufferSzs, %token6 = gpu.spmm_buffer_size async [%token5] %env, %spmat, %dnmat, %dnmat : tuple into f16 + %bufferSz0, %bufferSz1, %bufferSz2, %token6 = gpu.spmm_buffer_size async [%token5] %env, %spmat, %dnmat, %dnmat : index,index,index into f16 %token7 = gpu.spmm async [%token6] %env, %spmat, %dnmat, %dnmat, %mem2, %mem2, %mem2 : memref,memref,memref into f16 %token8 = gpu.destroy_sp_mat async [%token7] %spmat %token9 = gpu.destroy_dn_tensor async [%token8] %dnmat diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir @@ -39,7 +39,7 @@ %spmat, %token11 = gpu.create_2to4_spmat async [%token10] %env, %c16, %c32, %d_a: memref<16x32xf16> %dnmat, %token12 = gpu.create_dn_tensor async [%token11] %env, %d_b, %c32, %c16: index, index into memref<32x16xf16> %dnmat2, %token13 = gpu.create_dn_tensor async [%token12] %env, %d_c, %c16, %c16: index, index into memref<16x16xf16> - %bufferSzs, %token14 = gpu.spmm_buffer_size async [%token13] %env, %spmat{NON_TRANSPOSE}, %dnmat{NON_TRANSPOSE}, %dnmat2 : tuple into f16 + %bufferSz0, %bufferSz1, %bufferSz2, %token14 = gpu.spmm_buffer_size async [%token13] %env, %spmat{NON_TRANSPOSE}, %dnmat{NON_TRANSPOSE}, %dnmat2 : index, index,index into f16 %token15 = gpu.spmm async [%token14] %env, %spmat{NON_TRANSPOSE}, %dnmat{NON_TRANSPOSE}, %dnmat2, %mem1, %mem2, %mem3 : memref, memref,memref into f16 %token16 = gpu.destroy_sp_mat async [%token15] %spmat %token17 = gpu.destroy_dn_tensor async [%token16] %dnmat