diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -94,6 +95,13 @@ V value; }; +/// The type of callback functions which receive an element. We avoid +/// packaging the coordinates and value together as an `Element` object +/// because this helps keep code somewhat cleaner. +template +using ElementConsumer = + const std::function &, V)> &; + /// A memory-resident sparse tensor in coordinate scheme (collection of /// elements). This data structure is used to read a sparse tensor from /// any external format into memory and sort the elements lexicographically @@ -220,6 +228,7 @@ const uint64_t *perm, const DimLevelType *sparsity) : dimSizes(szs), rev(getRank()), dimTypes(sparsity, sparsity + getRank()) { + assert(perm && sparsity); const uint64_t rank = getRank(); // Validate parameters. assert(rank > 0 && "Trivial shape is unsupported"); @@ -310,6 +319,16 @@ /// Finishes insertion. virtual void endInsert() = 0; +protected: + // Since this class is virtual, we must disallow public copying in + // order to avoid "slicing". Since this class has data members, + // that means making copying protected. + // + SparseTensorStorageBase(const SparseTensorStorageBase &) = default; + // Copy-assignment would be implicitly deleted (because `dimSizes` + // is const), so we explicitly delete it for clarity. + SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete; + private: static void fatal(const char *tp) { fprintf(stderr, "unsupported %s\n", tp); @@ -321,6 +340,10 @@ const std::vector dimTypes; }; +// Forward. +template +class SparseTensorEnumerator; + /// A memory-resident sparse tensor using a storage scheme based on /// per-dimension sparse/dense annotations. This data structure provides a /// bufferized form of a sparse tensor type. In contrast to generating setup @@ -443,24 +466,13 @@ /// sparse tensor in coordinate scheme with the given dimension order. /// /// Precondition: `perm` must be valid for `getRank()`. - SparseTensorCOO *toCOO(const uint64_t *perm) { - // Restore original order of the dimension sizes and allocate coordinate - // scheme with desired new ordering specified in perm. - const uint64_t rank = getRank(); - const auto &rev = getRev(); - const auto &sizes = getDimSizes(); - std::vector orgsz(rank); - for (uint64_t r = 0; r < rank; r++) - orgsz[rev[r]] = sizes[r]; - SparseTensorCOO *coo = SparseTensorCOO::newSparseTensorCOO( - rank, orgsz.data(), perm, values.size()); - // Populate coordinate scheme restored from old ordering and changed with - // new ordering. Rather than applying both reorderings during the recursion, - // we compute the combine permutation in advance. - std::vector reord(rank); - for (uint64_t r = 0; r < rank; r++) - reord[r] = perm[rev[r]]; - toCOO(*coo, reord, 0, 0); + SparseTensorCOO *toCOO(const uint64_t *perm) const { + SparseTensorEnumerator enumerator(*this, getRank(), perm); + SparseTensorCOO *coo = + new SparseTensorCOO(enumerator.permutedSizes(), values.size()); + enumerator.forallElements([&coo](const std::vector &ind, V val) { + coo->add(ind, val); + }); // TODO: This assertion assumes there are no stored zeros, // or if there are then that we don't filter them out. // Cf., @@ -543,9 +555,10 @@ /// and pointwise less-than). void fromCOO(const std::vector> &elements, uint64_t lo, uint64_t hi, uint64_t d) { + uint64_t rank = getRank(); + assert(d <= rank && hi <= elements.size()); // Once dimensions are exhausted, insert the numerical values. - assert(d <= getRank() && hi <= elements.size()); - if (d == getRank()) { + if (d == rank) { assert(lo < hi); values.push_back(elements[lo].value); return; @@ -569,31 +582,6 @@ finalizeSegment(d, full); } - /// Stores the sparse tensor storage scheme into a memory-resident sparse - /// tensor in coordinate scheme. - void toCOO(SparseTensorCOO &tensor, std::vector &reord, - uint64_t pos, uint64_t d) { - assert(d <= getRank()); - if (d == getRank()) { - assert(pos < values.size()); - tensor.add(idx, values[pos]); - } else if (isCompressedDim(d)) { - // Sparse dimension. - for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) { - idx[reord[d]] = indices[d][ii]; - toCOO(tensor, reord, ii, d + 1); - } - } else { - // Dense dimension. - const uint64_t sz = getDimSizes()[d]; - const uint64_t off = pos * sz; - for (uint64_t i = 0; i < sz; i++) { - idx[reord[d]] = i; - toCOO(tensor, reord, off + i, d + 1); - } - } - } - /// Finalize the sparse pointer structure at this dimension. void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) { if (count == 0) @@ -649,13 +637,151 @@ return -1u; } -private: + // Allow `SparseTensorEnumerator` to access the data-members (to avoid + // the cost of virtual-function dispatch in inner loops), without + // making them public to other client code. + friend class SparseTensorEnumerator; + std::vector> pointers; std::vector> indices; std::vector values; std::vector idx; // index cursor for lexicographic insertion. }; +/// A (higher-order) function object for enumerating the elements of some +/// `SparseTensorStorage` under a permutation. That is, the `forallElements` +/// method encapsulates the loop-nest for enumerating the elements of +/// the source tensor (in whatever order is best for the source tensor), +/// and applies a permutation to the coordinates/indices before handing +/// each element to the callback. A single enumerator object can be +/// freely reused for several calls to `forallElements`, just so long +/// as each call is sequential with respect to one another. +/// +/// N.B., this class stores a reference to the `SparseTensorStorageBase` +/// passed to the constructor; thus, objects of this class must not +/// outlive the sparse tensor they depend on. +/// +/// Design Note: The reason we define this class instead of simply using +/// `SparseTensorEnumerator` is because we need to hide/generalize +/// the `` template parameters from MLIR client code (to simplify the +/// type parameters used for direct sparse-to-sparse conversion). And the +/// reason we define the `SparseTensorEnumerator` subclasses rather +/// than simply using this class, is to avoid the cost of virtual-method +/// dispatch within the loop-nest. +template +class SparseTensorEnumeratorBase { +public: + /// Constructs an enumerator with the given permutation for mapping + /// the semantic-ordering of dimensions to the desired target-ordering. + /// + /// Preconditions: + /// * the `tensor` must have the same `V` value type. + /// * `perm` must be valid for `rank`. + SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor, + uint64_t rank, const uint64_t *perm) + : src(tensor), permsz(src.getRev().size()), reord(getRank()), + cursor(getRank()) { + assert(perm && "Received nullptr for permutation"); + assert(rank == getRank() && "Permutation rank mismatch"); + const auto &rev = src.getRev(); // source stg-order -> semantic-order + const auto &sizes = src.getDimSizes(); // in source storage-order + for (uint64_t s = 0; s < rank; s++) { // `s` source storage-order + uint64_t t = perm[rev[s]]; // `t` target-order + reord[s] = t; + permsz[t] = sizes[s]; + } + } + + virtual ~SparseTensorEnumeratorBase() = default; + + // We disallow copying to help avoid leaking the `src` reference. + // (In addition to avoiding the problem of slicing.) + SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete; + SparseTensorEnumeratorBase & + operator=(const SparseTensorEnumeratorBase &) = delete; + + /// Returns the source/target tensor's rank. (The source-rank and + /// target-rank are always equal since we only support permutations. + /// Though once we add support for other dimension mappings, this + /// method will have to be split in two.) + uint64_t getRank() const { return permsz.size(); } + + /// Returns the target tensor's dimension sizes. + const std::vector &permutedSizes() const { return permsz; } + + /// Enumerates all elements of the source tensor, permutes their + /// indices, and passes the permuted element to the callback. + /// The callback must not store the cursor reference directly, + /// since this function reuses the storage. Instead, the callback + /// must copy it if they want to keep it. + virtual void forallElements(ElementConsumer yield) = 0; + +protected: + const SparseTensorStorageBase &src; + std::vector permsz; // in target order. + std::vector reord; // source storage-order -> target order. + std::vector cursor; // in target order. +}; + +template +class SparseTensorEnumerator final : public SparseTensorEnumeratorBase { + using Base = SparseTensorEnumeratorBase; + +public: + /// Constructs an enumerator with the given permutation for mapping + /// the semantic-ordering of dimensions to the desired target-ordering. + /// + /// Precondition: `perm` must be valid for `rank`. + SparseTensorEnumerator(const SparseTensorStorage &tensor, + uint64_t rank, const uint64_t *perm) + : Base(tensor, rank, perm) {} + + ~SparseTensorEnumerator() final override = default; + + void forallElements(ElementConsumer yield) final override { + forallElements(yield, 0, 0); + } + +private: + /// The recursive component of the public `forallElements`. + void forallElements(ElementConsumer yield, uint64_t parentPos, + uint64_t d) { + // Recover the `` type parameters of `src`. + const auto &src = + static_cast &>(this->src); + if (d == Base::getRank()) { + assert(parentPos < src.values.size() && + "Value position is out of bounds"); + // TODO: + yield(this->cursor, src.values[parentPos]); + } else if (src.isCompressedDim(d)) { + // Look up the bounds of the `d`-level segment determined by the + // `d-1`-level position `parentPos`. + const std::vector

&pointers_d = src.pointers[d]; + assert(parentPos + 1 < pointers_d.size() && + "Parent pointer position is out of bounds"); + const uint64_t pstart = static_cast(pointers_d[parentPos]); + const uint64_t pstop = static_cast(pointers_d[parentPos + 1]); + // Loop-invariant code for looking up the `d`-level coordinates/indices. + const std::vector &indices_d = src.indices[d]; + assert(pstop - 1 < indices_d.size() && "Index position is out of bounds"); + uint64_t &cursor_reord_d = this->cursor[this->reord[d]]; + for (uint64_t pos = pstart; pos < pstop; pos++) { + cursor_reord_d = static_cast(indices_d[pos]); + forallElements(yield, pos, d + 1); + } + } else { // Dense dimension. + const uint64_t sz = src.getDimSizes()[d]; + const uint64_t pstart = parentPos * sz; + uint64_t &cursor_reord_d = this->cursor[this->reord[d]]; + for (uint64_t i = 0; i < sz; i++) { + cursor_reord_d = i; + forallElements(yield, pstart + i, d + 1); + } + } + } +}; + /// Helper to convert string to lower case. static char *toLower(char *token) { for (char *c = token; *c; c++)