diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -347,6 +347,19 @@ aTp.getCrdWidth() == 64); } +/// Test for admissible types on operands (with output parameter `isCOO`). +static bool areAdmissibleTypes(SparseTensorType aTp, SparseTensorType bTp, + SparseTensorType cTp, bool enableRT, + bool &isCOO) { + if (bTp.hasEncoding() || cTp.hasEncoding()) + return false; + if (isAdmissibleCOO(aTp)) { + isCOO = true; + return enableRT; // TODO: CreateCooAoSOp was deprecated, find another way + } + return isAdmissibleCSR(aTp); +} + /// Generates the first positions/coordinates of a sparse matrix. static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a, bool isCOO, bool enableRT) { @@ -393,23 +406,13 @@ Value y = op.getOperand(2); // we have y = Ax SmallVector tokens; - // Only admissible sparse matrix format and dense vectors for now. + // Only admissible sparse matrix format and dense vectors. bool isCOO = false; SparseTensorType aTp = getSparseTensorType(a); SparseTensorType xTp = getSparseTensorType(x); SparseTensorType yTp = getSparseTensorType(y); - if (xTp.hasEncoding() || yTp.hasEncoding()) - return failure(); - if (isAdmissibleCOO(aTp)) { - isCOO = true; - // TODO: CreateCooAoSOp was deprecated, find another way - if (!enableRT) - return failure(); - } else if (isAdmissibleCSR(aTp)) { - isCOO = false; - } else { + if (!areAdmissibleTypes(aTp, xTp, yTp, enableRT, isCOO)) return failure(); - } // Start sparse kernel and copy data from host to device. // a : memR/memC/memV -> rowA,colA,valA @@ -500,7 +503,105 @@ /// Match and rewrite SpMM kernel. static LogicalResult rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT) { - return failure(); // TODO: implement + Location loc = op.getLoc(); + Value a = op.getOperand(0); + Value b = op.getOperand(1); + Value c = op.getOperand(2); // we have C = AB + SmallVector tokens; + + // Only admissible sparse matrix format and dense matrices. + bool isCOO = false; + SparseTensorType aTp = getSparseTensorType(a); + SparseTensorType bTp = getSparseTensorType(b); + SparseTensorType cTp = getSparseTensorType(c); + if (!areAdmissibleTypes(aTp, bTp, cTp, enableRT, isCOO)) + return failure(); + + // Start sparse kernel and copy data from host to device. + // a : memR/memC/memV -> rowA,colA,valA + // b : bufB -> matA + // c : bufC -> matC + Value nnzA = rewriter.create(loc, a); + Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); + Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); + Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); + Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT); + Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT); + Value memV = genToValues(rewriter, loc, a); + Value rowA = genAllocCopy(rewriter, loc, memR, tokens); + Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); + Value valA = genAllocCopy(rewriter, loc, memV, tokens); + Value bufB = genTensorToMemref(rewriter, loc, b); + Value matB = genAllocCopy(rewriter, loc, bufB, tokens); + Value bufC = genTensorToMemref(rewriter, loc, c); + Value matC = genAllocCopy(rewriter, loc, bufC, tokens); + genBlockingWait(rewriter, loc, tokens); + tokens.clear(); + + // Create sparse environment and sparse matrix/dense matrix handles. + Type indexTp = rewriter.getIndexType(); + Type handleTp = rewriter.getType(); + Type tokenTp = rewriter.getType(); + Value token = genFirstWait(rewriter, loc); + auto env = + rewriter.create(loc, handleTp, tokenTp, token); + Value handle = env.getResult(0); + token = env.getAsyncToken(); + Operation *spGenA = genSpMat(rewriter, loc, handleTp, tokenTp, token, szm, + szk, nnzA, rowA, colA, valA, isCOO, enableRT); + Value spMatA = spGenA->getResult(0); + token = spGenA->getResult(1); + auto dmatB = rewriter.create(loc, handleTp, tokenTp, + token, szk, szn, matB); + Value dnB = dmatB.getResult(0); + token = dmatB.getAsyncToken(); + auto dmatC = rewriter.create(loc, handleTp, tokenTp, + token, szm, szn, matC); + Value dnC = dmatC.getResult(0); + token = dmatC.getAsyncToken(); + + // Precompute buffersize for SpMM. + auto bufferComp = rewriter.create( + loc, indexTp, tokenTp, token, handle, spMatA, dnB, dnC); + Value bufferSz = bufferComp.getResult(0); + token = bufferComp.getAsyncToken(); + auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); + Value buffer = buf.getResult(0); + token = buf.getAsyncToken(); + + // Perform the SpMM. + auto spmmComp = rewriter.create(loc, tokenTp, token, handle, + spMatA, dnB, dnC, buffer); + token = spmmComp.getAsyncToken(); + + // Copy data back to host and free all the resoures. + token = rewriter.create(loc, tokenTp, token, spMatA) + .getAsyncToken(); + token = rewriter.create(loc, tokenTp, token, dnB) + .getAsyncToken(); + token = rewriter.create(loc, tokenTp, token, dnC) + .getAsyncToken(); + token = rewriter.create(loc, tokenTp, token, handle) + .getAsyncToken(); + tokens.push_back(token); + genBlockingWait(rewriter, loc, tokens); + tokens.clear(); + token = genFirstWait(rewriter, loc); + token = genCopyMemRef(rewriter, loc, bufC, matC, token); + token = genDeallocMemRef(rewriter, loc, rowA, token); + if (colA) + token = genDeallocMemRef(rewriter, loc, colA, token); + token = genDeallocMemRef(rewriter, loc, valA, token); + token = genDeallocMemRef(rewriter, loc, buffer, token); + token = genDeallocMemRef(rewriter, loc, matB, token); + token = genDeallocMemRef(rewriter, loc, matC, token); + tokens.push_back(token); + genBlockingWait(rewriter, loc, tokens); + tokens.clear(); + + // Done. + rewriter.replaceOp(op, op.getDpsInitOperand(0)->get()); + return success(); } //===----------------------------------------------------------------------===// @@ -610,7 +711,7 @@ //===----------------------------------------------------------------------===// /// Proof-of-concept rewriter. This rule recognizes certain math kernels -/// and replaces these with corresponding calls into the sparse library. +/// and replaces these with corresponding calls into a sparse library. struct LinalgOpRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir @@ -0,0 +1,76 @@ +// RUN: mlir-opt %s --linalg-generalize-named-ops \ +// RUN: --sparsification="enable-gpu-libgen" | FileCheck %s + +#CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }> + +// +// Compute matrix matrix C = AB +// +// CHECK-LABEL: func.func @matmul( +// CHECK-SAME: %[[VAL_0:.*]]: tensor>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor, +// CHECK-SAME: %[[VAL_2:.*]]: tensor) -> tensor { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.number_of_entries %[[VAL_0]] : tensor> +// CHECK-DAG: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor> +// CHECK-DAG: %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor +// CHECK-DAG: %[[VAL_8:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor> +// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor> to memref +// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor> to memref> +// CHECK: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor> to memref +// CHECK: %[[VAL_12:.*]] = gpu.wait async +// CHECK: %[[VAL_13:.*]] = memref.dim %[[VAL_9]], %[[VAL_3]] : memref +// CHECK: %[[VAL_14:.*]], %[[VAL_15:.*]] = gpu.alloc async {{\[}}%[[VAL_12]]] (%[[VAL_13]]) : memref +// CHECK: %[[VAL_16:.*]] = gpu.memcpy async {{\[}}%[[VAL_15]]] %[[VAL_14]], %[[VAL_9]] : memref, memref +// CHECK: %[[VAL_17:.*]] = gpu.wait async +// CHECK: %[[VAL_18:.*]] = memref.dim %[[VAL_10]], %[[VAL_3]] : memref> +// CHECK: %[[VAL_19:.*]], %[[VAL_20:.*]] = gpu.alloc async {{\[}}%[[VAL_17]]] (%[[VAL_18]]) : memref +// CHECK: %[[VAL_21:.*]] = gpu.memcpy async {{\[}}%[[VAL_20]]] %[[VAL_19]], %[[VAL_10]] : memref, memref> +// CHECK: %[[VAL_22:.*]] = gpu.wait async +// CHECK: %[[VAL_23:.*]] = memref.dim %[[VAL_11]], %[[VAL_3]] : memref +// CHECK: %[[VAL_24:.*]], %[[VAL_25:.*]] = gpu.alloc async {{\[}}%[[VAL_22]]] (%[[VAL_23]]) : memref +// CHECK: %[[VAL_26:.*]] = gpu.memcpy async {{\[}}%[[VAL_25]]] %[[VAL_24]], %[[VAL_11]] : memref, memref +// CHECK: %[[VAL_27:.*]] = bufferization.to_memref %[[VAL_1]] : memref +// CHECK: %[[VAL_28:.*]] = gpu.wait async +// CHECK: %[[VAL_29:.*]] = memref.dim %[[VAL_27]], %[[VAL_3]] : memref +// CHECK: %[[VAL_30:.*]] = memref.dim %[[VAL_27]], %[[VAL_4]] : memref +// CHECK: %[[VAL_31:.*]], %[[VAL_32:.*]] = gpu.alloc async {{\[}}%[[VAL_28]]] (%[[VAL_29]], %[[VAL_30]]) : memref +// CHECK: %[[VAL_33:.*]] = gpu.memcpy async {{\[}}%[[VAL_32]]] %[[VAL_31]], %[[VAL_27]] : memref, memref +// CHECK: %[[VAL_34:.*]] = bufferization.to_memref %[[VAL_2]] : memref +// CHECK: %[[VAL_35:.*]] = gpu.wait async +// CHECK: %[[VAL_36:.*]] = memref.dim %[[VAL_34]], %[[VAL_3]] : memref +// CHECK: %[[VAL_37:.*]] = memref.dim %[[VAL_34]], %[[VAL_4]] : memref +// CHECK: %[[VAL_38:.*]], %[[VAL_39:.*]] = gpu.alloc async {{\[}}%[[VAL_35]]] (%[[VAL_36]], %[[VAL_37]]) : memref +// CHECK: %[[VAL_40:.*]] = gpu.memcpy async {{\[}}%[[VAL_39]]] %[[VAL_38]], %[[VAL_34]] : memref, memref +// CHECK: gpu.wait {{\[}}%[[VAL_16]], %[[VAL_21]], %[[VAL_26]], %[[VAL_33]], %[[VAL_40]]] +// CHECK: %[[VAL_41:.*]] = gpu.wait async +// CHECK: %[[VAL_42:.*]], %[[VAL_43:.*]] = gpu.create_sparse_env async {{\[}}%[[VAL_41]]] +// CHECK: %[[VAL_44:.*]], %[[VAL_45:.*]] = gpu.create_csr async {{\[}}%[[VAL_43]]] %[[VAL_6]], %[[VAL_8]], %[[VAL_5]], %[[VAL_14]], %[[VAL_19]], %[[VAL_24]] : memref, memref, memref +// CHECK: %[[VAL_46:.*]], %[[VAL_47:.*]] = gpu.create_dn_mat async {{\[}}%[[VAL_45]]] %[[VAL_8]], %[[VAL_7]], %[[VAL_31]] : memref +// CHECK: %[[VAL_48:.*]], %[[VAL_49:.*]] = gpu.create_dn_mat async {{\[}}%[[VAL_47]]] %[[VAL_6]], %[[VAL_7]], %[[VAL_38]] : memref +// CHECK: %[[VAL_50:.*]], %[[VAL_51:.*]] = gpu.spmm_buffer_size async {{\[}}%[[VAL_49]]] %[[VAL_42]], %[[VAL_44]], %[[VAL_46]], %[[VAL_48]] +// CHECK: %[[VAL_52:.*]], %[[VAL_53:.*]] = gpu.alloc async {{\[}}%[[VAL_51]]] (%[[VAL_50]]) : memref +// CHECK: %[[VAL_54:.*]] = gpu.spmm async {{\[}}%[[VAL_53]]] %[[VAL_42]], %[[VAL_44]], %[[VAL_46]], %[[VAL_48]], %[[VAL_52]] : memref +// CHECK: %[[VAL_55:.*]] = gpu.destroy_sp_mat async {{\[}}%[[VAL_54]]] %[[VAL_44]] +// CHECK: %[[VAL_56:.*]] = gpu.destroy_dn_mat async {{\[}}%[[VAL_55]]] %[[VAL_46]] +// CHECK: %[[VAL_57:.*]] = gpu.destroy_dn_mat async {{\[}}%[[VAL_56]]] %[[VAL_48]] +// CHECK: %[[VAL_58:.*]] = gpu.destroy_sparse_env async {{\[}}%[[VAL_57]]] %[[VAL_42]] +// CHECK: gpu.wait {{\[}}%[[VAL_58]]] +// CHECK: %[[VAL_59:.*]] = gpu.wait async +// CHECK: %[[VAL_60:.*]] = gpu.memcpy async {{\[}}%[[VAL_59]]] %[[VAL_34]], %[[VAL_38]] : memref, memref +// CHECK: %[[VAL_61:.*]] = gpu.dealloc async {{\[}}%[[VAL_60]]] %[[VAL_14]] : memref +// CHECK: %[[VAL_62:.*]] = gpu.dealloc async {{\[}}%[[VAL_61]]] %[[VAL_19]] : memref +// CHECK: %[[VAL_63:.*]] = gpu.dealloc async {{\[}}%[[VAL_62]]] %[[VAL_24]] : memref +// CHECK: %[[VAL_64:.*]] = gpu.dealloc async {{\[}}%[[VAL_63]]] %[[VAL_52]] : memref +// CHECK: %[[VAL_65:.*]] = gpu.dealloc async {{\[}}%[[VAL_64]]] %[[VAL_31]] : memref +// CHECK: %[[VAL_66:.*]] = gpu.dealloc async {{\[}}%[[VAL_65]]] %[[VAL_38]] : memref +// CHECK: gpu.wait {{\[}}%[[VAL_66]]] +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +func.func @matmul(%A: tensor, %B: tensor, %C_in: tensor) -> tensor { + %C_out = linalg.matmul + ins(%A, %B: tensor, tensor) + outs(%C_in: tensor) -> tensor + return %C_out : tensor +}