diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -415,7 +415,7 @@ would be equivalent to a union operation where non-overlapping values in the inputs are copied to the output unchanged. - Example of isEqual applied to intersecting elements only: + Example of isEqual applied to intersecting elements only. ```mlir %C = sparse_tensor.init... %0 = linalg.generic #trait @@ -435,7 +435,8 @@ } -> tensor ``` - Example of A+B in upper triangle, A-B in lower triangle: + Example of A+B in upper triangle, A-B in lower triangle + (not working yet, but construct will be available soon). ```mlir %C = sparse_tensor.init... %1 = linalg.generic #trait diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -46,6 +46,8 @@ kCastIdx, kTruncI, kBitCast, + kBinaryBranch, // semiring unary branch created from a binary op + kUnary, // semiring unary op // Binary operations. kMulF, kMulI, @@ -62,6 +64,7 @@ kShrS, // signed kShrU, // unsigned kShlI, + kBinary, // semiring binary op }; /// Children subexpressions of tensor operations. @@ -72,7 +75,7 @@ /// Tensor expression. Represents a MLIR expression in tensor index notation. struct TensorExp { - TensorExp(Kind k, unsigned x, unsigned y, Value v); + TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *operation); /// Tensor expression kind. Kind kind; @@ -92,6 +95,12 @@ /// infer destination type) of a cast operation During code generation, /// this field may be used to cache "hoisted" loop invariant tensor loads. Value val; + + /// Code blocks used by semirings. For the case of kUnary and + /// kBinary, this holds the original operation with all regions. For + /// kBinaryBranch, this holds the YieldOp for the left or right half + /// to be merged into a nested scf loop. + Operation *op; }; /// Lattice point. Each lattice point consists of a conjunction of tensor @@ -110,7 +119,7 @@ /// must execute. Pre-computed during codegen to avoid repeated eval. BitVector simple; - /// Index of the tensor expresssion. + /// Index of the tensor expression. unsigned exp; }; @@ -130,9 +139,14 @@ hasSparseOut(false), dims(t + 1, std::vector(l, Dim::kUndef)) {} /// Adds a tensor expression. Returns its index. - unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()); - unsigned addExp(Kind k, unsigned e, Value v) { return addExp(k, e, -1u, v); } - unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); } + unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(), + Operation *op = nullptr); + unsigned addExp(Kind k, unsigned e, Value v, Operation *op = nullptr) { + return addExp(k, e, -1u, v, op); + } + unsigned addExp(Kind k, Value v, Operation *op = nullptr) { + return addExp(k, -1u, -1u, v, op); + } /// Adds an iteration lattice point. Returns its index. unsigned addLat(unsigned t, unsigned i, unsigned e); @@ -144,20 +158,31 @@ /// of loop indices (effectively constructing a larger "intersection" of those /// indices) with a newly constructed tensor (sub)expression of given kind. /// Returns the index of the new lattice point. - unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1); + unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1, + Operation *op = nullptr); /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of /// cartesian product. Returns the index of the new set. - unsigned takeConj(Kind kind, unsigned s0, unsigned s1); + unsigned takeConj(Kind kind, unsigned s0, unsigned s1, + Operation *op = nullptr); /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1). /// Returns the index of the new set. - unsigned takeDisj(Kind kind, unsigned s0, unsigned s1); + unsigned takeDisj(Kind kind, unsigned s0, unsigned s1, + Operation *op = nullptr); + + /// Disjunctive merge of two lattice sets L0 and L1 with custom handling of + /// the overlap, left, and right regions. Any region may be left missing in + /// the output. Returns the index of the new set. + unsigned takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, + bool includeLeft, Kind ltrans, Operation *opleft, + bool includeRight, Kind rtrans, Operation *opright); /// Maps the unary operator over the lattice set of the operand, i.e. each /// lattice point on an expression E is simply copied over, but with OP E /// as new expression. Returns the index of the new set. - unsigned mapSet(Kind kind, unsigned s0, Value v = Value()); + unsigned mapSet(Kind kind, unsigned s0, Value v = Value(), + Operation *op = nullptr); /// Optimizes the iteration lattice points in the given set. This /// method should be called right before code generation to avoid diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -781,7 +781,7 @@ /// Generates a store on a dense or sparse tensor. static void genTensorStore(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, - Value rhs) { + unsigned exp, Value rhs) { Location loc = op.getLoc(); // Test if this is a scalarized reduction. if (codegen.redVal) { @@ -794,7 +794,14 @@ // Store during insertion. OpOperand *t = op.getOutputOperand(0); if (t == codegen.sparseOut) { - genInsertionStore(codegen, rewriter, op, t, rhs); + if (!rhs) { + // Only unary and binary are allowed to return uninitialized rhs + // to indicate missing output. + Kind kind = merger.exp(exp).kind; + assert(kind == kUnary || kind == kBinary); + } else { + genInsertionStore(codegen, rewriter, op, t, rhs); + } return; } // Actual store. @@ -974,7 +981,7 @@ updateReduc(merger, codegen, Value()); codegen.redExp = -1u; codegen.redKind = kNoReduc; - genTensorStore(merger, codegen, rewriter, op, redVal); + genTensorStore(merger, codegen, rewriter, op, exp, redVal); } } else { // Start or end loop invariant hoisting of a tensor load. @@ -1217,8 +1224,7 @@ /// Emit a while-loop for co-iteration over multiple indices. static Operation *genWhile(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, - unsigned idx, bool needsUniv, - BitVector &indices) { + unsigned idx, bool needsUniv, BitVector &indices) { SmallVector types; SmallVector operands; // Construct the while-loop with a parameter for each index. @@ -1365,8 +1371,7 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, unsigned idx, bool needsUniv, - BitVector &induction, - scf::WhileOp whileOp) { + BitVector &induction, scf::WhileOp whileOp) { Location loc = op.getLoc(); // Finalize each else branch of all if statements. if (codegen.redVal || codegen.expValues) { @@ -1591,7 +1596,7 @@ if (at == topSort.size()) { unsigned ldx = topSort[at - 1]; Value rhs = genExp(merger, codegen, rewriter, op, exp, ldx); - genTensorStore(merger, codegen, rewriter, op, rhs); + genTensorStore(merger, codegen, rewriter, op, exp, rhs); return; } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/SparseTensor/Utils/Merger.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/Operation.h" #include "llvm/Support/Debug.h" @@ -19,18 +20,18 @@ // Constructors. //===----------------------------------------------------------------------===// -TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v) - : kind(k), val(v) { +TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) + : kind(k), val(v), op(o) { switch (kind) { case kTensor: - assert(x != -1u && y == -1u && !v); + assert(x != -1u && y == -1u && !v && !o); tensor = x; break; case kInvariant: - assert(x == -1u && y == -1u && v); + assert(x == -1u && y == -1u && v && !o); break; case kIndex: - assert(x != -1u && y == -1u && !v); + assert(x != -1u && y == -1u && !v && !o); index = x; break; case kAbsF: @@ -38,7 +39,7 @@ case kFloorF: case kNegF: case kNegI: - assert(x != -1u && y == -1u && !v); + assert(x != -1u && y == -1u && !v && !o); children.e0 = x; children.e1 = y; break; @@ -53,12 +54,29 @@ case kCastIdx: case kTruncI: case kBitCast: - assert(x != -1u && y == -1u && v); + assert(x != -1u && y == -1u && v && !o); + children.e0 = x; + children.e1 = y; + break; + case kBinaryBranch: + assert(x != -1u && y == -1u && !v && o); + children.e0 = x; + children.e1 = y; + break; + case kUnary: + // No assertion on y can be made, as the branching paths involve both + // a unary (mapSet) and binary (takeDisj) pathway. + assert(x != -1u && !v && o); + children.e0 = x; + children.e1 = y; + break; + case kBinary: + assert(x != -1u && y != -1u && !v && o); children.e0 = x; children.e1 = y; break; default: - assert(x != -1u && y != -1u && !v); + assert(x != -1u && y != -1u && !v && !o); children.e0 = x; children.e1 = y; break; @@ -77,9 +95,10 @@ // Lattice methods. //===----------------------------------------------------------------------===// -unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { +unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v, + Operation *op) { unsigned e = tensorExps.size(); - tensorExps.push_back(TensorExp(k, e0, e1, v)); + tensorExps.push_back(TensorExp(k, e0, e1, v, op)); return e; } @@ -96,29 +115,31 @@ return s; } -unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) { +unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1, + Operation *op) { unsigned p = latPoints.size(); BitVector nb = BitVector(latPoints[p0].bits); nb |= latPoints[p1].bits; - unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); + unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op); latPoints.push_back(LatPoint(nb, e)); return p; } -unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) { +unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) { unsigned s = addSet(); for (unsigned p0 : latSets[s0]) for (unsigned p1 : latSets[s1]) - latSets[s].push_back(conjLatPoint(kind, p0, p1)); + latSets[s].push_back(conjLatPoint(kind, p0, p1, op)); return s; } -unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) { - unsigned s = takeConj(kind, s0, s1); +unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) { + unsigned s = takeConj(kind, s0, s1, op); // Followed by all in s0. for (unsigned p : latSets[s0]) latSets[s].push_back(p); // Map binary 0-y to unary -y. + // TODO: move this if-else logic into buildLattices if (kind == kSubF) s1 = mapSet(kNegF, s1); else if (kind == kSubI) @@ -129,11 +150,32 @@ return s; } -unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) { - assert(kAbsF <= kind && kind <= kBitCast); +unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, + bool includeLeft, Kind ltrans, Operation *opleft, + bool includeRight, Kind rtrans, Operation *opright) { + unsigned s = takeConj(kind, s0, s1, orig); + // Left Region. + if (includeLeft) { + if (opleft) + s0 = mapSet(ltrans, s0, Value(), opleft); + for (unsigned p : latSets[s0]) + latSets[s].push_back(p); + } + // Right Region. + if (includeRight) { + if (opright) + s1 = mapSet(rtrans, s1, Value(), opright); + for (unsigned p : latSets[s1]) + latSets[s].push_back(p); + } + return s; +} + +unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) { + assert(kAbsF <= kind && kind <= kUnary); unsigned s = addSet(); for (unsigned p : latSets[s0]) { - unsigned e = addExp(kind, latPoints[p].exp, v); + unsigned e = addExp(kind, latPoints[p].exp, v, op); latPoints.push_back(LatPoint(latPoints[p].bits, e)); latSets[s].push_back(latPoints.size() - 1); } @@ -303,6 +345,10 @@ case kTruncI: case kBitCast: return "cast"; + case kBinaryBranch: + return "binary_branch"; + case kUnary: + return "unary"; case kMulF: return "*"; case kMulI: @@ -333,6 +379,8 @@ return ">>"; case kShlI: return "<<"; + case kBinary: + return "binary"; } llvm_unreachable("unexpected kind for symbol"); } @@ -474,6 +522,35 @@ // | 0 |-y | return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), tensorExps[e].val); + case kBinaryBranch: + // The left or right half of a binary operation which has already + // been split into separate operations for each region. + return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(), + tensorExps[e].op); + case kUnary: + // A custom unary operation. + // + // op y| !y | y | + // ----+----------+------------+ + // | absent() | present(y) | + { + unsigned child0 = buildLattices(tensorExps[e].children.e0, i); + UnaryOp unop = cast(tensorExps[e].op); + Region &absentRegion = unop.absentRegion(); + + if (absentRegion.empty()) { + // Simple mapping over existing values. + return mapSet(kind, child0, Value(), unop); + } else { + // Use a disjunction with `unop` on the left and the absent value as an + // invariant on the right. + Block &absentBlock = absentRegion.front(); + YieldOp absentYield = cast(absentBlock.getTerminator()); + Value absentVal = absentYield.result(); + unsigned rhs = addExp(kInvariant, absentVal); + return takeDisj(kind, child0, buildLattices(rhs, i), unop); + } + } case kMulF: case kMulI: case kAndI: @@ -533,6 +610,37 @@ return takeConj(kind, // take binary conjunction buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); + case kBinary: + // A custom binary operation. + // + // x op y| !y | y | + // ------+---------+--------------+ + // !x | empty | right(y) | + // x | left(x) | overlap(x,y) | + { + unsigned child0 = buildLattices(tensorExps[e].children.e0, i); + unsigned child1 = buildLattices(tensorExps[e].children.e1, i); + BinaryOp binop = cast(tensorExps[e].op); + Region &leftRegion = binop.leftRegion(); + Region &rightRegion = binop.rightRegion(); + // Left Region. + Operation *leftYield = nullptr; + if (!leftRegion.empty()) { + Block &leftBlock = leftRegion.front(); + leftYield = leftBlock.getTerminator(); + } + // Right Region. + Operation *rightYield = nullptr; + if (!rightRegion.empty()) { + Block &rightBlock = rightRegion.front(); + rightYield = rightBlock.getTerminator(); + } + bool includeLeft = binop.left_identity() || !leftRegion.empty(); + bool includeRight = binop.right_identity() || !rightRegion.empty(); + return takeCombi(kBinary, child0, child1, binop, includeLeft, + kBinaryBranch, leftYield, includeRight, kBinaryBranch, + rightYield); + } } llvm_unreachable("unexpected expression kind"); } @@ -627,6 +735,8 @@ return addExp(kTruncI, e, v); if (isa(def)) return addExp(kBitCast, e, v); + if (isa(def)) + return addExp(kUnary, e, Value(), def); } } // Construct binary operations if subexpressions can be built. @@ -668,12 +778,59 @@ return addExp(kShrU, e0, e1); if (isa(def) && isInvariant(e1)) return addExp(kShlI, e0, e1); + if (isa(def)) + return addExp(kBinary, e0, e1, Value(), def); } } // Cannot build. return None; } +static Value insertYieldOp(PatternRewriter &rewriter, Location loc, + Region ®ion, ValueRange vals) { + // Make a clone of overlap region. + Region tmpRegion; + BlockAndValueMapping mapper; + region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); + Block &clonedBlock = tmpRegion.front(); + YieldOp clonedYield = cast(clonedBlock.getTerminator()); + // Merge cloned block and return yield value. + Operation *placeholder = rewriter.create(loc, 0); + rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals); + Value val = clonedYield.result(); + rewriter.eraseOp(clonedYield); + rewriter.eraseOp(placeholder); + return val; +} + +static Value buildUnaryPresent(PatternRewriter &rewriter, Location loc, + Operation *op, Value v0) { + if (!v0) + // Empty input value must be propagated. + return Value(); + UnaryOp unop = cast(op); + Region &presentRegion = unop.presentRegion(); + if (presentRegion.empty()) + // Uninitialized Value() will be interpreted as missing data in the + // output. + return Value(); + return insertYieldOp(rewriter, loc, presentRegion, {v0}); +} + +static Value buildBinaryOverlap(PatternRewriter &rewriter, Location loc, + Operation *op, Value v0, Value v1) { + if (!v0 || !v1) + // Empty input values must be propagated. + return Value(); + BinaryOp binop = cast(op); + Region &overlapRegion = binop.overlapRegion(); + if (overlapRegion.empty()) + // Uninitialized Value() will be interpreted as missing data in the + // output. + return Value(); + return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); +} + Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, Value v0, Value v1) { switch (tensorExps[e].kind) { @@ -749,6 +906,14 @@ return rewriter.create(loc, v0, v1); case kShlI: return rewriter.create(loc, v0, v1); + // Semiring ops with custom logic. + case kBinaryBranch: + return insertYieldOp(rewriter, loc, + *tensorExps[e].op->getBlock()->getParent(), {v0}); + case kUnary: + return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); + case kBinary: + return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); } llvm_unreachable("unexpected expression kind in build"); } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir @@ -0,0 +1,294 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// +// Traits for tensor operations. +// +#trait_vec_scale = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a (in) + affine_map<(i) -> (i)> // x (out) + ], + iterator_types = ["parallel"] +} +#trait_vec_op = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a (in) + affine_map<(i) -> (i)>, // b (in) + affine_map<(i) -> (i)> // x (out) + ], + iterator_types = ["parallel"] +} +#trait_mat_op = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A (in) + affine_map<(i,j) -> (i,j)>, // B (in) + affine_map<(i,j) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) OP B(i,j)" +} + +module { + // Creates a new sparse vector using the minimum values from two input sparse vectors. + // When there is no overlap, include the present value in the output. + func @vector_min(%arga: tensor, + %argb: tensor) -> tensor { + %c = arith.constant 0 : index + %d = tensor.dim %arga, %c : tensor + %xv = sparse_tensor.init [%d] : tensor + %0 = linalg.generic #trait_vec_op + ins(%arga, %argb: tensor, tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %b: f64, %x: f64): + %1 = sparse_tensor.binary %a, %b : f64, f64 to f64 + overlap={ + ^bb0(%a0: f64, %b0: f64): + %cmp = arith.cmpf "olt", %a0, %b0 : f64 + %2 = arith.select %cmp, %a0, %b0: f64 + sparse_tensor.yield %2 : f64 + } + left=identity + right=identity + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + // Creates a new sparse vector by multiplying a sparse vector with a dense vector. + // When there is no overlap, leave the result empty. + func @vector_mul(%arga: tensor, + %argb: tensor) -> tensor { + %c = arith.constant 0 : index + %d = tensor.dim %arga, %c : tensor + %xv = sparse_tensor.init [%d] : tensor + %0 = linalg.generic #trait_vec_op + ins(%arga, %argb: tensor, tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %b: f64, %x: f64): + %1 = sparse_tensor.binary %a, %b : f64, f64 to f64 + overlap={ + ^bb0(%a0: f64, %b0: f64): + %ret = arith.mulf %a0, %b0 : f64 + sparse_tensor.yield %ret : f64 + } + left={} + right={} + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + // Take a set difference of two sparse vectors. The result will include only those + // sparse elements present in the first, but not the second vector. + func @vector_setdiff(%arga: tensor, + %argb: tensor) -> tensor { + %c = arith.constant 0 : index + %d = tensor.dim %arga, %c : tensor + %xv = sparse_tensor.init [%d] : tensor + %0 = linalg.generic #trait_vec_op + ins(%arga, %argb: tensor, tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %b: f64, %x: f64): + %1 = sparse_tensor.binary %a, %b : f64, f64 to f64 + overlap={} + left=identity + right={} + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + // Return the index of each entry + func @vector_index(%arga: tensor) -> tensor { + %c = arith.constant 0 : index + %d = tensor.dim %arga, %c : tensor + %xv = sparse_tensor.init [%d] : tensor + %0 = linalg.generic #trait_vec_scale + ins(%arga: tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %x: i32): + %idx = linalg.index 0 : index + %1 = sparse_tensor.binary %a, %idx : f64, index to i32 + overlap={ + ^bb0(%x0: f64, %i: index): + %ret = arith.index_cast %i : index to i32 + sparse_tensor.yield %ret : i32 + } + left={} + right={} + linalg.yield %1 : i32 + } -> tensor + return %0 : tensor + } + + // Adds two sparse matrices when they intersect. Where they don't intersect, + // negate the 2nd argument's values; ignore 1st argument-only values. + func @matrix_intersect(%arga: tensor, + %argb: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arga, %c0 : tensor + %d1 = tensor.dim %arga, %c1 : tensor + %xv = sparse_tensor.init [%d0, %d1] : tensor + %0 = linalg.generic #trait_mat_op + ins(%arga, %argb: tensor, tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %b: f64, %x: f64): + %1 = sparse_tensor.binary %a, %b: f64, f64 to f64 + overlap={ + ^bb0(%x0: f64, %y0: f64): + %ret = arith.addf %x0, %y0 : f64 + sparse_tensor.yield %ret : f64 + } + left={} + right={ + ^bb0(%x1: f64): + %lret = arith.negf %x1 : f64 + sparse_tensor.yield %lret : f64 + } + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + // Dumps a sparse vector of type f64. + func @dump_vec(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f64 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<16xf64> + vector.print %1 : vector<16xf64> + // Dump the dense vector to verify structure is correct. + %dv = sparse_tensor.convert %arg0 : tensor to tensor + %2 = bufferization.to_memref %dv : memref + %3 = vector.transfer_read %2[%c0], %d0: memref, vector<32xf64> + vector.print %3 : vector<32xf64> + memref.dealloc %2 : memref + return + } + + // Dumps a sparse vector of type i32. + func @dump_vec_i32(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1 : i32 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<24xi32> + vector.print %1 : vector<24xi32> + // Dump the dense vector to verify structure is correct. + %dv = sparse_tensor.convert %arg0 : tensor to tensor + %2 = bufferization.to_memref %dv : memref + %3 = vector.transfer_read %2[%c0], %d0: memref, vector<32xi32> + vector.print %3 : vector<32xi32> + memref.dealloc %2 : memref + return + } + + // Dump a sparse matrix. + func @dump_mat(%arg0: tensor) { + %d0 = arith.constant 0.0 : f64 + %c0 = arith.constant 0 : index + %dm = sparse_tensor.convert %arg0 : tensor to tensor + %0 = bufferization.to_memref %dm : memref + %1 = vector.transfer_read %0[%c0, %c0], %d0: memref, vector<4x8xf64> + vector.print %1 : vector<4x8xf64> + memref.dealloc %0 : memref + return + } + + // Driver method to call and verify vector kernels. + func @entry() { + %c0 = arith.constant 0 : index + + // Setup sparse vectors. + %v1 = arith.constant sparse< + [ [0], [3], [11], [17], [20], [21], [28], [29], [31] ], + [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ] + > : tensor<32xf64> + %v2 = arith.constant sparse< + [ [1], [3], [4], [10], [16], [18], [21], [28], [29], [31] ], + [11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0 ] + > : tensor<32xf64> + %v3 = arith.constant dense< + [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., + 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 0., 1.] + > : tensor<32xf64> + %sv1 = sparse_tensor.convert %v1 : tensor<32xf64> to tensor + %sv2 = sparse_tensor.convert %v2 : tensor<32xf64> to tensor + %dv3 = tensor.cast %v3 : tensor<32xf64> to tensor + + // Setup sparse matrices. + %m1 = arith.constant sparse< + [ [0,0], [0,1], [1,7], [2,2], [2,4], [2,7], [3,0], [3,2], [3,3] ], + [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ] + > : tensor<4x8xf64> + %m2 = arith.constant sparse< + [ [0,0], [0,7], [1,0], [1,6], [2,1], [2,7] ], + [6.0, 5.0, 4.0, 3.0, 2.0, 1.0 ] + > : tensor<4x8xf64> + %sm1 = sparse_tensor.convert %m1 : tensor<4x8xf64> to tensor + %sm2 = sparse_tensor.convert %m2 : tensor<4x8xf64> to tensor + + // Call sparse vector kernels. + %0 = call @vector_min(%sv1, %sv2) + : (tensor, + tensor) -> tensor + %1 = call @vector_mul(%sv1, %dv3) + : (tensor, + tensor) -> tensor + %2 = call @vector_setdiff(%sv1, %sv2) + : (tensor, + tensor) -> tensor + %3 = call @vector_index(%sv1) + : (tensor) -> tensor + + // Call sparse matrix kernels. + %5 = call @matrix_intersect(%sm1, %sm2) + : (tensor, tensor) -> tensor + + // + // Verify the results. + // + // CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 4, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 8, 0, 9 ) + // CHECK-NEXT: ( 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 0, 11, 0, 12, 13, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 15, 0, 16, 0, 0, 17, 0, 0, 0, 0, 0, 0, 18, 19, 0, 20 ) + // CHECK-NEXT: ( 1, 11, 2, 13, 14, 3, 15, 4, 16, 5, 6, 7, 8, 9, -1, -1 ) + // CHECK-NEXT: ( 1, 11, 0, 2, 13, 0, 0, 0, 0, 0, 14, 3, 0, 0, 0, 0, 15, 4, 16, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 8, 0, 9 ) + // CHECK-NEXT: ( 0, 6, 3, 28, 0, 6, 56, 72, 9, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 28, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 56, 72, 0, 9 ) + // CHECK-NEXT: ( 1, 3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 4, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 3, 11, 17, 20, 21, 28, 29, 31, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 17, 0, 0, 20, 21, 0, 0, 0, 0, 0, 0, 28, 29, 0, 31 ) + // CHECK-NEXT: ( ( 7, 0, 0, 0, 0, 0, 0, -5 ), ( -4, 0, 0, 0, 0, 0, -3, 0 ), ( 0, -2, 0, 0, 0, 0, 0, 7 ), ( 0, 0, 0, 0, 0, 0, 0, 0 ) ) + // + call @dump_vec(%sv1) : (tensor) -> () + call @dump_vec(%sv2) : (tensor) -> () + call @dump_vec(%0) : (tensor) -> () + call @dump_vec(%1) : (tensor) -> () + call @dump_vec(%2) : (tensor) -> () + call @dump_vec_i32(%3) : (tensor) -> () + call @dump_mat(%5) : (tensor) -> () + + // Release the resources. + sparse_tensor.release %sv1 : tensor + sparse_tensor.release %sv2 : tensor + sparse_tensor.release %sm1 : tensor + sparse_tensor.release %sm2 : tensor + sparse_tensor.release %0 : tensor + sparse_tensor.release %1 : tensor + sparse_tensor.release %2 : tensor + sparse_tensor.release %3 : tensor + sparse_tensor.release %5 : tensor + return + } +} \ No newline at end of file diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir @@ -131,7 +131,7 @@ %sm1 = sparse_tensor.convert %m1 : tensor<4x8xf64> to tensor %sm2 = sparse_tensor.convert %m2 : tensor<4x8xf64> to tensor - // Call sparse vector kernels. + // Call sparse matrix kernels. %0 = call @matrix_scale(%sm1) : (tensor) -> tensor %1 = call @matrix_scale_inplace(%sm1) diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir @@ -0,0 +1,205 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// +// Traits for tensor operations. +// +#trait_vec_scale = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a (in) + affine_map<(i) -> (i)> // x (out) + ], + iterator_types = ["parallel"] +} +#trait_mat_scale = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A (in) + affine_map<(i,j) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel"] +} + +module { + // Invert the structure of a sparse vector. Present values become missing. + // Missing values are filled with 1 (i32). + func @vector_complement(%arga: tensor) -> tensor { + %c = arith.constant 0 : index + %ci1 = arith.constant 1 : i32 + %d = tensor.dim %arga, %c : tensor + %xv = sparse_tensor.init [%d] : tensor + %0 = linalg.generic #trait_vec_scale + ins(%arga: tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %x: i32): + %1 = sparse_tensor.unary %a : f64 to i32 + present={} + absent={ + sparse_tensor.yield %ci1 : i32 + } + linalg.yield %1 : i32 + } -> tensor + return %0 : tensor + } + + // Negate existing values. Fill missing ones with +1. + func @vector_negation(%arga: tensor) -> tensor { + %c = arith.constant 0 : index + %cf1 = arith.constant 1.0 : f64 + %d = tensor.dim %arga, %c : tensor + %xv = sparse_tensor.init [%d] : tensor + %0 = linalg.generic #trait_vec_scale + ins(%arga: tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %x: f64): + %1 = sparse_tensor.unary %a : f64 to f64 + present={ + ^bb0(%x0: f64): + %ret = arith.negf %x0 : f64 + sparse_tensor.yield %ret : f64 + } + absent={ + sparse_tensor.yield %cf1 : f64 + } + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + // Clips values to the range [3, 7]. + func @matrix_clip(%argx: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cfmin = arith.constant 3.0 : f64 + %cfmax = arith.constant 7.0 : f64 + %d0 = tensor.dim %argx, %c0 : tensor + %d1 = tensor.dim %argx, %c1 : tensor + %xv = sparse_tensor.init [%d0, %d1] : tensor + %0 = linalg.generic #trait_mat_scale + ins(%argx: tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %x: f64): + %1 = sparse_tensor.unary %a: f64 to f64 + present={ + ^bb0(%x0: f64): + %mincmp = arith.cmpf "ogt", %x0, %cfmin : f64 + %x1 = arith.select %mincmp, %x0, %cfmin : f64 + %maxcmp = arith.cmpf "olt", %x1, %cfmax : f64 + %x2 = arith.select %maxcmp, %x1, %cfmax : f64 + sparse_tensor.yield %x2 : f64 + } + absent={} + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + // Dumps a sparse vector of type f64. + func @dump_vec_f64(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f64 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<32xf64> + vector.print %1 : vector<32xf64> + // Dump the dense vector to verify structure is correct. + %dv = sparse_tensor.convert %arg0 : tensor to tensor + %2 = bufferization.to_memref %dv : memref + %3 = vector.transfer_read %2[%c0], %d0: memref, vector<32xf64> + vector.print %3 : vector<32xf64> + memref.dealloc %2 : memref + return + } + + // Dumps a sparse vector of type i32. + func @dump_vec_i32(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1 : i32 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<24xi32> + vector.print %1 : vector<24xi32> + // Dump the dense vector to verify structure is correct. + %dv = sparse_tensor.convert %arg0 : tensor to tensor + %2 = bufferization.to_memref %dv : memref + %3 = vector.transfer_read %2[%c0], %d0: memref, vector<32xi32> + vector.print %3 : vector<32xi32> + memref.dealloc %2 : memref + return + } + + // Dump a sparse matrix. + func @dump_mat(%arg0: tensor) { + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f64 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<16xf64> + vector.print %1 : vector<16xf64> + %dm = sparse_tensor.convert %arg0 : tensor to tensor + %2 = bufferization.to_memref %dm : memref + %3 = vector.transfer_read %2[%c0, %c0], %d0: memref, vector<4x8xf64> + vector.print %3 : vector<4x8xf64> + memref.dealloc %2 : memref + return + } + + // Driver method to call and verify vector kernels. + func @entry() { + %c0 = arith.constant 0 : index + + // Setup sparse vectors. + %v1 = arith.constant sparse< + [ [0], [3], [11], [17], [20], [21], [28], [29], [31] ], + [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ] + > : tensor<32xf64> + %sv1 = sparse_tensor.convert %v1 : tensor<32xf64> to tensor + + // Setup sparse matrices. + %m1 = arith.constant sparse< + [ [0,0], [0,1], [1,7], [2,2], [2,4], [2,7], [3,0], [3,2], [3,3] ], + [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ] + > : tensor<4x8xf64> + %sm1 = sparse_tensor.convert %m1 : tensor<4x8xf64> to tensor + + // Call sparse vector kernels. + %0 = call @vector_complement(%sv1) + : (tensor) -> tensor + %1 = call @vector_negation(%sv1) + : (tensor) -> tensor + + + // Call sparse matrix kernels. + %2 = call @matrix_clip(%sm1) + : (tensor) -> tensor + + // + // Verify the results. + // + // CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 4, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 8, 0, 9 ) + // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1 ) + // CHECK-NEXT: ( 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0 ) + // CHECK-NEXT: ( -1, 1, 1, -2, 1, 1, 1, 1, 1, 1, 1, -3, 1, 1, 1, 1, 1, -4, 1, 1, -5, -6, 1, 1, 1, 1, 1, 1, -7, -8, 1, -9 ) + // CHECK-NEXT: ( -1, 1, 1, -2, 1, 1, 1, 1, 1, 1, 1, -3, 1, 1, 1, 1, 1, -4, 1, 1, -5, -6, 1, 1, 1, 1, 1, 1, -7, -8, 1, -9 ) + // CHECK-NEXT: ( 3, 3, 3, 4, 5, 6, 7, 7, 7, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( ( 3, 3, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 3 ), ( 0, 0, 4, 0, 5, 0, 0, 6 ), ( 7, 0, 7, 7, 0, 0, 0, 0 ) ) + // + call @dump_vec_f64(%sv1) : (tensor) -> () + call @dump_vec_i32(%0) : (tensor) -> () + call @dump_vec_f64(%1) : (tensor) -> () + call @dump_mat(%2) : (tensor) -> () + + // Release the resources. + sparse_tensor.release %sv1 : tensor + sparse_tensor.release %sm1 : tensor + sparse_tensor.release %0 : tensor + sparse_tensor.release %1 : tensor + sparse_tensor.release %2 : tensor + return + } +} \ No newline at end of file diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir @@ -125,19 +125,19 @@ // Sum reduces dot product of two sparse vectors. func @vector_dotprod(%arga: tensor, %argb: tensor, - %argx: tensor {linalg.inplaceable = true}) -> tensor { + %argx: tensor {linalg.inplaceable = true}) -> tensor { %0 = linalg.generic #trait_dot ins(%arga, %argb: tensor, tensor) outs(%argx: tensor) { ^bb(%a: f64, %b: f64, %x: f64): %1 = arith.mulf %a, %b : f64 - %2 = arith.addf %x, %1 : f64 + %2 = arith.addf %x, %1 : f64 linalg.yield %2 : f64 } -> tensor return %0 : tensor } - // Dumps a sparse vector. + // Dumps a sparse vector of type f64. func @dump(%arg0: tensor) { // Dump the values array to verify only sparse contents are stored. %c0 = arith.constant 0 : index