diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -249,6 +249,13 @@ } } +static bool isDenseTensorType(Type type) { + auto ttp = llvm::dyn_cast(type); + if (llvm::dyn_cast_or_null(ttp.getEncoding())) + return false; + return true; +} + /// Constructs code for new GPU kernel. static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, scf::ParallelOp forallOp, @@ -656,6 +663,112 @@ return success(); } +// Match and rewrite 2:4 SpMM kernels. +static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, + linalg::GenericOp op) { + Location loc = op.getLoc(); + Value A = op.getOperand(0); + Value B = op.getOperand(1); + Value C = op.getOperand(2); // we have C = AB + SmallVector tokens; + + // All input should be dense tensors. + if (!isDenseTensorType(A.getType()) || !isDenseTensorType(B.getType()) || + !isDenseTensorType(C.getType())) + return failure(); + + // For now, lower to GPU directly. + + Value bufA = genTensorToMemref(rewriter, loc, A); + Value matA = genAllocCopy(rewriter, loc, bufA, tokens); + Value bufB = genTensorToMemref(rewriter, loc, B); + Value matB = genAllocCopy(rewriter, loc, bufB, tokens); + Value bufC = genTensorToMemref(rewriter, loc, C); + Value matC = genAllocCopy(rewriter, loc, bufC, tokens); + genBlockingWait(rewriter, loc, tokens); + tokens.clear(); + Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0); + Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0); + Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1); + + Type indexTp = rewriter.getIndexType(); + Type dnTensorHandleTp = rewriter.getType(); + Type spMatHandleTp = rewriter.getType(); + Type tokenTp = rewriter.getType(); + Value token = genFirstWait(rewriter, loc); + Operation *spGenA = rewriter.create( + loc, spMatHandleTp, tokenTp, token, szm, szk, matA); + + Value spMatA = spGenA->getResult(0); + token = spGenA->getResult(1); + auto dmatB = rewriter.create( + loc, dnTensorHandleTp, tokenTp, token, matB, + SmallVector{szk, szn}); + Value dnB = dmatB.getResult(0); + token = dmatB.getAsyncToken(); + auto dmatC = rewriter.create( + loc, dnTensorHandleTp, tokenTp, token, matC, + SmallVector{szm, szn}); + Value dnC = dmatC.getResult(0); + token = dmatC.getAsyncToken(); + + auto dmatCType = llvm::cast(matC.getType()).getElementType(); + + // Precompute buffersize for SpMM. + SmallVector bufferTypes_{indexTp, indexTp, indexTp}; + TypeRange bufferTypes(bufferTypes_); + auto bufferComp = rewriter.create( + loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE, + gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC, + /*computeType=*/dmatCType); + + token = bufferComp.getAsyncToken(); + Value bufferSz = bufferComp.getResult(0); + auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); + Value buffer = buf.getResult(0); + token = buf.getAsyncToken(); + + Value bufferSz2 = bufferComp.getResult(1); + auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); + Value buffer2 = buf2.getResult(0); + token = buf2.getAsyncToken(); + + Value bufferSz3 = bufferComp.getResult(2); + auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token); + Value buffer3 = buf3.getResult(0); + token = buf3.getAsyncToken(); + + auto dnCType = llvm::cast(matC.getType()).getElementType(); + + // Perform the SpMM. + auto spmmComp = rewriter.create( + loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, + SmallVector{buffer, buffer2, buffer3}); + token = spmmComp.getAsyncToken(); + + // Copy data back to host and free all the resources. + token = rewriter.create(loc, tokenTp, token, spMatA) + .getAsyncToken(); + token = rewriter.create(loc, tokenTp, token, dnB) + .getAsyncToken(); + token = rewriter.create(loc, tokenTp, token, dnC) + .getAsyncToken(); + SmallVector newDynamicSizes; + + token = genDeallocMemRef(rewriter, loc, buffer, token); + token = genDeallocMemRef(rewriter, loc, buffer2, token); + token = genDeallocMemRef(rewriter, loc, buffer3, token); + token = genDeallocMemRef(rewriter, loc, matA, token); + token = genDeallocMemRef(rewriter, loc, matB, token); + token = genCopyMemRef(rewriter, loc, bufC, matC, token); + token = genDeallocMemRef(rewriter, loc, matC, token); + tokens.push_back(token); + genBlockingWait(rewriter, loc, tokens); + tokens.clear(); + rewriter.replaceOpWithNewOp(op, bufC); + return success(); +} + /// Match and rewrite SDDMM kernel. static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT) { @@ -906,7 +1019,12 @@ // TODO: add transposed {i, k}, {k, j} // TODO: maybe add transposed {i, j} in future maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { - return rewriteSpMM(rewriter, op, enableRT); + // TODO: match 2:4 + if (op->getAttr("DENSE24")) { + return rewrite2To4SpMM(rewriter, op); + } else { + return rewriteSpMM(rewriter, op, enableRT); + } } // Recognize a SDDMM kernel. diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir @@ -0,0 +1,69 @@ +// RUN: mlir-opt %s --linalg-generalize-named-ops \ +// RUN: --sparsification="enable-gpu-libgen" | FileCheck %s + +// CHECK-LABEL: func.func @matmul( +// CHECK-SAME: %[[VAL_0:.*0]]: tensor, +// CHECK-SAME: %[[VAL_1:.*1]]: tensor, +// CHECK-SAME: %[[VAL_2:.*2]]: tensor) -> tensor { +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = bufferization.to_memref %[[VAL_0]] : memref +// CHECK: %[[VAL_6:.*]] = gpu.wait async +// CHECK: %[[VAL_7:.*]] = memref.dim %[[VAL_5]], %[[VAL_3]] : memref +// CHECK: %[[VAL_8:.*]] = memref.dim %[[VAL_5]], %[[VAL_4]] : memref +// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]] = gpu.alloc async {{\[}}%[[VAL_6]]] (%[[VAL_7]], %[[VAL_8]]) : memref +// CHECK: %[[VAL_11:.*]] = gpu.memcpy async {{\[}}%[[VAL_10]]] %[[VAL_9]], %[[VAL_5]] : memref, memref +// CHECK: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref +// CHECK: %[[VAL_13:.*]] = gpu.wait async +// CHECK: %[[VAL_14:.*]] = memref.dim %[[VAL_12]], %[[VAL_3]] : memref +// CHECK: %[[VAL_15:.*]] = memref.dim %[[VAL_12]], %[[VAL_4]] : memref +// CHECK: %[[VAL_16:.*]], %[[VAL_17:.*]] = gpu.alloc async {{\[}}%[[VAL_13]]] (%[[VAL_14]], %[[VAL_15]]) : memref +// CHECK: %[[VAL_18:.*]] = gpu.memcpy async {{\[}}%[[VAL_17]]] %[[VAL_16]], %[[VAL_12]] : memref, memref +// CHECK: %[[VAL_19:.*]] = bufferization.to_memref %[[VAL_2]] : memref +// CHECK: %[[VAL_20:.*]] = gpu.wait async +// CHECK: %[[VAL_21:.*]] = memref.dim %[[VAL_19]], %[[VAL_3]] : memref +// CHECK: %[[VAL_22:.*]] = memref.dim %[[VAL_19]], %[[VAL_4]] : memref +// CHECK: %[[VAL_23:.*]], %[[VAL_24:.*]] = gpu.alloc async {{\[}}%[[VAL_20]]] (%[[VAL_21]], %[[VAL_22]]) : memref +// CHECK: %[[VAL_25:.*]] = gpu.memcpy async {{\[}}%[[VAL_24]]] %[[VAL_23]], %[[VAL_19]] : memref, memref +// CHECK: gpu.wait {{\[}}%[[VAL_11]], %[[VAL_18]], %[[VAL_25]]] +// CHECK: %[[VAL_26:.*]] = memref.dim %[[VAL_9]], %[[VAL_3]] : memref +// CHECK: %[[VAL_27:.*]] = memref.dim %[[VAL_16]], %[[VAL_3]] : memref +// CHECK: %[[VAL_28:.*]] = memref.dim %[[VAL_23]], %[[VAL_4]] : memref +// CHECK: %[[VAL_29:.*]] = gpu.wait async +// CHECK: %[[VAL_30:.*]], %[[VAL_31:.*]] = gpu.create_2to4_spmat async {{\[}}%[[VAL_29]]] %[[VAL_26]], %[[VAL_27]], %[[VAL_9]] : memref +// CHECK: %[[VAL_32:.*]], %[[VAL_33:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_31]]] %[[VAL_16]], %[[VAL_27]], %[[VAL_28]] : index, index into memref +// CHECK: %[[VAL_34:.*]], %[[VAL_35:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_33]]] %[[VAL_23]], %[[VAL_26]], %[[VAL_28]] : index, index into memref +// CHECK: %[[VAL_36:.*]]:3, %[[VAL_37:.*]] = gpu.spmm_buffer_size async {{\[}}%[[VAL_35]]] %[[VAL_30]], %[[VAL_32]], %[[VAL_34]] : index, index, index into f16 +// CHECK: %[[VAL_38:.*]], %[[VAL_39:.*]] = gpu.alloc async {{\[}}%[[VAL_37]]] (%[[VAL_36]]#0) : memref +// CHECK: %[[VAL_40:.*]], %[[VAL_41:.*]] = gpu.alloc async {{\[}}%[[VAL_39]]] (%[[VAL_36]]#1) : memref +// CHECK: %[[VAL_42:.*]], %[[VAL_43:.*]] = gpu.alloc async {{\[}}%[[VAL_41]]] (%[[VAL_36]]#2) : memref +// CHECK: %[[VAL_44:.*]] = gpu.spmm async {{\[}}%[[VAL_43]]] %[[VAL_30]], %[[VAL_32]], %[[VAL_34]], %[[VAL_38]], %[[VAL_40]], %[[VAL_42]] : memref, memref, memref into f16 +// CHECK: %[[VAL_45:.*]] = gpu.destroy_sp_mat async {{\[}}%[[VAL_44]]] %[[VAL_30]] +// CHECK: %[[VAL_46:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_45]]] %[[VAL_32]] +// CHECK: %[[VAL_47:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_46]]] %[[VAL_34]] +// CHECK: %[[VAL_48:.*]] = gpu.dealloc async {{\[}}%[[VAL_47]]] %[[VAL_38]] : memref +// CHECK: %[[VAL_49:.*]] = gpu.dealloc async {{\[}}%[[VAL_48]]] %[[VAL_40]] : memref +// CHECK: %[[VAL_50:.*]] = gpu.dealloc async {{\[}}%[[VAL_49]]] %[[VAL_42]] : memref +// CHECK: %[[VAL_51:.*]] = gpu.dealloc async {{\[}}%[[VAL_50]]] %[[VAL_9]] : memref +// CHECK: %[[VAL_52:.*]] = gpu.dealloc async {{\[}}%[[VAL_51]]] %[[VAL_16]] : memref +// CHECK: %[[VAL_53:.*]] = gpu.memcpy async {{\[}}%[[VAL_52]]] %[[VAL_19]], %[[VAL_23]] : memref, memref +// CHECK: %[[VAL_54:.*]] = gpu.dealloc async {{\[}}%[[VAL_53]]] %[[VAL_23]] : memref +// CHECK: gpu.wait {{\[}}%[[VAL_54]]] +// CHECK: %[[VAL_55:.*]] = bufferization.to_tensor %[[VAL_19]] : memref +// CHECK: return %[[VAL_55]] : tensor +// CHECK: } + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +module { + func.func @matmul(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = linalg.generic { DENSE24, indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %1 = arith.mulf %in, %in_0 : f16 + %2 = arith.addf %out, %1 : f16 + linalg.yield %2 : f16 + } -> tensor + return %0 : tensor + } +} \ No newline at end of file