diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -249,6 +249,13 @@ } } +static bool isDenseTensorType(Type type) { + auto ttp = llvm::dyn_cast(type); + if (llvm::dyn_cast_or_null(ttp.getEncoding())) + return false; + return true; +} + /// Constructs code for new GPU kernel. static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, scf::ParallelOp forallOp, @@ -656,6 +663,112 @@ return success(); } +// Match and rewrite 2:4 SpMM kernels. +static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, + linalg::GenericOp op) { + Location loc = op.getLoc(); + Value A = op.getOperand(0); + Value B = op.getOperand(1); + Value C = op.getOperand(2); // we have C = AB + SmallVector tokens; + + // All input should be dense tensors. + if (!isDenseTensorType(A.getType()) || !isDenseTensorType(B.getType()) || + !isDenseTensorType(C.getType())) + return failure(); + + // For now, lower to GPU directly. + + Value bufA = genTensorToMemref(rewriter, loc, A); + Value matA = genAllocCopy(rewriter, loc, bufA, tokens); + Value bufB = genTensorToMemref(rewriter, loc, B); + Value matB = genAllocCopy(rewriter, loc, bufB, tokens); + Value bufC = genTensorToMemref(rewriter, loc, C); + Value matC = genAllocCopy(rewriter, loc, bufC, tokens); + genBlockingWait(rewriter, loc, tokens); + tokens.clear(); + Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0); + Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0); + Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1); + + Type indexTp = rewriter.getIndexType(); + Type dnTensorHandleTp = rewriter.getType(); + Type spMatHandleTp = rewriter.getType(); + Type tokenTp = rewriter.getType(); + Value token = genFirstWait(rewriter, loc); + Operation *spGenA = rewriter.create( + loc, spMatHandleTp, tokenTp, token, szm, szk, matA); + + Value spMatA = spGenA->getResult(0); + token = spGenA->getResult(1); + auto dmatB = rewriter.create( + loc, dnTensorHandleTp, tokenTp, token, matB, + SmallVector{szk, szn}); + Value dnB = dmatB.getResult(0); + token = dmatB.getAsyncToken(); + auto dmatC = rewriter.create( + loc, dnTensorHandleTp, tokenTp, token, matC, + SmallVector{szm, szn}); + Value dnC = dmatC.getResult(0); + token = dmatC.getAsyncToken(); + + auto dmatCType = llvm::cast(matC.getType()).getElementType(); + + // Precompute buffersize for SpMM. + SmallVector bufferTypes_{indexTp, indexTp, indexTp}; + TypeRange bufferTypes(bufferTypes_); + auto bufferComp = rewriter.create( + loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE, + gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC, + /*computeType=*/dmatCType); + + token = bufferComp.getAsyncToken(); + Value bufferSz = bufferComp.getResult(0); + auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); + Value buffer = buf.getResult(0); + token = buf.getAsyncToken(); + + Value bufferSz2 = bufferComp.getResult(1); + auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); + Value buffer2 = buf2.getResult(0); + token = buf2.getAsyncToken(); + + Value bufferSz3 = bufferComp.getResult(2); + auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token); + Value buffer3 = buf3.getResult(0); + token = buf3.getAsyncToken(); + + auto dnCType = llvm::cast(matC.getType()).getElementType(); + + // Perform the SpMM. + auto spmmComp = rewriter.create( + loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, + SmallVector{buffer, buffer2, buffer3}); + token = spmmComp.getAsyncToken(); + + // Copy data back to host and free all the resources. + token = rewriter.create(loc, tokenTp, token, spMatA) + .getAsyncToken(); + token = rewriter.create(loc, tokenTp, token, dnB) + .getAsyncToken(); + token = rewriter.create(loc, tokenTp, token, dnC) + .getAsyncToken(); + SmallVector newDynamicSizes; + + token = genDeallocMemRef(rewriter, loc, buffer, token); + token = genDeallocMemRef(rewriter, loc, buffer2, token); + token = genDeallocMemRef(rewriter, loc, buffer3, token); + token = genDeallocMemRef(rewriter, loc, matA, token); + token = genDeallocMemRef(rewriter, loc, matB, token); + token = genCopyMemRef(rewriter, loc, bufC, matC, token); + token = genDeallocMemRef(rewriter, loc, matC, token); + tokens.push_back(token); + genBlockingWait(rewriter, loc, tokens); + tokens.clear(); + rewriter.replaceOpWithNewOp(op, bufC); + return success(); +} + /// Match and rewrite SDDMM kernel. static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT) { @@ -906,7 +1019,12 @@ // TODO: add transposed {i, k}, {k, j} // TODO: maybe add transposed {i, j} in future maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { - return rewriteSpMM(rewriter, op, enableRT); + // TODO: match 2:4 + if (op->getAttr("DENSE24")) { + return rewrite2To4SpMM(rewriter, op); + } else { + return rewriteSpMM(rewriter, op, enableRT); + } } // Recognize a SDDMM kernel.