diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -84,6 +84,7 @@ kShrU, // unsigned kShlI, kBinary, // semiring binary op + kReduce, // semiring reduction op }; /// Children subexpressions of tensor operations. @@ -115,8 +116,8 @@ /// this field may be used to cache "hoisted" loop invariant tensor loads. Value val; - /// Code blocks used by semirings. For the case of kUnary and - /// kBinary, this holds the original operation with all regions. For + /// Code blocks used by semirings. For the case of kUnary, kBinary, and + /// kReduce, this holds the original operation with all regions. For /// kBinaryBranch, this holds the YieldOp for the left or right half /// to be merged into a nested scf loop. Operation *op; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -50,7 +50,7 @@ }; // Reduction kinds. -enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; +enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor, kCustom }; // Code generation. struct CodeGen { @@ -87,6 +87,7 @@ unsigned redExp = -1u; Value redVal; Reduction redKind = kNoReduc; + unsigned redCustom = -1u; // Sparse tensor as output. Implemented either through direct injective // insertion in lexicographic index order (where indices are updated // in the temporary array `lexIdx`) or through access pattern expansion @@ -373,6 +374,7 @@ static vector::CombiningKind getCombiningKind(Reduction kind) { switch (kind) { case kNoReduc: + case kCustom: break; case kSum: return vector::CombiningKind::ADD; @@ -408,6 +410,8 @@ return kOr; case Kind::kXorI: return kXor; + case Kind::kReduce: + return kCustom; default: llvm_unreachable("unexpected reduction operator"); } @@ -422,6 +426,7 @@ Value r = codegen.redVal; switch (codegen.redKind) { case kNoReduc: + case kCustom: break; case kSum: case kXor: @@ -454,6 +459,11 @@ codegen.redVal = merger.exp(codegen.redExp).val = reduc; } +/// Extracts identity from custom reduce. +static Value getCustomRedId(Operation *op) { + return dyn_cast(op).getIdentity(); +} + //===----------------------------------------------------------------------===// // Sparse compiler synthesis methods (statements and expressions). //===----------------------------------------------------------------------===// @@ -726,6 +736,25 @@ return builder.create(loc, codegen.expValues, index); } +/// Generates insertion code to implement dynamic tensor load for reduction. +static Value genInsertionLoadReduce(Merger &merger, CodeGen &codegen, + OpBuilder &builder, linalg::GenericOp op, + OpOperand *t) { + Location loc = op.getLoc(); + Value identity = getCustomRedId(merger.exp(codegen.redCustom).op); + // Direct lexicographic index order, tensor loads as identity. + if (!codegen.expValues) { + return identity; + } + // Load from expanded access pattern if filled, identity otherwise. + Value index = genIndex(codegen, op, t); + Value isFilled = + builder.create(loc, codegen.expFilled, index); + Value valAtIndex = + builder.create(loc, codegen.expValues, index); + return builder.create(loc, isFilled, valAtIndex, identity); +} + /// Generates insertion code to implement dynamic tensor store. static void genInsertionStore(CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, OpOperand *t, Value rhs) { @@ -780,8 +809,11 @@ } // Load during insertion. OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; - if (t == codegen.sparseOut) + if (t == codegen.sparseOut) { + if (codegen.redCustom != -1u) + return genInsertionLoadReduce(merger, codegen, builder, op, t); return genInsertionLoad(codegen, builder, op, t); + } // Actual load. SmallVector args; Value ptr = genSubscript(codegen, builder, op, t, args); @@ -953,6 +985,11 @@ return genInvariantValue(merger, codegen, rewriter, exp); if (merger.exp(exp).kind == Kind::kIndex) return genIndexValue(codegen, rewriter, merger.exp(exp).index, ldx); + if (merger.exp(exp).kind == Kind::kReduce) { + // Make custom reduction identity accessible for expanded access pattern. + assert(codegen.redCustom == -1u); + codegen.redCustom = exp; + } Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx); Value v1 = @@ -960,8 +997,11 @@ Value ee = merger.buildExp(rewriter, loc, exp, v0, v1); if (ee && (merger.exp(exp).kind == Kind::kUnary || merger.exp(exp).kind == Kind::kBinary || - merger.exp(exp).kind == Kind::kBinaryBranch)) + merger.exp(exp).kind == Kind::kBinaryBranch || + merger.exp(exp).kind == Kind::kReduce)) ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx); + if (merger.exp(exp).kind == Kind::kReduce) + codegen.redCustom = -1u; return ee; } @@ -989,7 +1029,7 @@ /// Hoists loop invariant tensor loads for which indices have been exhausted. static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned exp, unsigned ldx, - bool atStart, Kind last = Kind::kTensor) { + bool atStart, unsigned last = 0) { if (exp == -1u) return; if (merger.exp(exp).kind == Kind::kTensor) { @@ -1010,8 +1050,11 @@ if (lhs == t) { // Start or end a scalarized reduction if (atStart) { - Value load = genTensorLoad(merger, codegen, builder, op, exp); - codegen.redKind = getReduction(last); + Kind kind = merger.exp(last).kind; + Value load = kind == Kind::kReduce + ? getCustomRedId(merger.exp(last).op) + : genTensorLoad(merger, codegen, builder, op, exp); + codegen.redKind = getReduction(kind); codegen.redExp = exp; updateReduc(merger, codegen, load); } else { @@ -1031,11 +1074,10 @@ // Traverse into the binary operations. Note that we only hoist // tensor loads, since subsequent MLIR/LLVM passes know how to // deal with all other kinds of derived loop invariants. - Kind last = merger.exp(exp).kind; unsigned e0 = merger.exp(exp).children.e0; unsigned e1 = merger.exp(exp).children.e1; - genInvariants(merger, codegen, builder, op, e0, ldx, atStart, last); - genInvariants(merger, codegen, builder, op, e1, ldx, atStart, last); + genInvariants(merger, codegen, builder, op, e0, ldx, atStart, exp); + genInvariants(merger, codegen, builder, op, e1, ldx, atStart, exp); } } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -114,6 +114,7 @@ children.e1 = y; break; case kBinary: + case kReduce: assert(x != -1u && y != -1u && !v && o); children.e0 = x; children.e1 = y; @@ -376,6 +377,7 @@ case kOrI: case kXorI: case kBinary: + case kReduce: return false; } llvm_unreachable("unexpected kind"); @@ -476,6 +478,8 @@ return "<<"; case kBinary: return "binary"; + case kReduce: + return "reduce"; } llvm_unreachable("unexpected kind for symbol"); } @@ -554,6 +558,7 @@ case kShrU: case kShlI: case kBinary: + case kReduce: llvm::dbgs() << "("; dumpExp(tensorExps[e].children.e0); llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; @@ -794,6 +799,11 @@ kBinaryBranch, leftYield, includeRight, kBinaryBranch, rightYield); } + case kReduce: + // A custom reduce operation. + return takeConj(kind, buildLattices(tensorExps[e].children.e0, i), + buildLattices(tensorExps[e].children.e1, i), + tensorExps[e].op); } llvm_unreachable("unexpected expression kind"); } @@ -965,7 +975,7 @@ } // Construct binary operations if subexpressions can be built. // See buildLattices() for an explanation of rejecting certain - // division and shift operations + // division and shift operations. if (def->getNumOperands() == 2) { auto x = buildTensorExp(op, def->getOperand(0)); auto y = buildTensorExp(op, def->getOperand(1)); @@ -1020,6 +1030,21 @@ } } } + // Construct ternary operations if subexpressions can be built. + if (def->getNumOperands() == 3) { + auto x = buildTensorExp(op, def->getOperand(0)); + auto y = buildTensorExp(op, def->getOperand(1)); + auto z = buildTensorExp(op, def->getOperand(2)); + if (x.has_value() && y.has_value() && z.has_value()) { + unsigned e0 = x.value(); + unsigned e1 = y.value(); + // unsigned e2 = z.getValue(); + if (auto redop = dyn_cast(def)) { + if (isAdmissableBranch(redop, redop.getRegion())) + return addExp(kReduce, e0, e1, Value(), def); + } + } + } // Cannot build. return None; } @@ -1199,6 +1224,10 @@ return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); case kBinary: return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); + case kReduce: { + ReduceOp redOp = cast(tensorExps[e].op); + return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1}); + } } llvm_unreachable("unexpected expression kind in build"); } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir @@ -0,0 +1,234 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}> +#CSC = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> + +// +// Traits for tensor operations. +// +#trait_matmul = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,k)>, // A + affine_map<(i,j,k) -> (k,j)>, // B + affine_map<(i,j,k) -> (i,j)> // C (out) + ], + iterator_types = ["parallel", "parallel", "reduction"], + doc = "C(i,j) = SUM_k A(i,k) * B(k,j)" +} + +#trait_mat_reduce_rowwise = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A (in) + affine_map<(i,j) -> (i)> // X (out) + ], + iterator_types = ["parallel", "reduction"], + doc = "X(i) = PROD_j A(i,j)" +} + +#trait_mat_reduce_colwise = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A (in) + affine_map<(i,j) -> (j)> // X (out) + ], + iterator_types = ["reduction", "parallel"], + doc = "X(j) = PROD_i A(i,j)" +} + +module { + func.func @redProdLex(%arga: tensor) -> tensor { + %c0 = arith.constant 0 : index + %cf1 = arith.constant 1.0 : f64 + %d0 = tensor.dim %arga, %c0 : tensor + %xv = bufferization.alloc_tensor(%d0): tensor + %0 = linalg.generic #trait_mat_reduce_rowwise + ins(%arga: tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %b: f64): + %1 = sparse_tensor.reduce %a, %b, %cf1 : f64 { + ^bb0(%x: f64, %y: f64): + %2 = arith.mulf %x, %y : f64 + sparse_tensor.yield %2 : f64 + } + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + func.func @redProdExpand(%arga: tensor) -> tensor { + %c0 = arith.constant 0 : index + %cf1 = arith.constant 1.0 : f64 + %d0 = tensor.dim %arga, %c0 : tensor + %xv = bufferization.alloc_tensor(%d0): tensor + %0 = linalg.generic #trait_mat_reduce_rowwise + ins(%arga: tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %b: f64): + %1 = sparse_tensor.reduce %a, %b, %cf1 : f64 { + ^bb0(%x: f64, %y: f64): + %2 = arith.mulf %x, %y : f64 + sparse_tensor.yield %2 : f64 + } + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + func.func @min_plus_csrcsr(%arga: tensor, + %argb: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %maxf = arith.constant 1.0e999 : f64 + %d0 = tensor.dim %arga, %c0 : tensor + %d1 = tensor.dim %argb, %c1 : tensor + %xm = bufferization.alloc_tensor(%d0, %d1) : tensor + %0 = linalg.generic #trait_matmul + ins(%arga, %argb: tensor, tensor) + outs(%xm: tensor) { + ^bb(%a: f64, %b: f64, %output: f64): + %1 = sparse_tensor.binary %a, %b : f64, f64 to f64 + overlap = { + ^bb0(%x: f64, %y: f64): + %3 = arith.addf %x, %y : f64 + sparse_tensor.yield %3 : f64 + } + left={} + right={} + %2 = sparse_tensor.reduce %1, %output, %maxf : f64 { + ^bb0(%x: f64, %y: f64): + %cmp = arith.cmpf "olt", %x, %y : f64 + %3 = arith.select %cmp, %x, %y : f64 + sparse_tensor.yield %3 : f64 + } + linalg.yield %2 : f64 + } -> tensor + return %0 : tensor + } + + func.func @min_plus_csrcsc(%arga: tensor, + %argb: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %maxf = arith.constant 1.0e999 : f64 + %d0 = tensor.dim %arga, %c0 : tensor + %d1 = tensor.dim %argb, %c1 : tensor + %xm = bufferization.alloc_tensor(%d0, %d1) : tensor + %0 = linalg.generic #trait_matmul + ins(%arga, %argb: tensor, tensor) + outs(%xm: tensor) { + ^bb(%a: f64, %b: f64, %output: f64): + %1 = sparse_tensor.binary %a, %b : f64, f64 to f64 + overlap = { + ^bb0(%x: f64, %y: f64): + %3 = arith.addf %x, %y : f64 + sparse_tensor.yield %3 : f64 + } + left={} + right={} + %2 = sparse_tensor.reduce %1, %output, %maxf : f64 { + ^bb0(%x: f64, %y: f64): + %cmp = arith.cmpf "olt", %x, %y : f64 + %3 = arith.select %cmp, %x, %y : f64 + sparse_tensor.yield %3 : f64 + } + linalg.yield %2 : f64 + } -> tensor + return %0 : tensor + } + + // Dumps a sparse vector of type f64. + func.func @dump_vec(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f64 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<8xf64> + vector.print %1 : vector<8xf64> + // Dump the dense vector to verify structure is correct. + %dv = sparse_tensor.convert %arg0 : tensor to tensor + %2 = vector.transfer_read %dv[%c0], %d0: tensor, vector<16xf64> + vector.print %2 : vector<16xf64> + return + } + + // Dump a sparse matrix. + func.func @dump_mat(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f64 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<16xf64> + vector.print %1 : vector<16xf64> + %dm = sparse_tensor.convert %arg0 : tensor to tensor + %2 = vector.transfer_read %dm[%c0, %c0], %d0: tensor, vector<5x5xf64> + vector.print %2 : vector<5x5xf64> + return + } + + // Driver method to call and verify vector kernels. + func.func @entry() { + %c0 = arith.constant 0 : index + + // Setup sparse matrices. + %m1 = arith.constant sparse< + [ [0,0], [0,1], [1,0], [2,2], [2,3], [2,4], [3,0], [3,2], [3,3] ], + [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ] + > : tensor<4x5xf64> + %m2 = arith.constant sparse< + [ [0,0], [1,3], [2,0], [2,3], [3,1], [4,1] ], + [6.0, 5.0, 4.0, 3.0, 2.0, 11.0 ] + > : tensor<5x4xf64> + %sm1 = sparse_tensor.convert %m1 : tensor<4x5xf64> to tensor + %sm2r = sparse_tensor.convert %m2 : tensor<5x4xf64> to tensor + %sm2c = sparse_tensor.convert %m2 : tensor<5x4xf64> to tensor + + // Call sparse matrix kernels. + %1 = call @redProdLex(%sm1) : (tensor) -> tensor + %2 = call @redProdExpand(%sm2c) : (tensor) -> tensor + %5 = call @min_plus_csrcsr(%sm1, %sm2r) + : (tensor, tensor) -> tensor + %6 = call @min_plus_csrcsc(%sm1, %sm2c) + : (tensor, tensor) -> tensor + + // + // Verify the results. + // + // CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( ( 1, 2, 0, 0, 0 ), ( 3, 0, 0, 0, 0 ), ( 0, 0, 4, 5, 6 ), ( 7, 0, 8, 9, 0 ), ( -1, -1, -1, -1, -1 ) ) + // CHECK-NEXT: ( 6, 5, 4, 3, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( ( 6, 0, 0, 0, -1 ), ( 0, 0, 0, 5, -1 ), ( 4, 0, 0, 3, -1 ), ( 0, 2, 0, 0, -1 ), ( 0, 11, 0, 0, -1 ) ) + // CHECK-NEXT: ( 2, 3, 120, 504, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 2, 3, 120, 504, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 6, 5, 12, 2, 11, -1, -1, -1 ) + // CHECK-NEXT: ( 6, 5, 12, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 7, 7, 9, 8, 7, 7, 12, 11, 11, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( ( 7, 0, 0, 7, -1 ), ( 9, 0, 0, 0, -1 ), ( 8, 7, 0, 7, -1 ), ( 12, 11, 0, 11, -1 ), ( -1, -1, -1, -1, -1 ) ) + // TODO: Update once identity values are no longer inserted for non-overlapping dot product + // CHECK-NEXT: ( 7, inf, inf, 7, 9, inf, inf, inf, 8, 7, inf, 7, 12, 11, inf, 11 ) + // CHECK-NEXT: ( ( 7, inf, inf, 7, -1 ), ( 9, inf, inf, inf, -1 ), ( 8, 7, inf, 7, -1 ), ( 12, 11, inf, 11, -1 ), ( -1, -1, -1, -1, -1 ) ) + // + call @dump_mat(%sm1) : (tensor) -> () + call @dump_mat(%sm2r) : (tensor) -> () + call @dump_vec(%1) : (tensor) -> () + call @dump_vec(%2) : (tensor) -> () + call @dump_mat(%5) : (tensor) -> () + call @dump_mat(%6) : (tensor) -> () + + // Release the resources. + bufferization.dealloc_tensor %sm1 : tensor + bufferization.dealloc_tensor %sm2r : tensor + bufferization.dealloc_tensor %sm2c : tensor + bufferization.dealloc_tensor %1 : tensor + bufferization.dealloc_tensor %2 : tensor + bufferization.dealloc_tensor %5 : tensor + bufferization.dealloc_tensor %6 : tensor + return + } +} diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -283,6 +283,7 @@ case kShrU: case kShlI: case kBinary: + case kReduce: return compareExpression(tensorExp.children.e0, pattern->e0) && compareExpression(tensorExp.children.e1, pattern->e1); }