diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -182,12 +182,18 @@ return call.getResult(0); } +/// Generates a constant zero of the appropriate type. +static Value getZero(ConversionPatternRewriter &r, Location loc, Type t) { + return r.create(loc, r.getZeroAttr(t)); +} + /// Generates the comparison `v != 0` where `v` is of numeric type `t`. /// For floating types, we use the "unordered" comparator (i.e., returns /// true if `v` is NaN). static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, - Type t, Value v) { - Value zero = rewriter.create(loc, rewriter.getZeroAttr(t)); + Value v) { + Type t = v.getType(); + Value zero = getZero(rewriter, loc, t); if (t.isa()) return rewriter.create(loc, CmpFPredicate::UNE, v, zero); if (t.isIntOrIndex()) @@ -207,7 +213,7 @@ Value ind, ValueRange ivs) { Location loc = op->getLoc(); Value val = rewriter.create(loc, tensor, ivs); - Value cond = genIsNonzero(rewriter, loc, eltType, val); + Value cond = genIsNonzero(rewriter, loc, val); scf::IfOp ifOp = rewriter.create(loc, cond, /*else*/ false); rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); unsigned i = 0; @@ -288,6 +294,43 @@ return rewriter.create(loc, values, ivs[0]); } +/// Generates a call to SparseTensorCOO::Iterator::getNext() +/// If there is a next `Element`: the `indices` will be filled from +/// that element, and the returned `Value` will be the `V` of that element. +/// If there is no next `Element`: the `indices` will be left in an +/// indeterminate state (in practice it'll be left unmodified), and +/// the returned `Value` is zero-- which (by definition) is never a +/// valid `V` for `SparseTensorCOO` to contain, so there's no chance +/// of confusion, nor any loss of expressivity. +static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, + Value iter, Value indices, Value perm) { + Location loc = op->getLoc(); + Type elemTp = iter.getType().cast().getElementType(); + StringRef name; + if (elemTp.isF64()) + name = "getNextF64"; + else if (elemTp.isF32()) + name = "getNextF32"; + else if (elemTp.isInteger(64)) + name = "getNextI64"; + else if (elemTp.isInteger(32)) + name = "getNextI32"; + else if (elemTp.isInteger(16)) + name = "getNextI16"; + else if (elemTp.isInteger(8)) + name = "getNextI8"; + else + llvm_unreachable("Unknown element type"); + SmallVector params; + params.push_back(iter); + params.push_back(indices); + params.push_back(perm); + auto call = rewriter.create( + loc, elemTp, getFunc(op, name, elemTp, params, /*emitCInterface=*/true), + params); + return call.getResult(0); +} + //===----------------------------------------------------------------------===// // Conversion rules. //===----------------------------------------------------------------------===// @@ -352,6 +395,23 @@ } }; +static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc, + int64_t rank) { + auto indexTp = rewriter.getIndexType(); + auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp); + Value arg = rewriter.create(loc, rewriter.getIndexAttr(rank)); + return rewriter.create(loc, memTp, ValueRange{arg}); +} + +static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc, + ShapedType shapeTp) { + Type elemTp = shapeTp.getElementType(); + Value tensor = rewriter.create( + loc, MemRefType::get(shapeTp.getShape(), elemTp)); + rewriter.create(loc, getZero(rewriter, loc, elemTp), tensor); + return tensor; +} + /// Sparse conversion rule for the convert operator. class SparseTensorConvertConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -361,6 +421,7 @@ Type resType = op.getType(); auto encDst = getSparseTensorEncoding(resType); auto encSrc = getSparseTensorEncoding(op.source().getType()); + auto src = adaptor.getOperands()[0]; if (encDst && encSrc) { // This is a sparse => sparse conversion, which is handled as follows: // t = src->toCOO(); ; src to COO in dst order @@ -369,14 +430,44 @@ // yield the fastest conversion but avoids the need for a full // O(N^2) conversion matrix. Value perm; - Value coo = - genNewCall(rewriter, op, encDst, 3, perm, adaptor.getOperands()[0]); + Value coo = genNewCall(rewriter, op, encDst, 3, perm, src); rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo)); return success(); } if (!encDst || encSrc) { - // TODO: sparse => dense - return failure(); + // This is sparse => dense conversion, which is handled as follows: + // dst = new MemRef(0); + // iter = src->toCOO()->getIterator(); + // while (iter->hasNext()) { + // elem = iter->getNext(); + // dst[elem.indices] = elem.value; + // } + // While it would be more efficient to inline the iterator logic + // directly rather than allocating an object and calling methods, + // this is good enough for now. + Location loc = op->getLoc(); + ShapedType shapeTp = resType.cast(); + Value dst = allocDenseTensor(rewriter, loc, shapeTp); + Value indices = allocaIndices(rewriter, loc, shapeTp.getRank()); + Value perm; + Value iter = genNewCall(rewriter, op, encDst, 4, perm, src); + // Generate the while-loop Op itself. + TypeRange argTypes{}; + ValueRange args{}; + scf::WhileOp whileOp = rewriter.create(loc, argTypes, args); + Block *before = rewriter.createBlock(&whileOp.before(), {}, argTypes); + Block *after = rewriter.createBlock(&whileOp.after(), {}, argTypes); + // Build the while-loop's "before" region. + rewriter.setInsertionPointToEnd(before); + Value elemVal = genGetNextCall(rewriter, op, iter, indices, perm); + Value cond = genIsNonzero(rewriter, loc, elemVal); + rewriter.create(loc, cond, before->getArguments()); + // Build the while-loop's "after" region. + rewriter.setInsertionPointToStart(after); + rewriter.create(loc, elemVal, dst, indices); + // Finish up. + rewriter.replaceOp(op, dst); + return success(); } // This is a dense => sparse conversion or a sparse constant in COO => // sparse conversion, which is handled as follows: @@ -406,20 +497,15 @@ // loop is generated by genAddElt(). Location loc = op->getLoc(); ShapedType shape = resType.cast(); - auto memTp = - MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType()); Value perm; Value ptr = genNewCall(rewriter, op, encDst, 2, perm); - Value arg = rewriter.create( - loc, rewriter.getIndexAttr(shape.getRank())); - Value ind = rewriter.create(loc, memTp, ValueRange{arg}); + Value ind = allocaIndices(rewriter, loc, shape.getRank()); SmallVector lo; SmallVector hi; SmallVector st; Value zero = rewriter.create(loc, rewriter.getIndexAttr(0)); Value one = rewriter.create(loc, rewriter.getIndexAttr(1)); - Value tensor = adaptor.getOperands()[0]; - auto indicesValues = genSplitSparseConstant(rewriter, op, tensor); + auto indicesValues = genSplitSparseConstant(rewriter, op, src); bool isCOOConstant = indicesValues.hasValue(); Value indices; Value values; @@ -432,7 +518,7 @@ } else { for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) { lo.push_back(zero); - hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i)); + hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); st.push_back(one); } } @@ -447,7 +533,7 @@ rewriter, op, indices, values, ind, ivs, rank); else val = genIndexAndValueForDense(rewriter, op, eltType, - tensor, ind, ivs); + src, ind, ivs); genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); return {}; diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp @@ -107,6 +107,12 @@ /// Getter for elements array. const std::vector> &getElements() const { return elements; } + // Forward declaration of the class required by getIterator. We make + // it a nexted class so that it can access the private fields. + class Iterator; + /// Returns an iterator over the elements of a SparseTensorCOO. + Iterator *getIterator() const { return new Iterator(*this); } + /// Factory method. Permutes the original dimensions according to /// the given ordering and expects subsequent add() calls to honor /// that same ordering for the given indices. The result is a @@ -132,10 +138,39 @@ } return false; } - std::vector sizes; // per-rank dimension sizes + const std::vector sizes; // per-rank dimension sizes std::vector> elements; }; +/// This iterator is specifically designed for the needs of the MLIR +/// generated for sparse=>dense conversion; hence why it is so +/// idiosyncratic compared to a more conventional iterator for use within +/// C++ itself. +template +class SparseTensorCOO::Iterator { + // TODO(wrengr): really this class should be a thin wrapper/subclass + // of the std::vector, rather than needing to do a dereference every + // time a method is called; but we don't want to actually copy the whole + // contents of the underlying array(s) when this class is initialized. + // Maybe we should be a thin wrapper/subclass of SparseTensorCOO? + // Or have a variant of SparseTensorStorage::toCOO() to construct this + // iterator directly? + const std::vector> &elements; + unsigned pos; + +public: + // TODO(wrengr): to guarantee safety we'd either need to consume the + // SparseTensorCOO (e.g., requiring an rvalue-reference) or get notified + // somehow whenever the SparseTensorCOO adds new elements, sorts, etc. + Iterator(const SparseTensorCOO &coo) : elements(coo.elements), pos(0) {} + + const Element *getNext() { + if (pos < elements.size()) + return &(elements[pos++]); + return nullptr; + } +}; + /// Abstract base class of sparse tensor storage. Note that we use /// function overloading to implement "partial" method specialization. class SparseTensorStorageBase { @@ -517,8 +552,12 @@ tensor = static_cast *>(ptr); \ else if (action == 2) \ return SparseTensorCOO::newSparseTensorCOO(size, sizes, perm); \ - else \ - return static_cast *>(ptr)->toCOO(perm); \ + else { \ + tensor = static_cast *>(ptr)->toCOO(perm); \ + if (action == 3) \ + return tensor; \ + return tensor->getIterator(); \ + } \ return SparseTensorStorage::newSparseTensor(tensor, sparsity, \ perm); \ } @@ -560,6 +599,33 @@ return tensor; \ } +// TODO(wrengr): Why do we need to handle the permutation manually +// here? Why doesn't the SparseTensorCOO type handle it itself when constructed? +/// Calls SparseTensorCOO::Iterator::getNext() with the following semantics. +/// If there is a next `Element`: the `iref` will be filled from that +/// element, and the element's value is returned. +/// If there is no next `Element`: the `iref` will be left in an +/// indeterminate state (in practice it'll be left unmodified), and +/// the return value is zero-- which (by definition) is never a valid `V` +/// for `SparseTensorCOO` to contain, so there's no chance of confusion, +/// nor any loss of expressivity. +#define IMPL_COO_GETNEXT(NAME, V) \ + V _mlir_ciface_##NAME(void *ptr, StridedMemRefType *iref, \ + StridedMemRefType *pref) { \ + assert(iref->strides[0] == 1 && pref->strides[0] == 1); \ + assert(iref->sizes[0] == pref->sizes[0]); \ + uint64_t *indx = iref->data + iref->offset; \ + const uint64_t *perm = pref->data + pref->offset; \ + const uint64_t isize = iref->sizes[0]; \ + auto iter = static_cast::Iterator *>(ptr); \ + const Element *elem = iter->getNext(); \ + if (elem == nullptr) \ + return 0; \ + for (uint64_t r = 0; r < isize; r++) \ + indx[r] = elem->indices[perm[r]]; \ + return elem->value; \ + } + enum OverheadTypeEnum : uint64_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 }; enum PrimaryTypeEnum : uint64_t { @@ -578,6 +644,7 @@ /// 1 : ptr contains coordinate scheme to assign to new storage /// 2 : returns empty coordinate scheme to fill (call back 1 to setup) /// 3 : returns coordinate scheme from storage in ptr (call back 1 to convert) +/// 4 : returns iterator from storage in ptr (call IMPL_COO_GETNEXT to use) void * _mlir_ciface_newSparseTensor(StridedMemRefType *aref, // NOLINT StridedMemRefType *sref, @@ -674,10 +741,19 @@ IMPL3(addEltI16, int16_t) IMPL3(addEltI8, int8_t) +/// Helper to enumerate elements of coordinate scheme, one per value type. +IMPL_COO_GETNEXT(getNextF64, double) +IMPL_COO_GETNEXT(getNextF32, float) +IMPL_COO_GETNEXT(getNextI64, int64_t) +IMPL_COO_GETNEXT(getNextI32, int32_t) +IMPL_COO_GETNEXT(getNextI16, int16_t) +IMPL_COO_GETNEXT(getNextI8, int8_t) + #undef CASE #undef IMPL1 #undef IMPL2 #undef IMPL3 +#undef IMPL_COO_GETNEXT //===----------------------------------------------------------------------===// //