diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -415,7 +415,7 @@ would be equivalent to a union operation where non-overlapping values in the inputs are copied to the output unchanged. - Example of isEqual applied to intersecting elements only: + Example of isEqual applied to intersecting elements only. ```mlir %C = sparse_tensor.init... %0 = linalg.generic #trait @@ -435,32 +435,24 @@ } -> tensor ``` - Example of A+B in upper triangle, A-B in lower triangle: + Example of replacing every element by its column index. This uses + `linalg.index` to get the column index, treating the value like a + dense tensor usable in `sparse_tensor.binary`. ```mlir %C = sparse_tensor.init... %1 = linalg.generic #trait - ins(%A: tensor, %B: tensor - outs(%C: tensor { - ^bb0(%a: f64, %b: f64, %c: f64) : - %row = linalg.index 0 : index + ins(%A: tensor) + outs(%C: tensor) { + ^bb0(%a: f64, %c: i32) : %col = linalg.index 1 : index - %result = sparse_tensor.binary %a, %b : f64, f64 to f64 + %col32 = arith.index_cast %col : index to i32 + %result = sparse_tensor.binary %a, %col32 : f64, i32 to i32 overlap={ - ^bb0(%x: f64, %y: f64): - %cmp = arith.cmpi "uge", %column, %row : index - %upperTriangleResult = arith.addf %x, %y : f64 - %lowerTriangleResult = arith.subf %x, %y : f64 - %ret = arith.select %cmp, %upperTriangleResult, %lowerTriangleResult : f64 - sparse_tensor.yield %ret : f64 - } - left=identity - right={ - ^bb0(%y: f64): - %cmp = arith.cmpi "uge", %column, %row : index - %lowerTriangleResult = arith.negf %y : f64 - %ret = arith.select %cmp, %y, %lowerTriangleResult - sparse_tensor.yield %ret : f64 + ^bb0(%x: f64, %y: i32): + sparse_tensor.yield %y : i32 } + left={} + right={} linalg.yield %result : f64 } -> tensor ``` diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -46,6 +46,7 @@ kCastIdx, kTruncI, kBitCast, + kUnary, // Binary operations. kMulF, kMulI, @@ -62,6 +63,7 @@ kShrS, // signed kShrU, // unsigned kShlI, + kBinary, }; /// Children subexpressions of tensor operations. @@ -72,7 +74,8 @@ /// Tensor expression. Represents a MLIR expression in tensor index notation. struct TensorExp { - TensorExp(Kind k, unsigned x, unsigned y, Value v); + TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *orig, + Operation *merge); /// Tensor expression kind. Kind kind; @@ -92,6 +95,12 @@ /// infer destination type) of a cast operation During code generation, /// this field may be used to cache "hoisted" loop invariant tensor loads. Value val; + + /// Code blocks used by unary and binary. The original op must be + /// carried around for each nested level and might contain several + /// merge blocks for the various sparse regions of overlap. + Operation *origOp; + Operation *mergeOp; }; /// Lattice point. Each lattice point consists of a conjunction of tensor @@ -110,7 +119,7 @@ /// must execute. Pre-computed during codegen to avoid repeated eval. BitVector simple; - /// Index of the tensor expresssion. + /// Index of the tensor expression. unsigned exp; }; @@ -130,9 +139,16 @@ hasSparseOut(false), dims(t + 1, std::vector(l, Dim::kUndef)) {} /// Adds a tensor expression. Returns its index. - unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()); - unsigned addExp(Kind k, unsigned e, Value v) { return addExp(k, e, -1u, v); } - unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); } + unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(), + Operation *orig = nullptr, Operation *merge = nullptr); + unsigned addExp(Kind k, unsigned e, Value v, Operation *orig = nullptr, + Operation *merge = nullptr) { + return addExp(k, e, -1u, v, orig, merge); + } + unsigned addExp(Kind k, Value v, Operation *orig = nullptr, + Operation *merge = nullptr) { + return addExp(k, -1u, -1u, v, orig, merge); + } /// Adds an iteration lattice point. Returns its index. unsigned addLat(unsigned t, unsigned i, unsigned e); @@ -144,20 +160,36 @@ /// of loop indices (effectively constructing a larger "intersection" of those /// indices) with a newly constructed tensor (sub)expression of given kind. /// Returns the index of the new lattice point. - unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1); + unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1, + Operation *orig = nullptr, Operation *merge = nullptr); /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of /// cartesian product. Returns the index of the new set. - unsigned takeConj(Kind kind, unsigned s0, unsigned s1); + unsigned takeConj(Kind kind, unsigned s0, unsigned s1, + Operation *orig = nullptr, Operation *merge = nullptr); /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1). - /// Returns the index of the new set. - unsigned takeDisj(Kind kind, unsigned s0, unsigned s1); + /// Returns the index of the new set. Either left or right disjunction may be + /// dropped or modified as part of the function call. + unsigned takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *orig, + Operation *opboth, bool includeLeft, Kind ltrans, + Operation *opleft, bool includeRight, Kind rtrans, + Operation *opright); + unsigned takeDisj(Kind kind, unsigned s0, unsigned s1, + Operation *orig = nullptr, Operation *merge = nullptr) { + return takeDisj(kind, s0, s1, orig, merge, true, kind, nullptr, true, kind, + nullptr); + } + unsigned takeDisj(Kind kind, unsigned s0, unsigned s1, Kind rtrans) { + return takeDisj(kind, s0, s1, nullptr, nullptr, true, kind, nullptr, true, + rtrans, nullptr); + } /// Maps the unary operator over the lattice set of the operand, i.e. each /// lattice point on an expression E is simply copied over, but with OP E /// as new expression. Returns the index of the new set. - unsigned mapSet(Kind kind, unsigned s0, Value v = Value()); + unsigned mapSet(Kind kind, unsigned s0, Value v = Value(), + Operation *orig = nullptr, Operation *merge = nullptr); /// Optimizes the iteration lattice points in the given set. This /// method should be called right before code generation to avoid diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -794,7 +794,10 @@ // Store during insertion. OpOperand *t = op.getOutputOperand(0); if (t == codegen.sparseOut) { - genInsertionStore(codegen, rewriter, op, t, rhs); + // Some operations have conditional output (ex. sparse_tensor.unary) and + // indicate missing output by passing an unitialized Value() for rhs. + if (rhs) + genInsertionStore(codegen, rewriter, op, t, rhs); return; } // Actual store. @@ -1217,8 +1220,7 @@ /// Emit a while-loop for co-iteration over multiple indices. static Operation *genWhile(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, - unsigned idx, bool needsUniv, - BitVector &indices) { + unsigned idx, bool needsUniv, BitVector &indices) { SmallVector types; SmallVector operands; // Construct the while-loop with a parameter for each index. @@ -1365,8 +1367,7 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, unsigned idx, bool needsUniv, - BitVector &induction, - scf::WhileOp whileOp) { + BitVector &induction, scf::WhileOp whileOp) { Location loc = op.getLoc(); // Finalize each else branch of all if statements. if (codegen.redVal || codegen.expValues) { diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/SparseTensor/Utils/Merger.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/Operation.h" #include "llvm/Support/Debug.h" @@ -19,18 +20,19 @@ // Constructors. //===----------------------------------------------------------------------===// -TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v) +TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *orig, + Operation *merge) : kind(k), val(v) { switch (kind) { case kTensor: - assert(x != -1u && y == -1u && !v); + assert(x != -1u && y == -1u && !v && !orig && !merge); tensor = x; break; case kInvariant: - assert(x == -1u && y == -1u && v); + assert(x == -1u && y == -1u && v && !orig && !merge); break; case kIndex: - assert(x != -1u && y == -1u && !v); + assert(x != -1u && y == -1u && !v && !orig && !merge); index = x; break; case kAbsF: @@ -38,7 +40,7 @@ case kFloorF: case kNegF: case kNegI: - assert(x != -1u && y == -1u && !v); + assert(x != -1u && y == -1u && !v && !orig && !merge); children.e0 = x; children.e1 = y; break; @@ -53,12 +55,28 @@ case kCastIdx: case kTruncI: case kBitCast: - assert(x != -1u && y == -1u && v); + assert(x != -1u && y == -1u && v && !orig && !merge); children.e0 = x; children.e1 = y; break; + case kUnary: + assert(x != -1u && !v && (orig || merge)); + if (!orig) + assert(y == -1u); + children.e0 = x; + children.e1 = y; + origOp = orig; + mergeOp = merge; + break; + case kBinary: + assert(x != -1u && y != -1u && !v && orig); + children.e0 = x; + children.e1 = y; + origOp = orig; + mergeOp = merge; + break; default: - assert(x != -1u && y != -1u && !v); + assert(x != -1u && y != -1u && !v && !orig && !merge); children.e0 = x; children.e1 = y; break; @@ -77,9 +95,10 @@ // Lattice methods. //===----------------------------------------------------------------------===// -unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { +unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v, + Operation *orig, Operation *merge) { unsigned e = tensorExps.size(); - tensorExps.push_back(TensorExp(k, e0, e1, v)); + tensorExps.push_back(TensorExp(k, e0, e1, v, orig, merge)); return e; } @@ -96,44 +115,54 @@ return s; } -unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) { +unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1, + Operation *orig, Operation *merge) { unsigned p = latPoints.size(); BitVector nb = BitVector(latPoints[p0].bits); nb |= latPoints[p1].bits; - unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); + unsigned e = + addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), orig, merge); latPoints.push_back(LatPoint(nb, e)); return p; } -unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) { +unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *orig, + Operation *merge) { unsigned s = addSet(); for (unsigned p0 : latSets[s0]) for (unsigned p1 : latSets[s1]) - latSets[s].push_back(conjLatPoint(kind, p0, p1)); + latSets[s].push_back(conjLatPoint(kind, p0, p1, orig, merge)); return s; } -unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) { - unsigned s = takeConj(kind, s0, s1); - // Followed by all in s0. - for (unsigned p : latSets[s0]) - latSets[s].push_back(p); - // Map binary 0-y to unary -y. - if (kind == kSubF) - s1 = mapSet(kNegF, s1); - else if (kind == kSubI) - s1 = mapSet(kNegI, s1); - // Followed by all in s1. - for (unsigned p : latSets[s1]) - latSets[s].push_back(p); +unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *orig, + Operation *opboth, bool includeLeft, Kind ltrans, + Operation *opleft, bool includeRight, Kind rtrans, + Operation *opright) { + unsigned s = takeConj(kind, s0, s1, orig, opboth); + // Left Region. + if (includeLeft) { + if (opleft || ltrans != kind) + s0 = mapSet(ltrans, s0, Value(), nullptr, opleft); + for (unsigned p : latSets[s0]) + latSets[s].push_back(p); + } + // Right Region. + if (includeRight) { + if (opright || rtrans != kind) + s1 = mapSet(rtrans, s1, Value(), nullptr, opright); + for (unsigned p : latSets[s1]) + latSets[s].push_back(p); + } return s; } -unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) { - assert(kAbsF <= kind && kind <= kBitCast); +unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *orig, + Operation *merge) { + assert(kAbsF <= kind && kind <= kUnary); unsigned s = addSet(); for (unsigned p : latSets[s0]) { - unsigned e = addExp(kind, latPoints[p].exp, v); + unsigned e = addExp(kind, latPoints[p].exp, v, orig, merge); latPoints.push_back(LatPoint(latPoints[p].bits, e)); latSets[s].push_back(latPoints.size() - 1); } @@ -303,6 +332,8 @@ case kTruncI: case kBitCast: return "cast"; + case kUnary: + return "unary"; case kMulF: return "*"; case kMulI: @@ -333,6 +364,8 @@ return ">>"; case kShlI: return "<<"; + case kBinary: + return "binary"; } llvm_unreachable("unexpected kind for symbol"); } @@ -474,6 +507,43 @@ // | 0 |-y | return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), tensorExps[e].val); + case kUnary: + // A custom unary operation + // + // op y| !y | y | + // ----+----------+------------+ + // | absent() | present(y) | + { + unsigned child0 = buildLattices(tensorExps[e].children.e0, i); + if (!tensorExps[e].origOp) { + // The unary op has already been handled and this + // is a splinter that needs to be propagated. + return mapSet(kind, child0, Value(), nullptr, tensorExps[e].mergeOp); + } + UnaryOp unop = dyn_cast(tensorExps[e].origOp); + assert(unop); + Region &presentRegion = unop.presentRegion(); + Region &absentRegion = unop.absentRegion(); + + Operation *presentYield = nullptr; + if (!presentRegion.empty()) { + Block &presentBlock = presentRegion.front(); + presentYield = presentBlock.getTerminator(); + } + if (absentRegion.empty()) { + // Simple mapping over existing values + return mapSet(kind, child0, Value(), unop, presentYield); + } else { + // Use a disjunction with `unop` on the left and the absent value as an + // invariant on the right + Block &absentBlock = absentRegion.front(); + YieldOp absentYield = dyn_cast(absentBlock.getTerminator()); + Value absentVal = absentYield.result(); + unsigned rhs = addExp(kInvariant, absentVal); + return takeDisj(kind, child0, buildLattices(rhs, i), unop, + presentYield); + } + } case kMulF: case kMulI: case kAndI: @@ -509,8 +579,6 @@ buildLattices(tensorExps[e].children.e1, i)); case kAddF: case kAddI: - case kSubF: - case kSubI: case kOrI: case kXorI: // An additive operation needs to be performed @@ -523,6 +591,12 @@ return takeDisj(kind, // take binary disjunction buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); + case kSubF: + return takeDisj(kind, buildLattices(tensorExps[e].children.e0, i), + buildLattices(tensorExps[e].children.e1, i), kNegF); + case kSubI: + return takeDisj(kind, buildLattices(tensorExps[e].children.e0, i), + buildLattices(tensorExps[e].children.e1, i), kNegI); case kShrS: case kShrU: case kShlI: @@ -533,6 +607,47 @@ return takeConj(kind, // take binary conjunction buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); + case kBinary: + // A custom binary operation + // + // x op y| !y | y | + // ------+---------+--------------+ + // !x | empty | right(y) | + // x | left(x) | overlap(x,y) | + { + unsigned child0 = buildLattices(tensorExps[e].children.e0, i); + unsigned child1 = buildLattices(tensorExps[e].children.e1, i); + BinaryOp binop = dyn_cast(tensorExps[e].origOp); + assert(binop); + Region &overlapRegion = binop.overlapRegion(); + Region &leftRegion = binop.leftRegion(); + Region &rightRegion = binop.rightRegion(); + + // Overlap Region. + Operation *overlapYield = nullptr; + if (!overlapRegion.empty()) { + Block &overlapBlock = overlapRegion.front(); + overlapYield = overlapBlock.getTerminator(); + } + // Left Region. + Operation *leftYield = nullptr; + if (!leftRegion.empty()) { + Block &leftBlock = leftRegion.front(); + leftYield = leftBlock.getTerminator(); + } + // Right Region. + Operation *rightYield = nullptr; + if (!rightRegion.empty()) { + Block &rightBlock = rightRegion.front(); + rightYield = rightBlock.getTerminator(); + } + bool includeLeft = binop.left_identity() || !leftRegion.empty(); + bool includeRight = binop.right_identity() || !rightRegion.empty(); + return takeDisj(kBinary, child0, child1, binop, overlapYield, includeLeft, + binop.left_identity() ? kBinary : kUnary, leftYield, + includeRight, binop.right_identity() ? kBinary : kUnary, + rightYield); + } } llvm_unreachable("unexpected expression kind"); } @@ -627,6 +742,8 @@ return addExp(kTruncI, e, v); if (isa(def)) return addExp(kBitCast, e, v); + if (isa(def)) + return addExp(kUnary, e, Value(), def); } } // Construct binary operations if subexpressions can be built. @@ -668,6 +785,8 @@ return addExp(kShrU, e0, e1); if (isa(def) && isInvariant(e1)) return addExp(kShlI, e0, e1); + if (isa(def)) + return addExp(kBinary, e0, e1, Value(), def); } } // Cannot build. @@ -749,6 +868,39 @@ return rewriter.create(loc, v0, v1); case kShlI: return rewriter.create(loc, v0, v1); + // Set-like ops with custom logic. + case kUnary: + case kBinary: { + Operation *op = tensorExps[e].mergeOp; + // Null op indicates no output (i.e. the corresponding region + // in the output will have missing data). + if (!op) + return Value(); + // Make a clone of the block. + Region tmpRegion; + BlockAndValueMapping mapper; + op->getBlock()->getParent()->cloneInto(&tmpRegion, tmpRegion.begin(), + mapper); + Block &clonedBlock = tmpRegion.front(); + YieldOp clonedYield = dyn_cast(clonedBlock.getTerminator()); + // Merge cloned block and return yield value. + Operation *placeholder = rewriter.create(loc, 0); + if (clonedBlock.getNumArguments() == 2) { + // Empty input values must be propagated. + if (!v0 || !v1) + return Value(); + rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, {v0, v1}); + } else { + // Empty input value must be propagated + if (!v0) + return Value(); + rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, {v0}); + } + Value val = clonedYield.result(); + rewriter.eraseOp(clonedYield); + rewriter.eraseOp(placeholder); + return val; + } } llvm_unreachable("unexpected expression kind in build"); } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir @@ -0,0 +1,256 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// +// Traits for tensor operations. +// +#trait_vec_scale = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a (in) + affine_map<(i) -> (i)> // x (out) + ], + iterator_types = ["parallel"] +} +#trait_vec_op = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a (in) + affine_map<(i) -> (i)>, // b (in) + affine_map<(i) -> (i)> // x (out) + ], + iterator_types = ["parallel"] +} +#trait_mat_op = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A (in) + affine_map<(i,j) -> (i,j)>, // B (in) + affine_map<(i,j) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) OP B(i,j)" +} + +module { + // Creates a new sparse vector using the minimum values from two input sparse vectors. + // When there is no overlap, include the present value in the output. + func @vector_min(%arga: tensor, + %argb: tensor) -> tensor { + %c = arith.constant 0 : index + %d = tensor.dim %arga, %c : tensor + %xv = sparse_tensor.init [%d] : tensor + %0 = linalg.generic #trait_vec_op + ins(%arga, %argb: tensor, tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %b: f64, %x: f64): + %1 = sparse_tensor.binary %a, %b : f64, f64 to f64 + overlap={ + ^bb0(%a0: f64, %b0: f64): + %cmp = arith.cmpf "olt", %a0, %b0 : f64 + %2 = arith.select %cmp, %a0, %b0: f64 + sparse_tensor.yield %2 : f64 + } + left=identity + right=identity + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + // Take a set difference of two sparse vectors. The result will include only those + // sparse elements present in the first, but not the second vector. + func @vector_setdiff(%arga: tensor, + %argb: tensor) -> tensor { + %c = arith.constant 0 : index + %d = tensor.dim %arga, %c : tensor + %xv = sparse_tensor.init [%d] : tensor + %0 = linalg.generic #trait_vec_op + ins(%arga, %argb: tensor, tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %b: f64, %x: f64): + %1 = sparse_tensor.binary %a, %b : f64, f64 to f64 + overlap={} + left=identity + right={} + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + // Return the index of each entry + func @vector_index(%arga: tensor) -> tensor { + %c = arith.constant 0 : index + %d = tensor.dim %arga, %c : tensor + %xv = sparse_tensor.init [%d] : tensor + %0 = linalg.generic #trait_vec_scale + ins(%arga: tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %x: i32): + %idx = linalg.index 0 : index + %1 = sparse_tensor.binary %a, %idx : f64, index to i32 + overlap={ + ^bb0(%x0: f64, %i: index): + %ret = arith.index_cast %i : index to i32 + sparse_tensor.yield %ret : i32 + } + left={} + right={} + linalg.yield %1 : i32 + } -> tensor + return %0 : tensor + } + + // Adds two sparse matrices when they intersect. Where they don't intersect, + // negate the 2nd argument's values; ignore 1st argument-only values. + func @matrix_intersect(%arga: tensor, + %argb: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arga, %c0 : tensor + %d1 = tensor.dim %arga, %c1 : tensor + %xv = sparse_tensor.init [%d0, %d1] : tensor + %0 = linalg.generic #trait_mat_op + ins(%arga, %argb: tensor, tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %b: f64, %x: f64): + %1 = sparse_tensor.binary %a, %b: f64, f64 to f64 + overlap={ + ^bb0(%x0: f64, %y0: f64): + %ret = arith.addf %x0, %y0 : f64 + sparse_tensor.yield %ret : f64 + } + left={} + right={ + ^bb0(%x1: f64): + %lret = arith.negf %x1 : f64 + sparse_tensor.yield %lret : f64 + } + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + // Dumps a sparse vector of type f64. + func @dump_vec(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f64 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<16xf64> + vector.print %1 : vector<16xf64> + // Dump the dense vector to verify structure is correct. + %dv = sparse_tensor.convert %arg0 : tensor to tensor + %2 = bufferization.to_memref %dv : memref + %3 = vector.transfer_read %2[%c0], %d0: memref, vector<32xf64> + vector.print %3 : vector<32xf64> + memref.dealloc %2 : memref + return + } + + // Dumps a sparse vector of type i32. + func @dump_vec_i32(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1 : i32 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<24xi32> + vector.print %1 : vector<24xi32> + // Dump the dense vector to verify structure is correct. + %dv = sparse_tensor.convert %arg0 : tensor to tensor + %2 = bufferization.to_memref %dv : memref + %3 = vector.transfer_read %2[%c0], %d0: memref, vector<32xi32> + vector.print %3 : vector<32xi32> + memref.dealloc %2 : memref + return + } + + // Dump a sparse matrix. + func @dump_mat(%arg0: tensor) { + %d0 = arith.constant 0.0 : f64 + %c0 = arith.constant 0 : index + %dm = sparse_tensor.convert %arg0 : tensor to tensor + %0 = bufferization.to_memref %dm : memref + %1 = vector.transfer_read %0[%c0, %c0], %d0: memref, vector<4x8xf64> + vector.print %1 : vector<4x8xf64> + memref.dealloc %0 : memref + return + } + + // Driver method to call and verify vector kernels. + func @entry() { + %c0 = arith.constant 0 : index + + // Setup sparse vectors. + %v1 = arith.constant sparse< + [ [0], [3], [11], [17], [20], [21], [28], [29], [31] ], + [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ] + > : tensor<32xf64> + %v2 = arith.constant sparse< + [ [1], [3], [4], [10], [16], [18], [21], [28], [29], [31] ], + [11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0 ] + > : tensor<32xf64> + %sv1 = sparse_tensor.convert %v1 : tensor<32xf64> to tensor + %sv2 = sparse_tensor.convert %v2 : tensor<32xf64> to tensor + + // Setup sparse matrices. + %m1 = arith.constant sparse< + [ [0,0], [0,1], [1,7], [2,2], [2,4], [2,7], [3,0], [3,2], [3,3] ], + [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ] + > : tensor<4x8xf64> + %m2 = arith.constant sparse< + [ [0,0], [0,7], [1,0], [1,6], [2,1], [2,7] ], + [6.0, 5.0, 4.0, 3.0, 2.0, 1.0 ] + > : tensor<4x8xf64> + %sm1 = sparse_tensor.convert %m1 : tensor<4x8xf64> to tensor + %sm2 = sparse_tensor.convert %m2 : tensor<4x8xf64> to tensor + + // Call sparse vector kernels. + %0 = call @vector_min(%sv1, %sv2) + : (tensor, + tensor) -> tensor + %1 = call @vector_setdiff(%sv1, %sv2) + : (tensor, + tensor) -> tensor + %2 = call @vector_index(%sv1) + : (tensor) -> tensor + + // Call sparse matrix kernels. + %5 = call @matrix_intersect(%sm1, %sm2) + : (tensor, tensor) -> tensor + + // + // Verify the results. + // + // CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 4, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 8, 0, 9 ) + // CHECK-NEXT: ( 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 0, 11, 0, 12, 13, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 15, 0, 16, 0, 0, 17, 0, 0, 0, 0, 0, 0, 18, 19, 0, 20 ) + // CHECK-NEXT: ( 1, 11, 2, 13, 14, 3, 15, 4, 16, 5, 6, 7, 8, 9, -1, -1 ) + // CHECK-NEXT: ( 1, 11, 0, 2, 13, 0, 0, 0, 0, 0, 14, 3, 0, 0, 0, 0, 15, 4, 16, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 8, 0, 9 ) + // CHECK-NEXT: ( 1, 3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 4, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 3, 11, 17, 20, 21, 28, 29, 31, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 17, 0, 0, 20, 21, 0, 0, 0, 0, 0, 0, 28, 29, 0, 31 ) + // CHECK-NEXT: ( ( 7, 0, 0, 0, 0, 0, 0, -5 ), ( -4, 0, 0, 0, 0, 0, -3, 0 ), ( 0, -2, 0, 0, 0, 0, 0, 7 ), ( 0, 0, 0, 0, 0, 0, 0, 0 ) ) + // + call @dump_vec(%sv1) : (tensor) -> () + call @dump_vec(%sv2) : (tensor) -> () + call @dump_vec(%0) : (tensor) -> () + call @dump_vec(%1) : (tensor) -> () + call @dump_vec_i32(%2) : (tensor) -> () + call @dump_mat(%5) : (tensor) -> () + + // Release the resources. + sparse_tensor.release %sv1 : tensor + sparse_tensor.release %sv2 : tensor + sparse_tensor.release %0 : tensor + sparse_tensor.release %1 : tensor + sparse_tensor.release %2 : tensor + sparse_tensor.release %5 : tensor + return + } +} \ No newline at end of file diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir @@ -131,7 +131,7 @@ %sm1 = sparse_tensor.convert %m1 : tensor<4x8xf64> to tensor %sm2 = sparse_tensor.convert %m2 : tensor<4x8xf64> to tensor - // Call sparse vector kernels. + // Call sparse matrix kernels. %0 = call @matrix_scale(%sm1) : (tensor) -> tensor %1 = call @matrix_scale_inplace(%sm1) diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir @@ -0,0 +1,169 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// +// Traits for tensor operations. +// +#trait_vec_scale = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a (in) + affine_map<(i) -> (i)> // x (out) + ], + iterator_types = ["parallel"] +} +#trait_mat_scale = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A (in) + affine_map<(i,j) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel"] +} + +module { + // Invert the structure of a sparse vector. Present values become missing. + // Missing values are filled with 1 (i32). + func @vector_complement(%arga: tensor) -> tensor { + %c = arith.constant 0 : index + %ci1 = arith.constant 1 : i32 + %d = tensor.dim %arga, %c : tensor + %xv = sparse_tensor.init [%d] : tensor + %0 = linalg.generic #trait_vec_scale + ins(%arga: tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %x: i32): + %1 = sparse_tensor.unary %a : f64 to i32 + present={} + absent={ + sparse_tensor.yield %ci1 : i32 + } + linalg.yield %1 : i32 + } -> tensor + return %0 : tensor + } + + // Clips values to the range [3, 7]. + func @matrix_clip(%argx: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cfmin = arith.constant 3.0 : f64 + %cfmax = arith.constant 7.0 : f64 + %d0 = tensor.dim %argx, %c0 : tensor + %d1 = tensor.dim %argx, %c1 : tensor + %xv = sparse_tensor.init [%d0, %d1] : tensor + %0 = linalg.generic #trait_mat_scale + ins(%argx: tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %x: f64): + %1 = sparse_tensor.unary %a: f64 to f64 + present={ + ^bb0(%x0: f64): + %mincmp = arith.cmpf "ogt", %x0, %cfmin : f64 + %x1 = arith.select %mincmp, %x0, %cfmin : f64 + %maxcmp = arith.cmpf "olt", %x1, %cfmax : f64 + %x2 = arith.select %maxcmp, %x1, %cfmax : f64 + sparse_tensor.yield %x2 : f64 + } + absent={} + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + // Dumps a sparse vector of type f64. + func @dump_vec(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f64 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<16xf64> + vector.print %1 : vector<16xf64> + // Dump the dense vector to verify structure is correct. + %dv = sparse_tensor.convert %arg0 : tensor to tensor + %2 = bufferization.to_memref %dv : memref + %3 = vector.transfer_read %2[%c0], %d0: memref, vector<32xf64> + vector.print %3 : vector<32xf64> + memref.dealloc %2 : memref + return + } + + // Dumps a sparse vector of type i32. + func @dump_vec_i32(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1 : i32 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<24xi32> + vector.print %1 : vector<24xi32> + // Dump the dense vector to verify structure is correct. + %dv = sparse_tensor.convert %arg0 : tensor to tensor + %2 = bufferization.to_memref %dv : memref + %3 = vector.transfer_read %2[%c0], %d0: memref, vector<32xi32> + vector.print %3 : vector<32xi32> + memref.dealloc %2 : memref + return + } + + // Dump a sparse matrix. + func @dump_mat(%arg0: tensor) { + %d0 = arith.constant 0.0 : f64 + %c0 = arith.constant 0 : index + %dm = sparse_tensor.convert %arg0 : tensor to tensor + %0 = bufferization.to_memref %dm : memref + %1 = vector.transfer_read %0[%c0, %c0], %d0: memref, vector<4x8xf64> + vector.print %1 : vector<4x8xf64> + memref.dealloc %0 : memref + return + } + + // Driver method to call and verify vector kernels. + func @entry() { + %c0 = arith.constant 0 : index + + // Setup sparse vectors. + %v1 = arith.constant sparse< + [ [0], [3], [11], [17], [20], [21], [28], [29], [31] ], + [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ] + > : tensor<32xf64> + %sv1 = sparse_tensor.convert %v1 : tensor<32xf64> to tensor + + // Setup sparse matrices. + %m1 = arith.constant sparse< + [ [0,0], [0,1], [1,7], [2,2], [2,4], [2,7], [3,0], [3,2], [3,3] ], + [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ] + > : tensor<4x8xf64> + %sm1 = sparse_tensor.convert %m1 : tensor<4x8xf64> to tensor + + // Call sparse vector kernels. + %0 = call @vector_complement(%sv1) + : (tensor) -> tensor + + // Call sparse matrix kernels. + %1 = call @matrix_clip(%sm1) + : (tensor) -> tensor + + // + // Verify the results. + // + // CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 4, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 8, 0, 9 ) + // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1 ) + // CHECK-NEXT: ( 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0 ) + // CHECK-NEXT: ( ( 3, 3, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 3 ), ( 0, 0, 4, 0, 5, 0, 0, 6 ), ( 7, 0, 7, 7, 0, 0, 0, 0 ) ) + // + call @dump_vec(%sv1) : (tensor) -> () + call @dump_vec_i32(%0) : (tensor) -> () + call @dump_mat(%1) : (tensor) -> () + + // Release the resources. + sparse_tensor.release %sv1 : tensor + sparse_tensor.release %0 : tensor + sparse_tensor.release %1 : tensor + return + } +} \ No newline at end of file diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir @@ -125,19 +125,19 @@ // Sum reduces dot product of two sparse vectors. func @vector_dotprod(%arga: tensor, %argb: tensor, - %argx: tensor {linalg.inplaceable = true}) -> tensor { + %argx: tensor {linalg.inplaceable = true}) -> tensor { %0 = linalg.generic #trait_dot ins(%arga, %argb: tensor, tensor) outs(%argx: tensor) { ^bb(%a: f64, %b: f64, %x: f64): %1 = arith.mulf %a, %b : f64 - %2 = arith.addf %x, %1 : f64 + %2 = arith.addf %x, %1 : f64 linalg.yield %2 : f64 } -> tensor return %0 : tensor } - // Dumps a sparse vector. + // Dumps a sparse vector of type f64. func @dump(%arg0: tensor) { // Dump the values array to verify only sparse contents are stored. %c0 = arith.constant 0 : index