diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1935,4 +1935,69 @@ }]; } +def GPU_SDDMMBufferSizeOp : GPU_Op<"sddmm_buffer_size", [GPU_AsyncOpInterface]> { + let summary = "Precompute buffersize for SDDMM operation"; + let description = [{ + The `gpu.sddmm_buffer_size` operation returns the buffer size required + to perform the SDDMM operation on the given sparse and dense matrix. + The operation expects handles returned by previous sparse operations + to construct an environment and the operands for SDDMM. + + If the `async` keyword is present, the op is executed asynchronously (i.e. + it does not block until the execution has finished on the device). In + that case, it returns a !gpu.async.token in addition to the environment. + + Example: + + ```mlir + %buffersz, %token = gpu.sddmm_buffer_size async [%dep] %env, %spmatA, %spmatB, %spmatC + ``` + }]; + + let arguments = (ins Variadic:$asyncDependencies, + GPU_SparseHandle:$env, + GPU_SparseHandle:$dnmatA, + GPU_SparseHandle:$dnmatB, + GPU_SparseHandle:$spmatC); + let results = (outs Res:$bufferSz, Optional:$asyncToken); + + let assemblyFormat = [{ + custom(type($asyncToken), $asyncDependencies) + $env `,` $dnmatA `,` $dnmatB `,` $spmatC attr-dict + }]; +} + +def GPU_SDDMMOp : GPU_Op<"sddmm", [GPU_AsyncOpInterface]> { + let summary = "SDDMM operation"; + let description = [{ + The `gpu.sddmm` operation performs the SDDMM operation on the given sparse and + dense matrix, and buffer. The operation expects handles returned by previous + sparse operations to construct an environment and the operands for SDDMM. The + buffer must have been allocated on the device. + + If the `async` keyword is present, the op is executed asynchronously (i.e. + it does not block until the execution has finished on the device). In + that case, it returns a !gpu.async.token in addition to the environment. + + Example: + + ```mlir + %token = gpu.sddmm async [%dep] %env, %dnmatA, %dnmatB, %spmatC, %buffer + ``` + }]; + + let arguments = (ins Variadic:$asyncDependencies, + GPU_SparseHandle:$env, + GPU_SparseHandle:$dnmatA, + GPU_SparseHandle:$dnmatB, + GPU_SparseHandle:$spmatC, + AnyMemRef:$buffer); + let results = (outs Optional:$asyncToken); + + let assemblyFormat = [{ + custom(type($asyncToken), $asyncDependencies) + $env `,` $dnmatA `,` $dnmatB `,` $spmatC `,` $buffer attr-dict `:` type($buffer) + }]; +} + #endif // GPU_OPS diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -249,11 +249,21 @@ llvmIntPtrType, {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}}; + FunctionCallBuilder SDDMMBufferSizeCallBuilder = { + "mgpuSDDMMBufferSize", + llvmIntPtrType, + {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, + llvmInt32Type, llvmPointerType /* void *stream */}}; FunctionCallBuilder spMMCallBuilder = { "mgpuSpMM", llvmVoidType, {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}}; + FunctionCallBuilder SDDMMCallBuilder = { + "mgpuSDDMM", + llvmVoidType, + {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, + llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}}; }; /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime @@ -596,6 +606,20 @@ ConversionPatternRewriter &rewriter) const override; }; +class ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern + : public ConvertOpToGpuRuntimeCallPattern { +public: + ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern( + LLVMTypeConverter &typeConverter) + : ConvertOpToGpuRuntimeCallPattern( + typeConverter) {} + +private: + LogicalResult + matchAndRewrite(gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + class ConvertSpMMOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: @@ -608,6 +632,18 @@ ConversionPatternRewriter &rewriter) const override; }; +class ConvertSDDMMOpToGpuRuntimeCallPattern + : public ConvertOpToGpuRuntimeCallPattern { +public: + ConvertSDDMMOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) + : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} + +private: + LogicalResult + matchAndRewrite(gpu::SDDMMOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace void GpuToLLVMConversionPass::runOnOperation() { @@ -1447,6 +1483,27 @@ return success(); } +LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + Type dType = getSpMatElemType(op.getSpmatC()); + auto dw = rewriter.create(loc, llvmInt32Type, + dType.getIntOrFloatBitWidth()); + auto stream = adaptor.getAsyncDependencies().front(); + auto bufferSize = + SDDMMBufferSizeCallBuilder + .create(loc, rewriter, + {adaptor.getEnv(), adaptor.getDnmatA(), adaptor.getDnmatB(), + adaptor.getSpmatC(), dw, stream}) + .getResult(); + rewriter.replaceOp(op, {bufferSize, stream}); + return success(); +} + LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite( gpu::SpMMOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -1470,6 +1527,29 @@ return success(); } +LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::SDDMMOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + Type dType = getSpMatElemType(op.getSpmatC()); + auto dw = rewriter.create(loc, llvmInt32Type, + dType.getIntOrFloatBitWidth()); + auto stream = adaptor.getAsyncDependencies().front(); + Value pBuf = + MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); + if (!getTypeConverter()->useOpaquePointers()) + pBuf = rewriter.create(loc, llvmPointerType, pBuf); + SDDMMCallBuilder.create(loc, rewriter, + {adaptor.getEnv(), adaptor.getDnmatA(), + adaptor.getDnmatB(), adaptor.getSpmatC(), dw, pBuf, + stream}); + rewriter.replaceOp(op, {stream}); + return success(); +} + void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef gpuBinaryAnnotation, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -604,6 +604,115 @@ return success(); } +/// Match and rewrite SDDMM kernel. +static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, + linalg::GenericOp op, bool enableRT) { + Location loc = op.getLoc(); + Value a = op.getOperand(0); + Value b = op.getOperand(1); + Value c = op.getOperand(2); // we have C = AB + SmallVector tokens; + + // Only admissible sparse matrix format and dense matrices. + bool isCOO = false; + SparseTensorType aTp = getSparseTensorType(a); + SparseTensorType bTp = getSparseTensorType(b); + SparseTensorType cTp = getSparseTensorType(c); + if (!areAdmissibleTypes(aTp, bTp, cTp, enableRT, isCOO)) + return failure(); + + // TODO: spmm is sparse A, dense B, dense C; SDDMM is dense A, dense B, sparse + // C Start sparse kernel and copy data from host to device. + // a : bufA -> matA + // b : bufB -> matA + // c : memR/memC/memV -> rowC,colC,valC + Value nseC = rewriter.create(loc, a); + Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); + Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); + Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); + Value bufA = genTensorToMemref(rewriter, loc, a); + Value matA = genAllocCopy(rewriter, loc, bufA, tokens); + Value bufB = genTensorToMemref(rewriter, loc, b); + Value matB = genAllocCopy(rewriter, loc, bufB, tokens); + Value memR = genFirstPosOrCrds(rewriter, loc, c, isCOO, enableRT); + Value memC = genSecondCrds(rewriter, loc, c, isCOO, enableRT); + Value memV = genToValues(rewriter, loc, c); + Value rowC = genAllocCopy(rewriter, loc, memR, tokens); + Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); + Value valC = genAllocCopy(rewriter, loc, memV, tokens); + genBlockingWait(rewriter, loc, tokens); + tokens.clear(); + + // Create sparse environment and sparse matrix/dense matrix handles. + Type indexTp = rewriter.getIndexType(); + Type handleTp = rewriter.getType(); + Type tokenTp = rewriter.getType(); + Value token = genFirstWait(rewriter, loc); + auto env = + rewriter.create(loc, handleTp, tokenTp, token); + Value handle = env.getResult(0); + token = env.getAsyncToken(); + + auto dmatA = rewriter.create(loc, handleTp, tokenTp, + token, szm, szk, matA); + Value dnA = dmatA.getResult(0); + token = dmatA.getAsyncToken(); + auto dmatB = rewriter.create(loc, handleTp, tokenTp, + token, szk, szn, matB); + Value dnB = dmatB.getResult(0); + token = dmatB.getAsyncToken(); + + Operation *spGenC = genSpMat(rewriter, loc, handleTp, tokenTp, token, szm, + szn, nseC, rowC, colC, valC, isCOO, enableRT); + Value spMatC = spGenC->getResult(0); + token = spGenC->getResult(1); + + // Precompute buffersize for SDDMM. + auto bufferComp = rewriter.create( + loc, indexTp, tokenTp, token, handle, dnA, dnB, spMatC); + Value bufferSz = bufferComp.getResult(0); + token = bufferComp.getAsyncToken(); + auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); + Value buffer = buf.getResult(0); + token = buf.getAsyncToken(); + + // Perform the SDDMM. + auto sddmmComp = rewriter.create(loc, tokenTp, token, handle, + dnA, dnB, spMatC, buffer); + token = sddmmComp.getAsyncToken(); + + // Copy data back to host and free all the resoures. + token = rewriter.create(loc, tokenTp, token, dnA) + .getAsyncToken(); + token = rewriter.create(loc, tokenTp, token, dnB) + .getAsyncToken(); + token = rewriter.create(loc, tokenTp, token, spMatC) + .getAsyncToken(); + token = rewriter.create(loc, tokenTp, token, handle) + .getAsyncToken(); + tokens.push_back(token); + genBlockingWait(rewriter, loc, tokens); + tokens.clear(); + token = genFirstWait(rewriter, loc); + token = genCopyMemRef(rewriter, loc, memR, rowC, token); + token = genCopyMemRef(rewriter, loc, memC, colC, token); + token = genCopyMemRef(rewriter, loc, memV, valC, token); + token = genDeallocMemRef(rewriter, loc, buffer, token); + token = genDeallocMemRef(rewriter, loc, matA, token); + token = genDeallocMemRef(rewriter, loc, matB, token); + token = genDeallocMemRef(rewriter, loc, rowC, token); + if (colC) + token = genDeallocMemRef(rewriter, loc, colC, token); + token = genDeallocMemRef(rewriter, loc, valC, token); + tokens.push_back(token); + genBlockingWait(rewriter, loc, tokens); + tokens.clear(); + + // Done. + rewriter.replaceOp(op, op.getDpsInitOperand(0)->get()); + return success(); +} + //===----------------------------------------------------------------------===// // Rewriting rules for direct code generation. //===----------------------------------------------------------------------===//