diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -301,12 +301,25 @@ // Library helper methods. //===----------------------------------------------------------------------===// -/// Helper to detect a * b. -static bool matchMulOfArgs(linalg::GenericOp op, Value val) { +/// Helper to detect a + b with arguments taken from given block. +static bool matchAddOfArgs(Block *block, Value val) { if (auto *def = val.getDefiningOp()) { - if (isa(def) || isa(def)) { - Value a = op.getBlock()->getArguments()[0]; - Value b = op.getBlock()->getArguments()[1]; + if (isa(def)) { + Value a = block->getArguments()[0]; + Value b = block->getArguments()[1]; + return (def->getOperand(0) == a && def->getOperand(1) == b) || + (def->getOperand(0) == b && def->getOperand(1) == a); + } + } + return false; +} + +/// Helper to detect a * b with arguments taken from given block. +static bool matchMulOfArgs(Block *block, Value val) { + if (auto *def = val.getDefiningOp()) { + if (isa(def)) { + Value a = block->getArguments()[0]; + Value b = block->getArguments()[1]; return (def->getOperand(0) == a && def->getOperand(1) == b) || (def->getOperand(0) == b && def->getOperand(1) == a); } @@ -318,67 +331,47 @@ static bool matchSumOfMultOfArgs(linalg::GenericOp op) { auto yieldOp = cast(op.getRegion().front().getTerminator()); if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { - if (isa(def) || isa(def)) { + if (isa(def)) { Value x = op.getBlock()->getArguments()[2]; return (def->getOperand(0) == x && - matchMulOfArgs(op, def->getOperand(1))) || + matchMulOfArgs(op.getBlock(), def->getOperand(1))) || (def->getOperand(1) == x && - matchMulOfArgs(op, def->getOperand(0))); + matchMulOfArgs(op.getBlock(), def->getOperand(0))); } } return false; } -// Helper to detect c = c \spy (a * b) +// Helper to detect c += spy(s) x (a * b) static bool matchSumReductionOfMulUnary(linalg::GenericOp op) { auto yieldOp = cast(op.getRegion().front().getTerminator()); - auto def = yieldOp.getOperand(0).getDefiningOp(); - if (!def) - return false; + // The linalg yields a custom reduce result. Value s_out = op.getBlock()->getArguments()[2]; - for (auto *use : s_out.getUsers()) { - if (!(isa(use) || - isa(use))) - return false; - // The sparse matrix should be specified as the pattern in the two - // operators. - if (s_out != use->getOperand(0)) + if (auto redOp = + yieldOp.getOperand(0).getDefiningOp()) { + // The reduce consumes the output. + Value other; + if (s_out == redOp->getOperand(0)) + other = redOp->getOperand(1); + else if (s_out == redOp->getOperand(1)) + other = redOp->getOperand(0); + else return false; - - // the above logic makes sure the pattern involves reduction and unary, - // i.e., - // %1 = sparse_tensor.unary - // %2 = sparse_tensor.reduce - // we need to make sure %1 produces A*B and %2 uses summation as the - // reduction operator. - if (isa(use)) { - auto reduceSpOp = cast(use); - auto yieldSpOp = cast( - reduceSpOp.getRegion().front().getTerminator()); - auto *reduce = yieldSpOp.getOperand(0).getDefiningOp(); - if (!isa_and_nonnull(reduce) && - !isa_and_nonnull(reduce)) - return false; - } - if (isa(use)) { - auto unarySpOp = cast(use); - auto yieldSpOp = cast( - unarySpOp.getRegion(0).front().getTerminator()); - auto *unary = yieldSpOp.getOperand(0).getDefiningOp(); - if (!isa_and_nonnull(unary) && - !isa_and_nonnull(unary)) + // The reduce op also consumes an unary which also consumes the output + // and does not define an absent value. + if (auto unOp = other.getDefiningOp()) { + if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty()) return false; - - // we also need to make sure the unary operation is used by the reduction - // operation. - for (auto *useUnary : unarySpOp->getUsers()) { - if (!isa(useUnary)) { - return false; - } - } + // And the bodies are as expected. + auto yieldUn = cast( + unOp.getRegion(0).front().getTerminator()); + auto yieldRed = cast( + redOp.getRegion().front().getTerminator()); + return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) && + matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0)); } } - return true; + return false; } /// Test for sorted COO with suitable data and coordinates types. @@ -679,37 +672,24 @@ return success(); } -// TODO: identify alpha and beta and pass them to the CUDA calls /// Match and rewrite SDDMM kernel. static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT) { - - // For now, this pass reuses C and copies the result non-zero elements to - // overwrite C's. - // As an ad hoc solution, this pass also assumes the linalg takes a,b,c as - // input argument, and c as the output. It recognizes this pattern and rewrite - // it. - Location loc = op.getLoc(); Value a = op.getOperand(0); Value b = op.getOperand(1); Value c = op.getOperand(2); - SmallVector tokens; - // Only admissible sparse matrix format and dense matrices. + // Only admissible sparse matrix format and dense matrices, no COO. bool isCOO = false; SparseTensorType aTp = getSparseTensorType(a); SparseTensorType bTp = getSparseTensorType(b); SparseTensorType cTp = getSparseTensorType(c); - if (!areAdmissibleTypes(cTp, bTp, aTp, enableRT, false, isCOO)) return failure(); - - // cusparse currently does not support COO in its SDDMM kernel. - if (isCOO) { + if (isCOO) return failure(); - } // The SDDMM does the in-place operation. // Start sparse kernel and copy data from host to device. @@ -798,8 +778,8 @@ genBlockingWait(rewriter, loc, tokens); tokens.clear(); + // Done. rewriter.replaceOpWithNewOp(op, c); - return success(); } @@ -933,6 +913,7 @@ bindDims(getContext(), i, j, k); // TODO: more robust patterns, tranposed versions, more kernels... + // TODO: identify alpha and beta and pass them to the CUDA calls // Recognize a SpMV kernel. if (numLoops == 2 && numTensors == 3 && diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir --- a/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir +++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir @@ -1,5 +1,4 @@ -// RUN: mlir-opt %s --linalg-generalize-named-ops \ -// RUN: --sparsification="enable-gpu-libgen" | FileCheck %s +// RUN: mlir-opt %s --sparsification="enable-gpu-libgen" | FileCheck %s #trait_sampled_dense_dense = { indexing_maps = [ @@ -22,8 +21,6 @@ #CSR = #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }> -module { - // CHECK-LABEL: func.func @sparse_sampled_dd( // CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>, @@ -82,7 +79,7 @@ // A kernel that computes a direct sampled matrix matrix multiplication // (with sparse result). // Compute SDDMM C = C\spy AB -// +// func.func @sparse_sampled_dd(%argS: tensor<8x8xf64, #CSR>, %argA: tensor<8x8xf64>, %argB: tensor<8x8xf64>) -> tensor<8x8xf64, #CSR> { @@ -106,6 +103,4 @@ linalg.yield %r : f64 } -> tensor<8x8xf64, #CSR> return %result : tensor<8x8xf64, #CSR> - } - }