diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -151,13 +151,25 @@ token, dynamicSizes, ValueRange()); } +// Allocates a typed buffer on the host with given size. +static Value genHostBuffer(OpBuilder &builder, Location loc, Type type, + Value size) { + const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); + return builder.create(loc, memTp, size).getResult(); +} + +// Allocates a typed buffer on the device with given size. +static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type, + Value size, Value token) { + const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); + return builder.create(loc, TypeRange({memTp, token.getType()}), + token, size, ValueRange()); +} + // Allocates a void buffer on the device with given size. static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size, Value token) { - const auto memTp = - MemRefType::get({ShapedType::kDynamic}, builder.getI8Type()); - return builder.create(loc, TypeRange({memTp, token.getType()}), - token, size, ValueRange()); + return genAllocBuffer(builder, loc, builder.getI8Type(), size, token); } /// Deallocates memory from the device. @@ -198,7 +210,6 @@ /// assume that the first buffer is the one allocated for output. We create /// a set of properly chained asynchronous allocation/copy pairs to increase /// overlap before launching the kernel. -/// TODO: the output assumption may be a bit too brittle static Value genParametersIn(OpBuilder &builder, Location loc, SmallVectorImpl &scalars, SmallVectorImpl &buffers, @@ -571,6 +582,7 @@ token = genDeallocMemRef(rewriter, loc, vecY, token); tokens.push_back(token); genBlockingWait(rewriter, loc, tokens); + tokens.clear(); if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { genHostUnregisterMemref(rewriter, loc, castR); if (memC) @@ -579,7 +591,6 @@ genHostUnregisterMemref(rewriter, loc, castX); genHostUnregisterMemref(rewriter, loc, castY); } - tokens.clear(); // Done. rewriter.replaceOpWithNewOp(op, memY); @@ -630,7 +641,6 @@ castB = genHostRegisterMemref(rewriter, loc, bufB); castBufC = genHostRegisterMemref(rewriter, loc, bufC); } - Value rowA = genAllocCopy(rewriter, loc, memR, tokens); Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); Value valA = genAllocCopy(rewriter, loc, memV, tokens); @@ -702,6 +712,7 @@ token = genDeallocMemRef(rewriter, loc, matC, token); tokens.push_back(token); genBlockingWait(rewriter, loc, tokens); + tokens.clear(); if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { genHostUnregisterMemref(rewriter, loc, castR); if (memC) @@ -710,14 +721,179 @@ genHostUnregisterMemref(rewriter, loc, castB); genHostUnregisterMemref(rewriter, loc, castC); } - tokens.clear(); // Done. rewriter.replaceOpWithNewOp(op, bufC); return success(); } -// Match and rewrite 2:4 SpMM kernels. +// Match and rewrite SpGEMM kernel. +static LogicalResult +rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT, + GPUDataTransferStrategy gpuDataTransferStrategy) { + Location loc = op.getLoc(); + Value a = op.getOperand(0); + Value b = op.getOperand(1); + Value c = op.getOperand(2); // we have C = AB + SmallVector tokens; + + // Only CSR <- CSR x CSR supported. + bool isCOO = false; + SparseTensorType aTp = getSparseTensorType(a); + SparseTensorType bTp = getSparseTensorType(b); + SparseTensorType cTp = getSparseTensorType(c); + if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp)) + return failure(); + + // Start sparse kernel and copy data from host to device. + // a : amemR/amemC/amemV -> rowA,colA,valA + // b : bmemR/bmemC/bmemV -> rowB,colB,valB + // c : materializes + auto dnCType = cTp.getElementType(); + Value nseA = rewriter.create(loc, a); + Value nseB = rewriter.create(loc, b); + Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); + Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); + Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); + Value amemR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT); + Value amemC = genSecondCrds(rewriter, loc, a, isCOO, enableRT); + Value amemV = genToValues(rewriter, loc, a); + Value bmemR = genFirstPosOrCrds(rewriter, loc, b, isCOO, enableRT); + Value bmemC = genSecondCrds(rewriter, loc, b, isCOO, enableRT); + Value bmemV = genToValues(rewriter, loc, b); + Value rowA = genAllocCopy(rewriter, loc, amemR, tokens); + Value colA = genAllocCopy(rewriter, loc, amemC, tokens); + Value valA = genAllocCopy(rewriter, loc, amemV, tokens); + Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens); + Value colB = genAllocCopy(rewriter, loc, bmemC, tokens); + Value valB = genAllocCopy(rewriter, loc, bmemV, tokens); + genBlockingWait(rewriter, loc, tokens); + tokens.clear(); + + // Create sparse environment and sparse matrix/dense vector handles. + Type indexTp = rewriter.getIndexType(); + Type spmatHandleTp = rewriter.getType(); + Type descTp = rewriter.getType(); + Type tokenTp = rewriter.getType(); + Value token = genFirstWait(rewriter, loc); + Operation *spGenA = + genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szk, nseA, + rowA, colA, valA, isCOO, enableRT); + Value spMatA = spGenA->getResult(0); + token = spGenA->getResult(1); + Operation *spGenB = + genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szk, szn, nseB, + rowB, colB, valB, isCOO, enableRT); + Value spMatB = spGenB->getResult(0); + token = spGenB->getResult(1); + + // Sparse matrix C materializes (also assumes beta == 0). + Value zero = constantIndex(rewriter, loc, 0); + Value one = constantIndex(rewriter, loc, 1); + Value mplus1 = rewriter.create(loc, szm, one); + auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token); + Value rowC = e1.getResult(0); + token = e1.getAsyncToken(); + auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token); + Value colC = e2.getResult(0); + token = e2.getAsyncToken(); + auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token); + Value valC = e3.getResult(0); + token = e3.getAsyncToken(); + Operation *spGenC = + genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szn, zero, + rowC, colC, valC, isCOO, enableRT); + Value spMatC = spGenC->getResult(0); + token = spGenC->getResult(1); + + // Precompute buffersizes for SpGEMM. + Operation *descOp = + rewriter.create(loc, descTp, tokenTp, token); + Value desc = descOp->getResult(0); + token = descOp->getResult(1); + Operation *work1 = rewriter.create( + loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, + gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, + valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); + Value bufferSz1 = work1->getResult(0); + token = work1->getResult(1); + auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); + Value buffer1 = buf1.getResult(0); + token = buf1.getAsyncToken(); + Operation *work2 = rewriter.create( + loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, + gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, + bufferSz1, buffer1, + gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); + token = work2->getResult(1); + + // Compute step. + Operation *compute1 = rewriter.create( + loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, + gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, + valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); + Value bufferSz2 = compute1->getResult(0); + token = compute1->getResult(1); + auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); + Value buffer2 = buf2.getResult(0); + token = buf2.getAsyncToken(); + Operation *compute2 = rewriter.create( + loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, + gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, + bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); + token = compute2->getResult(1); + + // Get sizes. + Operation *sizes = rewriter.create( + loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC); + Value nnz = sizes->getResult(2); + token = sizes->getResult(3); + auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token); + colC = a2.getResult(0); + token = a2.getAsyncToken(); + auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token); + valC = a3.getResult(0); + token = a3.getAsyncToken(); + + // Update C with new pointers and copy final product back into C. + Operation *update = rewriter.create( + loc, tokenTp, token, spMatC, rowC, colC, valC); + token = update->getResult(0); + Operation *copy = rewriter.create( + loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, + gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType); + token = copy->getResult(0); + + // Allocate buffers on host. + Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1); + Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz); + Value valH = genHostBuffer(rewriter, loc, dnCType, nnz); + + // Copy data back to host and free all the resoures. + token = rewriter.create(loc, tokenTp, token, desc) + .getAsyncToken(); + token = rewriter.create(loc, tokenTp, token, spMatA) + .getAsyncToken(); + token = rewriter.create(loc, tokenTp, token, spMatB) + .getAsyncToken(); + token = rewriter.create(loc, tokenTp, token, spMatC) + .getAsyncToken(); + token = genCopyMemRef(rewriter, loc, rowH, rowC, token); + token = genCopyMemRef(rewriter, loc, colH, colC, token); + token = genCopyMemRef(rewriter, loc, valH, valC, token); + tokens.push_back(token); + genBlockingWait(rewriter, loc, tokens); + tokens.clear(); + + // Done. + Value vt = rewriter.create(loc, valH); + Value rt = rewriter.create(loc, rowH); + Value ct = rewriter.create(loc, colH); + rewriter.replaceOpWithNewOp(op, c.getType(), vt, ValueRange{rt, ct}); + return success(); +} + +// Match and rewrite 2:4 SpMM kernel. static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op, GPUDataTransferStrategy gpuDataTransferStrategy) { @@ -748,7 +924,6 @@ castB = genHostRegisterMemref(rewriter, loc, bufB); castC = genHostRegisterMemref(rewriter, loc, bufC); } - if (isZeroCopy) { matA = bufA; matB = bufB; @@ -756,10 +931,11 @@ Value matC = genAllocCopy(rewriter, loc, bufC, tokens); genBlockingWait(rewriter, loc, tokens); tokens.clear(); + + // Create sparse environment and sparse matrix/dense vector handles. Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0); Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0); Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1); - Type indexTp = rewriter.getIndexType(); Type dnTensorHandleTp = rewriter.getType(); Type spMatHandleTp = rewriter.getType(); @@ -768,7 +944,6 @@ Operation *spGenA = rewriter.create( loc, spMatHandleTp, tokenTp, token, szm, szk, gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA); - Value spMatA = spGenA->getResult(0); token = spGenA->getResult(1); auto dmatB = rewriter.create( @@ -781,7 +956,6 @@ SmallVector{szm, szn}); Value dnC = dmatC.getResult(0); token = dmatC.getAsyncToken(); - auto dmatCType = llvm::cast(matC.getType()).getElementType(); // Precompute buffersize for SpMM. @@ -791,8 +965,8 @@ loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC, /*computeType=*/dmatCType); - token = bufferComp.getAsyncToken(); + Value bufferSz = bufferComp.getResult(0); auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); Value buffer = buf.getResult(0); @@ -824,11 +998,9 @@ token = rewriter.create(loc, tokenTp, token, dnC) .getAsyncToken(); SmallVector newDynamicSizes; - token = genDeallocMemRef(rewriter, loc, buffer, token); token = genDeallocMemRef(rewriter, loc, buffer2, token); token = genDeallocMemRef(rewriter, loc, buffer3, token); - if (!isZeroCopy) token = genDeallocMemRef(rewriter, loc, matA, token); if (!isZeroCopy) @@ -837,12 +1009,14 @@ token = genDeallocMemRef(rewriter, loc, matC, token); tokens.push_back(token); genBlockingWait(rewriter, loc, tokens); + tokens.clear(); if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { genHostUnregisterMemref(rewriter, loc, castA); genHostUnregisterMemref(rewriter, loc, castB); genHostUnregisterMemref(rewriter, loc, castC); } - tokens.clear(); + + // Done. rewriter.replaceOpWithNewOp(op, bufC); return success(); } @@ -889,7 +1063,6 @@ Value memR = genFirstPosOrCrds(rewriter, loc, c, isCOO, enableRT); Value memC = genSecondCrds(rewriter, loc, c, isCOO, enableRT); Value memV = genToValues(rewriter, loc, c); - Value castB, castA, castR, castC, castV; if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { castB = genHostRegisterMemref(rewriter, loc, bufB); @@ -899,7 +1072,6 @@ castC = genHostRegisterMemref(rewriter, loc, memC); castV = genHostRegisterMemref(rewriter, loc, memV); } - if (isZeroCopy) { matA = bufA; matB = bufB; @@ -930,8 +1102,8 @@ rowC, colC, valC, isCOO, enableRT); Value spMatC = spGenC->getResult(0); token = spGenC->getResult(1); - auto dnCType = llvm::cast(c.getType()).getElementType(); + // Precompute buffersize for SDDMM. auto bufferComp = rewriter.create( loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType); @@ -965,6 +1137,7 @@ token = genDeallocMemRef(rewriter, loc, valC, token); tokens.push_back(token); genBlockingWait(rewriter, loc, tokens); + tokens.clear(); if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { genHostUnregisterMemref(rewriter, loc, castB); genHostUnregisterMemref(rewriter, loc, castA); @@ -973,7 +1146,6 @@ genHostUnregisterMemref(rewriter, loc, castC); genHostUnregisterMemref(rewriter, loc, castV); } - tokens.clear(); // Done. rewriter.replaceOpWithNewOp(op, c); @@ -986,7 +1158,7 @@ /// Proof-of-concept rewriter. This rule generates a GPU implementation /// for each outermost forall loop generated by the sparse compiler. -/// TODO: right works with parallelization-strategy=dense-outer-loop +/// TODO: right now works with parallelization-strategy=dense-outer-loop /// but give this its own flags in the future struct ForallRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1109,29 +1281,27 @@ AffineExpr i, j, k; bindDims(getContext(), i, j, k); - // TODO: more robust patterns, tranposed versions, more kernels... - // TODO: identify alpha and beta and pass them to the CUDA calls + // TODO: more robust patterns, tranposed versions, more kernels, + // identify alpha and beta and pass them to the CUDA calls. // Recognize a SpMV kernel. if (numLoops == 2 && numTensors == 3 && linalg::isParallelIterator(iteratorTypes[0]) && linalg::isReductionIterator(iteratorTypes[1]) && - // TODO: add transposed {i, j} maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) { return rewriteSpMV(rewriter, op, enableRT, gpuDataTransferStrategy); } - // Recognize a SpMM kernel. + // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel. if (numLoops == 3 && numTensors == 3 && linalg::isParallelIterator(iteratorTypes[0]) && linalg::isParallelIterator(iteratorTypes[1]) && linalg::isReductionIterator(iteratorTypes[2]) && - // TODO: add transposed {i, k}, {k, j} - // TODO: maybe add transposed {i, j} in future maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { + if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1))) + return rewriteSpGEMM(rewriter, op, enableRT, gpuDataTransferStrategy); if (op->getAttr("DENSE24")) return rewrite2To4SpMM(rewriter, op, gpuDataTransferStrategy); - return rewriteSpMM(rewriter, op, enableRT, gpuDataTransferStrategy); } @@ -1140,8 +1310,6 @@ linalg::isParallelIterator(iteratorTypes[0]) && linalg::isParallelIterator(iteratorTypes[1]) && linalg::isReductionIterator(iteratorTypes[2]) && - // TODO: add transposed {i, k}, {k, j} - // TODO: maybe add transposed {i, j} in future maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumReductionOfMulUnary(op)) { return rewriteSDDMM(rewriter, op, enableRT, gpuDataTransferStrategy); diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir @@ -0,0 +1,81 @@ +// +// NOTE: this test requires gpu-sm80 +// +// without RT lib: +// +// RUN: mlir-opt %s \ +// RUN: --sparse-compiler="enable-runtime-library=false enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \ +// RUN: | mlir-cpu-runner \ +// RUN: --shared-libs=%mlir_cuda_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --e main --entry-point-result=void \ +// RUN: | FileCheck %s + +#CSR = #sparse_tensor.encoding<{ + lvlTypes = [ "dense", "compressed" ], + posWidth = 32, + crdWidth = 32 +}> + +module { + llvm.func @mgpuCreateSparseEnv() + llvm.func @mgpuDestroySparseEnv() + + // Computes C = A x B with A,B,C sparse CSR. + func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>, + %B: tensor<8x8xf32, #CSR>) -> tensor<8x8xf32, #CSR> { + %init = bufferization.alloc_tensor() : tensor<8x8xf32, #CSR> + %C = linalg.matmul + ins(%A, %B: tensor<8x8xf32, #CSR>, + tensor<8x8xf32, #CSR>) + outs(%init: tensor<8x8xf32, #CSR>) -> tensor<8x8xf32, #CSR> + return %C: tensor<8x8xf32, #CSR> + } + + // + // Main driver. + // + func.func @main() { + llvm.call @mgpuCreateSparseEnv(): () -> () + + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + + %t = arith.constant dense<[ + [ 1.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 3.0], + [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [ 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [ 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0], + [ 0.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0], + [ 0.0, 7.0, 8.0, 0.0, 0.0, 0.0, 0.0, 9.0], + [ 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 11.0, 12.0], + [ 0.0, 13.0, 14.0, 0.0, 0.0, 0.0, 15.0, 16.0] + ]> : tensor<8x8xf32> + %Acsr = sparse_tensor.convert %t : tensor<8x8xf32> to tensor<8x8xf32, #CSR> + + %Ccsr = call @matmulCSR(%Acsr, %Acsr) : (tensor<8x8xf32, #CSR>, + tensor<8x8xf32, #CSR>) -> tensor<8x8xf32, #CSR> + + // + // Verify computed result (expected output, with only 20 nonzeros). + // + // CHECK: ( ( 1, 39, 52, 0, 0, 0, 45, 51 ), + // CHECK-SAME: ( 0, 0, 0, 0, 0, 0, 0, 0 ), + // CHECK-SAME: ( 0, 0, 16, 0, 0, 0, 0, 0 ), + // CHECK-SAME: ( 0, 0, 0, 25, 0, 0, 0, 0 ), + // CHECK-SAME: ( 0, 0, 0, 0, 36, 0, 0, 0 ), + // CHECK-SAME: ( 0, 117, 158, 0, 0, 0, 135, 144 ), + // CHECK-SAME: ( 0, 156, 318, 0, 0, 0, 301, 324 ), + // CHECK-SAME: ( 0, 208, 430, 0, 0, 0, 405, 436 ) ) + // CHECK-NEXT: 20 + %d = sparse_tensor.convert %Ccsr : tensor<8x8xf32, #CSR> to tensor<8x8xf32> + %v = vector.transfer_read %d[%c0, %c0], %f0: tensor<8x8xf32>, vector<8x8xf32> + vector.print %v : vector<8x8xf32> + %nnz = sparse_tensor.number_of_entries %Ccsr : tensor<8x8xf32, #CSR> + %x = sparse_tensor.number_of_entries %Ccsr : tensor<8x8xf32, #CSR> + vector.print %nnz : index + + llvm.call @mgpuDestroySparseEnv(): () -> () + return + } +}