diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -46,6 +46,7 @@ kCastIdx, kTruncI, kBitCast, + kUnary, // Binary operations. kMulF, kMulI, @@ -62,6 +63,7 @@ kShrS, // signed kShrU, // unsigned kShlI, + kBinary, }; /// Children subexpressions of tensor operations. @@ -72,7 +74,7 @@ /// Tensor expression. Represents a MLIR expression in tensor index notation. struct TensorExp { - TensorExp(Kind k, unsigned x, unsigned y, Value v); + TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *op); /// Tensor expression kind. Kind kind; @@ -92,6 +94,9 @@ /// infer destination type) of a cast operation During code generation, /// this field may be used to cache "hoisted" loop invariant tensor loads. Value val; + + /// Holder for block of custom code to be merged. + Operation *operation; }; /// Lattice point. Each lattice point consists of a conjunction of tensor @@ -110,7 +115,7 @@ /// must execute. Pre-computed during codegen to avoid repeated eval. BitVector simple; - /// Index of the tensor expresssion. + /// Index of the tensor expression. unsigned exp; }; @@ -130,9 +135,14 @@ hasSparseOut(false), dims(t + 1, std::vector(l, Dim::kUndef)) {} /// Adds a tensor expression. Returns its index. - unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()); - unsigned addExp(Kind k, unsigned e, Value v) { return addExp(k, e, -1u, v); } - unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); } + unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(), + Operation *op = nullptr); + unsigned addExp(Kind k, unsigned e, Value v, Operation *op = nullptr) { + return addExp(k, e, -1u, v, op); + } + unsigned addExp(Kind k, Value v, Operation *op = nullptr) { + return addExp(k, -1u, -1u, v, op); + } /// Adds an iteration lattice point. Returns its index. unsigned addLat(unsigned t, unsigned i, unsigned e); @@ -144,20 +154,26 @@ /// of loop indices (effectively constructing a larger "intersection" of those /// indices) with a newly constructed tensor (sub)expression of given kind. /// Returns the index of the new lattice point. - unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1); + unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1, + Operation *op = nullptr); /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of /// cartesian product. Returns the index of the new set. - unsigned takeConj(Kind kind, unsigned s0, unsigned s1); + unsigned takeConj(Kind kind, unsigned s0, unsigned s1, + Operation *op = nullptr); /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1). /// Returns the index of the new set. unsigned takeDisj(Kind kind, unsigned s0, unsigned s1); + unsigned takeDisj(Kind kind, unsigned s0, unsigned s1, bool includeLeft, + bool includeRight, Operation *opboth, Operation *opleft, + Operation *opright); /// Maps the unary operator over the lattice set of the operand, i.e. each /// lattice point on an expression E is simply copied over, but with OP E /// as new expression. Returns the index of the new set. - unsigned mapSet(Kind kind, unsigned s0, Value v = Value()); + unsigned mapSet(Kind kind, unsigned s0, Value v = Value(), + Operation *op = nullptr); /// Optimizes the iteration lattice points in the given set. This /// method should be called right before code generation to avoid @@ -228,7 +244,7 @@ /// Builds the iteration lattices in a bottom-up traversal given the remaining /// tensor (sub)expression and the next loop index in the iteration graph. /// Returns index of the root expression. - unsigned buildLattices(unsigned e, unsigned i); + unsigned buildLattices(unsigned e, unsigned im, unsigned z); /// Builds a tensor expression from the given Linalg operation. /// Returns index of the root expression on success. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -794,7 +794,10 @@ // Store during insertion. OpOperand *t = op.getOutputOperand(0); if (t == codegen.sparseOut) { - genInsertionStore(codegen, rewriter, op, t, rhs); + // A few Kinds have conditional output (ex. sparse_tensor.unary) and + // indicate no output by passing an unitialized Value(). + if (rhs) + genInsertionStore(codegen, rewriter, op, t, rhs); return; } // Actual store. @@ -1217,8 +1220,7 @@ /// Emit a while-loop for co-iteration over multiple indices. static Operation *genWhile(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, - unsigned idx, bool needsUniv, - BitVector &indices) { + unsigned idx, bool needsUniv, BitVector &indices) { SmallVector types; SmallVector operands; // Construct the while-loop with a parameter for each index. @@ -1365,8 +1367,7 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, unsigned idx, bool needsUniv, - BitVector &induction, - scf::WhileOp whileOp) { + BitVector &induction, scf::WhileOp whileOp) { Location loc = op.getLoc(); // Finalize each else branch of all if statements. if (codegen.redVal || codegen.expValues) { @@ -1598,7 +1599,8 @@ // Construct iteration lattices for current loop index, with L0 at top. unsigned idx = topSort[at]; unsigned ldx = at == 0 ? -1u : topSort[at - 1]; - unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); + unsigned lts = + merger.optimizeSet(merger.buildLattices(exp, idx, topSort.size() - at)); // Start a loop sequence. bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/SparseTensor/Utils/Merger.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/Operation.h" #include "llvm/Support/Debug.h" @@ -19,18 +20,18 @@ // Constructors. //===----------------------------------------------------------------------===// -TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v) +TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *op) : kind(k), val(v) { switch (kind) { case kTensor: - assert(x != -1u && y == -1u && !v); + assert(x != -1u && y == -1u && !v && !op); tensor = x; break; case kInvariant: - assert(x == -1u && y == -1u && v); + assert(x == -1u && y == -1u && v && !op); break; case kIndex: - assert(x != -1u && y == -1u && !v); + assert(x != -1u && y == -1u && !v && !op); index = x; break; case kAbsF: @@ -38,7 +39,7 @@ case kFloorF: case kNegF: case kNegI: - assert(x != -1u && y == -1u && !v); + assert(x != -1u && y == -1u && !v && !op); children.e0 = x; children.e1 = y; break; @@ -53,12 +54,25 @@ case kCastIdx: case kTruncI: case kBitCast: - assert(x != -1u && y == -1u && v); + assert(x != -1u && y == -1u && v && !op); + children.e0 = x; + children.e1 = y; + break; + case kUnary: + assert(x != -1u && y == -1u && !v); children.e0 = x; children.e1 = y; + operation = op; + break; + case kBinary: + // assert(x != -1u && y != -1u && !v); + assert(x != -1u && !v); + children.e0 = x; + children.e1 = y; + operation = op; break; default: - assert(x != -1u && y != -1u && !v); + assert(x != -1u && y != -1u && !v && !op); children.e0 = x; children.e1 = y; break; @@ -77,9 +91,10 @@ // Lattice methods. //===----------------------------------------------------------------------===// -unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { +unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v, + Operation *op) { unsigned e = tensorExps.size(); - tensorExps.push_back(TensorExp(k, e0, e1, v)); + tensorExps.push_back(TensorExp(k, e0, e1, v, op)); return e; } @@ -96,20 +111,21 @@ return s; } -unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) { +unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1, + Operation *op) { unsigned p = latPoints.size(); BitVector nb = BitVector(latPoints[p0].bits); nb |= latPoints[p1].bits; - unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); + unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op); latPoints.push_back(LatPoint(nb, e)); return p; } -unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) { +unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) { unsigned s = addSet(); for (unsigned p0 : latSets[s0]) for (unsigned p1 : latSets[s1]) - latSets[s].push_back(conjLatPoint(kind, p0, p1)); + latSets[s].push_back(conjLatPoint(kind, p0, p1, op)); return s; } @@ -129,11 +145,32 @@ return s; } -unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) { - assert(kAbsF <= kind && kind <= kBitCast); +unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, bool includeLeft, + bool includeRight, Operation *opboth, + Operation *opleft, Operation *opright) { + unsigned s = takeConj(kind, s0, s1, opboth); + // Left Region + if (includeLeft) { + if (opleft) + s0 = mapSet(kind, s0, Value(), opleft); + for (unsigned p : latSets[s0]) + latSets[s].push_back(p); + } + // Right Region + if (includeRight) { + if (opright) + s1 = mapSet(kind, s1, Value(), opright); + for (unsigned p : latSets[s1]) + latSets[s].push_back(p); + } + return s; +} + +unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) { + assert(kind == kBinary || (kAbsF <= kind && kind <= kUnary)); unsigned s = addSet(); for (unsigned p : latSets[s0]) { - unsigned e = addExp(kind, latPoints[p].exp, v); + unsigned e = addExp(kind, latPoints[p].exp, v, op); latPoints.push_back(LatPoint(latPoints[p].bits, e)); latSets[s].push_back(latPoints.size() - 1); } @@ -303,6 +340,8 @@ case kTruncI: case kBitCast: return "cast"; + case kUnary: + return "unary"; case kMulF: return "*"; case kMulI: @@ -333,6 +372,8 @@ return ">>"; case kShlI: return "<<"; + case kBinary: + return "binary"; } llvm_unreachable("unexpected kind for symbol"); } @@ -429,7 +470,7 @@ // Builder methods. //===----------------------------------------------------------------------===// -unsigned Merger::buildLattices(unsigned e, unsigned i) { +unsigned Merger::buildLattices(unsigned e, unsigned i, unsigned z) { Kind kind = tensorExps[e].kind; switch (kind) { case kTensor: @@ -472,8 +513,37 @@ // -y|!y | y | // --+---+---+ // | 0 |-y | - return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), + return mapSet(kind, buildLattices(tensorExps[e].children.e0, i, z), tensorExps[e].val); + case kUnary: + // A custom unary operation + // + // op y| !y | y | + // ----+----------+------------+ + // | absent() | present(y) | + { + // Present Region + UnaryOp unop = dyn_cast(tensorExps[e].operation); + assert(unop); + Region &presentRegion = unop.presentRegion(); + Region &absentRegion = unop.absentRegion(); + Operation *presentYield = nullptr; + if (!presentRegion.empty()) { + Block &presentBlock = presentRegion.front(); + presentYield = presentBlock.getTerminator(); + } + unsigned s = mapSet(kind, buildLattices(tensorExps[e].children.e0, i, z), + Value(), presentYield); + // Absent Region + if (!absentRegion.empty()) { + Block &absentBlock = absentRegion.front(); + YieldOp absentYield = dyn_cast(absentBlock.getTerminator()); + Value absentVal = absentYield.result(); + // TODO: figure out what to do here + assert(false); + } + return s; + } case kMulF: case kMulI: case kAndI: @@ -485,8 +555,8 @@ // !x | 0 | 0 | // x | 0 |x*y| return takeConj(kind, // take binary conjunction - buildLattices(tensorExps[e].children.e0, i), - buildLattices(tensorExps[e].children.e1, i)); + buildLattices(tensorExps[e].children.e0, i, z), + buildLattices(tensorExps[e].children.e1, i, z)); case kDivF: case kDivS: case kDivU: @@ -505,8 +575,8 @@ // construction is concerned). assert(!maybeZero(tensorExps[e].children.e1)); return takeConj(kind, // take binary conjunction - buildLattices(tensorExps[e].children.e0, i), - buildLattices(tensorExps[e].children.e1, i)); + buildLattices(tensorExps[e].children.e0, i, z), + buildLattices(tensorExps[e].children.e1, i, z)); case kAddF: case kAddI: case kSubF: @@ -521,8 +591,8 @@ // !x | 0 | y | !x | 0 |-y | // x | x |x+y| x | x |x-y| return takeDisj(kind, // take binary disjunction - buildLattices(tensorExps[e].children.e0, i), - buildLattices(tensorExps[e].children.e1, i)); + buildLattices(tensorExps[e].children.e0, i, z), + buildLattices(tensorExps[e].children.e1, i, z)); case kShrS: case kShrU: case kShlI: @@ -531,8 +601,54 @@ // with the conjuction rule. assert(isInvariant(tensorExps[e].children.e1)); return takeConj(kind, // take binary conjunction - buildLattices(tensorExps[e].children.e0, i), - buildLattices(tensorExps[e].children.e1, i)); + buildLattices(tensorExps[e].children.e0, i, z), + buildLattices(tensorExps[e].children.e1, i, z)); + case kBinary: + // A custom binary operation + // + // x op y| !y | y | + // ------+---------+--------------+ + // !x | empty | right(y) | + // x | left(x) | overlap(x,y) | + { + BinaryOp binop = dyn_cast(tensorExps[e].operation); + assert(binop); + Region &overlapRegion = binop.overlapRegion(); + Region &leftRegion = binop.leftRegion(); + Region &rightRegion = binop.rightRegion(); + unsigned child0 = buildLattices(tensorExps[e].children.e0, i, z); + unsigned child1 = buildLattices(tensorExps[e].children.e1, i, z); + + if (z != 1) { + // When z == 1, this will be resolved correctly. + return takeConj(kind, buildLattices(tensorExps[e].children.e0, i, z), + buildLattices(tensorExps[e].children.e1, i, z), + tensorExps[e].operation); + } + + // Overlap Region + Operation *overlapYield = nullptr; + if (!overlapRegion.empty()) { + Block &overlapBlock = overlapRegion.front(); + overlapYield = overlapBlock.getTerminator(); + } + // Left Region + Operation *leftYield = nullptr; + if (!leftRegion.empty()) { + Block &leftBlock = leftRegion.front(); + leftYield = leftBlock.getTerminator(); + } + // Right Region + Operation *rightYield = nullptr; + if (!rightRegion.empty()) { + Block &rightBlock = rightRegion.front(); + rightYield = rightBlock.getTerminator(); + } + return takeDisj(kind, child0, child1, + binop.left_identity() || !leftRegion.empty(), + binop.right_identity() || !rightRegion.empty(), + overlapYield, leftYield, rightYield); + } } llvm_unreachable("unexpected expression kind"); } @@ -627,6 +743,9 @@ return addExp(kTruncI, e, v); if (isa(def)) return addExp(kBitCast, e, v); + if (isa(def)) { + return addExp(kUnary, e, Value(), def); + } } } // Construct binary operations if subexpressions can be built. @@ -668,6 +787,9 @@ return addExp(kShrU, e0, e1); if (isa(def) && isInvariant(e1)) return addExp(kShlI, e0, e1); + if (isa(def)) { + return addExp(kBinary, e0, e1, Value(), def); + } } } // Cannot build. @@ -749,6 +871,35 @@ return rewriter.create(loc, v0, v1); case kShlI: return rewriter.create(loc, v0, v1); + // Set-like ops with custom logic. + case kUnary: + case kBinary: { + Operation *op = tensorExps[e].operation; + if (!op) + return Value(); + // Make a clone of the block + Region tmpRegion; + BlockAndValueMapping mapper; + op->getBlock()->getParent()->cloneInto(&tmpRegion, tmpRegion.begin(), + mapper); + Block &clonedBlock = tmpRegion.front(); + YieldOp clonedYield = dyn_cast(clonedBlock.getTerminator()); + // Merge cloned block and return yield value + Operation *placeholder = rewriter.create(loc, 0); + if (clonedBlock.getNumArguments() == 2) { + if (!v0 or !v1) + return Value(); + rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, {v0, v1}); + } else { + if (!v0) + return Value(); + rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, {v0}); + } + Value val = clonedYield.result(); + rewriter.eraseOp(clonedYield); + rewriter.eraseOp(placeholder); + return val; + } } llvm_unreachable("unexpected expression kind in build"); } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir @@ -102,6 +102,36 @@ return %0 : tensor } + // Adds two sparse matrices when they intersect. Where they don't intersect, + // negate the 2nd argument's values and don't include the 1st argument's values. + func @matrix_intersect(%arga: tensor, + %argb: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arga, %c0 : tensor + %d1 = tensor.dim %arga, %c1 : tensor + %xv = sparse_tensor.init [%d0, %d1] : tensor + %0 = linalg.generic #trait_op + ins(%arga, %argb: tensor, tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %b: f64, %x: f64): + %1 = sparse_tensor.binary %a, %b: f64, f64 to f64 + overlap={ + ^bb0(%x0: f64, %y0: f64): + %ret = arith.addf %x0, %y0 : f64 + sparse_tensor.yield %ret : f64 + } + left={} + right={ + ^bb0(%x1: f64): + %lret = arith.negf %x1 : f64 + sparse_tensor.yield %lret : f64 + } + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + // Dump a sparse matrix. func @dump(%arg0: tensor) { %d0 = arith.constant 0.0 : f64 @@ -140,6 +170,8 @@ : (tensor, tensor) -> tensor %3 = call @matrix_mul(%sm1, %sm2) : (tensor, tensor) -> tensor + %4 = call @matrix_intersect(%sm1, %sm2) + : (tensor, tensor) -> tensor // // Verify the results. @@ -150,6 +182,7 @@ // CHECK-NEXT: ( ( 2, 4, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 6 ), ( 0, 0, 8, 0, 10, 0, 0, 12 ), ( 14, 0, 16, 18, 0, 0, 0, 0 ) ) // CHECK-NEXT: ( ( 8, 4, 0, 0, 0, 0, 0, 5 ), ( 4, 0, 0, 0, 0, 0, 3, 6 ), ( 0, 2, 8, 0, 10, 0, 0, 13 ), ( 14, 0, 16, 18, 0, 0, 0, 0 ) ) // CHECK-NEXT: ( ( 12, 0, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 12 ), ( 0, 0, 0, 0, 0, 0, 0, 0 ) ) + // CHECK-NEXT: ( ( 8, 0, 0, 0, 0, 0, 0, -5 ), ( -4, 0, 0, 0, 0, 0, -3, 0 ), ( 0, -2, 0, 0, 0, 0, 0, 13 ), ( 0, 0, 0, 0, 0, 0, 0, 0 ) ) // call @dump(%sm1) : (tensor) -> () call @dump(%sm2) : (tensor) -> () @@ -157,6 +190,7 @@ call @dump(%1) : (tensor) -> () call @dump(%2) : (tensor) -> () call @dump(%3) : (tensor) -> () + call @dump(%4) : (tensor) -> () // Release the resources. sparse_tensor.release %sm1 : tensor @@ -164,6 +198,7 @@ sparse_tensor.release %0 : tensor sparse_tensor.release %2 : tensor sparse_tensor.release %3 : tensor + sparse_tensor.release %4 : tensor return } } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir @@ -90,6 +90,31 @@ return %0 : tensor } + // Creates a new sparse vector using the minimum values from two input sparse vectors. + // When there is no overlap, include the present value in the output. + func @vector_min(%arga: tensor, + %argb: tensor) -> tensor { + %c = arith.constant 0 : index + %d = tensor.dim %arga, %c : tensor + %xv = sparse_tensor.init [%d] : tensor + %0 = linalg.generic #trait_op + ins(%arga, %argb: tensor, tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %b: f64, %x: f64): + %1 = sparse_tensor.binary %a, %b : f64, f64 to f64 + overlap={ + ^bb0(%a0: f64, %b0: f64): + %cmp = arith.cmpf "olt", %a0, %b0 : f64 + %2 = arith.select %cmp, %a0, %b0: f64 + sparse_tensor.yield %2 : f64 + } + left=identity + right=identity + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + // Multiplies two sparse vectors into a new sparse vector. func @vector_mul(%arga: tensor, %argb: tensor) -> tensor { @@ -125,19 +150,60 @@ // Sum reduces dot product of two sparse vectors. func @vector_dotprod(%arga: tensor, %argb: tensor, - %argx: tensor {linalg.inplaceable = true}) -> tensor { + %argx: tensor {linalg.inplaceable = true}) -> tensor { %0 = linalg.generic #trait_dot ins(%arga, %argb: tensor, tensor) outs(%argx: tensor) { ^bb(%a: f64, %b: f64, %x: f64): %1 = arith.mulf %a, %b : f64 - %2 = arith.addf %x, %1 : f64 + %2 = arith.addf %x, %1 : f64 linalg.yield %2 : f64 } -> tensor return %0 : tensor } - // Dumps a sparse vector. + // Take a set difference of two sparse vectors. The result will include only those + // sparse elements present in the first, but not the second vector. + func @vector_setdiff(%arga: tensor, + %argb: tensor) -> tensor { + %c = arith.constant 0 : index + %d = tensor.dim %arga, %c : tensor + %xv = sparse_tensor.init [%d] : tensor + %0 = linalg.generic #trait_op + ins(%arga, %argb: tensor, tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %b: f64, %x: f64): + %1 = sparse_tensor.binary %a, %b : f64, f64 to f64 + overlap={} + left=identity + right={} + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + // Invert the structure of a sparse vector. Present values become missing. + // Missing values are filled with 1 (i32). + func @vector_complement(%arga: tensor) -> tensor { + %c = arith.constant 0 : index + %ci1 = arith.constant 1 : i32 + %d = tensor.dim %arga, %c : tensor + %xv = sparse_tensor.init [%d] : tensor + %0 = linalg.generic #trait_scale + ins(%arga: tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %x: i32): + %1 = sparse_tensor.unary %a : f64 to i32 + present={} + absent={ + sparse_tensor.yield %ci1 : i32 + } + linalg.yield %1 : i32 + } -> tensor + return %0 : tensor + } + + // Dumps a sparse vector of type f64. func @dump(%arg0: tensor) { // Dump the values array to verify only sparse contents are stored. %c0 = arith.constant 0 : index @@ -154,6 +220,23 @@ return } + // Dumps a sparse vector of type i32. + func @dumpi32(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1 : i32 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<24xi32> + vector.print %1 : vector<24xi32> + // Dump the dense vector to verify structure is correct. + %dv = sparse_tensor.convert %arg0 : tensor to tensor + %2 = bufferization.to_memref %dv : memref + %3 = vector.transfer_read %2[%c0], %d0: memref, vector<32xi32> + vector.print %3 : vector<32xi32> + memref.dealloc %2 : memref + return + } + // Driver method to call and verify vector kernels. func @entry() { %c0 = arith.constant 0 : index @@ -184,15 +267,23 @@ %2 = call @vector_add(%sv1, %sv2) : (tensor, tensor) -> tensor - %3 = call @vector_mul(%sv1, %sv2) + %3 = call @vector_min(%sv1, %sv2) + : (tensor, + tensor) -> tensor + %4 = call @vector_mul(%sv1, %sv2) : (tensor, tensor) -> tensor - %4 = call @vector_mul_d(%sv1, %sv2) + %5 = call @vector_mul_d(%sv1, %sv2) : (tensor, tensor) -> tensor - %5 = call @vector_dotprod(%sv1, %sv2, %x) + %6 = call @vector_dotprod(%sv1, %sv2, %x) : (tensor, tensor, tensor) -> tensor + %7 = call @vector_setdiff(%sv1, %sv2) + : (tensor, + tensor) -> tensor + %8 = call @vector_complement(%sv1) + : (tensor) -> tensor // // Verify the results. @@ -207,10 +298,16 @@ // CHECK-NEXT: ( 2, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 8, 0, 0, 10, 12, 0, 0, 0, 0, 0, 0, 14, 16, 0, 18 ) // CHECK-NEXT: ( 2, 11, 16, 13, 14, 6, 15, 8, 16, 10, 29, 32, 35, 38, -1, -1 ) // CHECK-NEXT: ( 2, 11, 0, 16, 13, 0, 0, 0, 0, 0, 14, 6, 0, 0, 0, 0, 15, 8, 16, 0, 10, 29, 0, 0, 0, 0, 0, 0, 32, 35, 0, 38 ) + // CHECK-NEXT: ( 2, 11, 4, 13, 14, 6, 15, 8, 16, 10, 12, 14, 16, 18, -1, -1 ) + // CHECK-NEXT: ( 2, 11, 0, 4, 13, 0, 0, 0, 0, 0, 14, 6, 0, 0, 0, 0, 15, 8, 16, 0, 10, 12, 0, 0, 0, 0, 0, 0, 14, 16, 0, 18 ) // CHECK-NEXT: ( 48, 204, 252, 304, 360, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) // CHECK-NEXT: ( 0, 0, 0, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 204, 0, 0, 0, 0, 0, 0, 252, 304, 0, 360 ) // CHECK-NEXT: ( 0, 0, 0, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 204, 0, 0, 0, 0, 0, 0, 252, 304, 0, 360 ) // CHECK-NEXT: 1169.1 + // CHECK-NEXT: ( 2, 6, 8, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 8, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1 ) + // CHECK-NEXT: ( 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0 ) // call @dump(%sv1) : (tensor) -> () call @dump(%sv2) : (tensor) -> () @@ -218,12 +315,15 @@ call @dump(%1) : (tensor) -> () call @dump(%2) : (tensor) -> () call @dump(%3) : (tensor) -> () - %m4 = sparse_tensor.values %4 : tensor to memref - %v4 = vector.load %m4[%c0]: memref, vector<32xf64> - vector.print %v4 : vector<32xf64> - %m5 = bufferization.to_memref %5 : memref - %v5 = memref.load %m5[] : memref - vector.print %v5 : f64 + call @dump(%4) : (tensor) -> () + %m5 = sparse_tensor.values %5 : tensor to memref + %v5 = vector.load %m5[%c0]: memref, vector<32xf64> + vector.print %v5 : vector<32xf64> + %m6 = bufferization.to_memref %6 : memref + %v6 = memref.load %m6[] : memref + vector.print %v6 : f64 + call @dump(%7) : (tensor) -> () + call @dumpi32(%8) : (tensor) -> () // Release the resources. sparse_tensor.release %sv1 : tensor @@ -231,7 +331,10 @@ sparse_tensor.release %0 : tensor sparse_tensor.release %2 : tensor sparse_tensor.release %3 : tensor - sparse_tensor.release %4 : tensor + sparse_tensor.release %4 : tensor + sparse_tensor.release %5 : tensor + sparse_tensor.release %7 : tensor + sparse_tensor.release %8 : tensor memref.dealloc %xdata : memref return } diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -223,7 +223,7 @@ auto e = addf(tensor(t0), tensor(t1)); // Build lattices and check. - auto s = merger.buildLattices(e, l0); + auto s = merger.buildLattices(e, l0, 1); expectNumLatPoints(s, 3); expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), loopsToBits({{l0, t0}, {l0, t1}})); @@ -254,7 +254,7 @@ auto e = mulf(t0, t1); // Build lattices and check. - auto s = merger.buildLattices(e, l0); + auto s = merger.buildLattices(e, l0, 1); expectNumLatPoints(s, 1); expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), loopsToBits({{l0, t0}, {l0, t1}}));