diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -102,6 +103,13 @@ } }; +/// The type of callback functions which receive an element. We avoid +/// packaging the coordinates and value together as an `Element` object +/// because this helps keep code somewhat cleaner. +template +using ElementConsumer = + const std::function &, V)> &; + /// A memory-resident sparse tensor in coordinate scheme (collection of /// elements). This data structure is used to read a sparse tensor from /// any external format into memory and sort the elements lexicographically @@ -194,6 +202,7 @@ const uint64_t *perm, const DimLevelType *sparsity) : dimSizes(szs), rev(getRank()), dimTypes(sparsity, sparsity + getRank()) { + assert(perm && sparsity); const uint64_t rank = getRank(); // Validate parameters. assert(rank > 0 && "Trivial shape is unsupported"); @@ -284,6 +293,31 @@ /// Finishes insertion. virtual void endInsert() = 0; + // The following two methods provide the additional interface required + // by `SparseTensorEnumerator`. They abstract over the `` + // templating of a `SparseTensorStorage` object, which is needed + // so that direct sparse-to-sparse conversion need not take separate + // `` parameters for the source vs the target (nor require that + // the source parameters match the target parameters). + + /// Looks up the `d`-level position/pointer stored at the given + /// `d-1`-level position, and converts the stored `P` into `uint64_t`. + virtual uint64_t getPointer(uint64_t d, uint64_t parentPos) const = 0; + + /// Looks up the `d`-level coordinate/index stored at the given + /// `d`-level position, and converts the stored `I` into `uint64_t`. + virtual uint64_t getIndex(uint64_t d, uint64_t pos) const = 0; + +protected: + // Since this class is virtual, we must disallow public copying in + // order to avoid "slicing". Since this class has data members, + // that means making copying protected. + // + SparseTensorStorageBase(const SparseTensorStorageBase &) = default; + // Copy-assignment would be implicitly deleted (because `dimSizes` + // is const), so we explicitly delete it for clarity. + SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete; + private: static void fatal(const char *tp) { fprintf(stderr, "unsupported %s\n", tp); @@ -295,6 +329,97 @@ const std::vector dimTypes; }; +/// A (higher-order) function object for enumerating the elements of some +/// `SparseTensorStorage` under a permutation. That is, the `forallElements` +/// method encapsulates the loop-nest for enumerating the elements of +/// the source tensor (in whatever order is best for the source tensor), +/// and applies a permutation to the coordinates/indices before handing +/// each element to the callback. A single enumerator object can be +/// freely reused for several calls to `forallElements`, just so long +/// as each call is sequential with respect to one another. +/// +/// N.B., this class stores a reference to the `SparseTensorStorageBase` +/// passed to the constructor; thus, objects of this class must not +/// outlive the sparse tensor they depend on. +template +class SparseTensorEnumerator { +public: + /// Constructs an enumerator with the given permutation for mapping + /// the semantic-ordering of dimensions to the desired target-ordering. + /// + /// Preconditions: + /// * the `tensor` must have the same value type `V`. + /// * `perm` must be valid for `rank`. + SparseTensorEnumerator(const SparseTensorStorageBase &tensor, uint64_t rank, + const uint64_t *perm) + : src(tensor), permsz(src.getRev().size()), reord(getRank()), + cursor(getRank()) { + assert(perm); + assert(rank == getRank() && "Permutation rank mismatch"); + const auto &rev = src.getRev(); // source stg-order -> semantic-order + const auto &sizes = src.getDimSizes(); // in source storage-order + for (uint64_t s = 0; s < rank; s++) { // `s` source storage-order + uint64_t t = perm[rev[s]]; // `t` target-order + reord[s] = t; + permsz[t] = sizes[s]; + } + } + + // We disallow copying to help avoid leaking the `src` reference. + SparseTensorEnumerator(const SparseTensorEnumerator &) = delete; + SparseTensorEnumerator &operator=(const SparseTensorEnumerator &) = delete; + + /// Returns the source/target tensor's rank. (The source-rank and + /// target-rank are always equal since we only support permutations. + /// Though once we add support for other dimension mappings, this + /// method will have to be split in two.) + uint64_t getRank() const { return permsz.size(); } + + /// Returns the target tensor's dimension sizes. + const std::vector &permutedSizes() const { return permsz; } + + /// Enumerates all elements of the source tensor, permutes their + /// indices, and passes the permuted element to the callback. + /// The callback must not store the cursor reference directly, since + /// this function reuses the storage. Instead, the callback must copy + /// it if they want to keep it. + void forallElements(ElementConsumer yield) { forallElements(yield, 0, 0); } + +private: + /// The recursive component of the public `forallElements`. + void forallElements(ElementConsumer yield, uint64_t parentPos, + uint64_t d) { + if (d == getRank()) { + std::vector *values; + // We must cast away the const on `src` to be allowed to call + // `getValues`; which is safe since we do not mutate `*values`. + const_cast(src).getValues(&values); + assert(parentPos < values->size() && "Value position is out of bounds"); + // TODO: + yield(cursor, (*values)[parentPos]); + } else if (src.isCompressedDim(d)) { + const uint64_t pstart = src.getPointer(d, parentPos); + const uint64_t pstop = src.getPointer(d, parentPos + 1); + for (uint64_t pos = pstart; pos < pstop; pos++) { + cursor[reord[d]] = src.getIndex(d, pos); + forallElements(yield, pos, d + 1); + } + } else { // Dense dimension. + const uint64_t sz = src.getDimSizes()[d]; + const uint64_t pstart = parentPos * sz; + for (uint64_t i = 0; i < sz; i++) { + cursor[reord[d]] = i; + forallElements(yield, pstart + i, d + 1); + } + } + } + + const SparseTensorStorageBase &src; + std::vector permsz; // in target order. + std::vector reord; // source storage-order -> target order. + std::vector cursor; // in target order. +}; + /// A memory-resident sparse tensor using a storage scheme based on /// per-dimension sparse/dense annotations. This data structure provides a /// bufferized form of a sparse tensor type. In contrast to generating setup @@ -417,24 +542,13 @@ /// sparse tensor in coordinate scheme with the given dimension order. /// /// Precondition: `perm` must be valid for `getRank()`. - SparseTensorCOO *toCOO(const uint64_t *perm) { - // Restore original order of the dimension sizes and allocate coordinate - // scheme with desired new ordering specified in perm. - const uint64_t rank = getRank(); - const auto &rev = getRev(); - const auto &sizes = getDimSizes(); - std::vector orgsz(rank); - for (uint64_t r = 0; r < rank; r++) - orgsz[rev[r]] = sizes[r]; - SparseTensorCOO *coo = SparseTensorCOO::newSparseTensorCOO( - rank, orgsz.data(), perm, values.size()); - // Populate coordinate scheme restored from old ordering and changed with - // new ordering. Rather than applying both reorderings during the recursion, - // we compute the combine permutation in advance. - std::vector reord(rank); - for (uint64_t r = 0; r < rank; r++) - reord[r] = perm[rev[r]]; - toCOO(*coo, reord, 0, 0); + SparseTensorCOO *toCOO(const uint64_t *perm) const { + SparseTensorEnumerator enumerator(*this, getRank(), perm); + SparseTensorCOO *coo = + new SparseTensorCOO(enumerator.permutedSizes(), values.size()); + enumerator.forallElements([&coo](const std::vector &ind, V val) { + coo->add(ind, val); + }); // TODO: This assertion assumes there are no stored zeros, // or if there are then that we don't filter them out. // Cf., @@ -517,9 +631,10 @@ /// and pointwise less-than). void fromCOO(const std::vector> &elements, uint64_t lo, uint64_t hi, uint64_t d) { + uint64_t rank = getRank(); + assert(d <= rank && hi <= elements.size()); // Once dimensions are exhausted, insert the numerical values. - assert(d <= getRank() && hi <= elements.size()); - if (d == getRank()) { + if (d == rank) { assert(lo < hi); values.push_back(elements[lo].value); return; @@ -543,31 +658,6 @@ finalizeSegment(d, full); } - /// Stores the sparse tensor storage scheme into a memory-resident sparse - /// tensor in coordinate scheme. - void toCOO(SparseTensorCOO &tensor, std::vector &reord, - uint64_t pos, uint64_t d) { - assert(d <= getRank()); - if (d == getRank()) { - assert(pos < values.size()); - tensor.add(idx, values[pos]); - } else if (isCompressedDim(d)) { - // Sparse dimension. - for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) { - idx[reord[d]] = indices[d][ii]; - toCOO(tensor, reord, ii, d + 1); - } - } else { - // Dense dimension. - const uint64_t sz = getDimSizes()[d]; - const uint64_t off = pos * sz; - for (uint64_t i = 0; i < sz; i++) { - idx[reord[d]] = i; - toCOO(tensor, reord, off + i, d + 1); - } - } - } - /// Finalize the sparse pointer structure at this dimension. void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) { if (count == 0) @@ -623,7 +713,21 @@ return -1u; } -private: + // Allow `SparseTensorEnumerator` to access the following methods + // without making them public to other client code. + friend class SparseTensorEnumerator; + uint64_t getPointer(uint64_t d, uint64_t parentPos) const override { + assert(isCompressedDim(d)); + assert(parentPos < pointers[d].size() && + "Pointer position is out of bounds"); + return pointers[d][parentPos]; // Converts the stored `P` into `uint64_t`. + } + uint64_t getIndex(uint64_t d, uint64_t pos) const override { + assert(isCompressedDim(d)); + assert(pos < indices[d].size() && "Index position is out of bounds"); + return indices[d][pos]; // Converts the stored `I` into `uint64_t`. + } + std::vector> pointers; std::vector> indices; std::vector values;