diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -149,6 +150,8 @@ /// the given ordering and expects subsequent add() calls to honor /// that same ordering for the given indices. The result is a /// fully permuted coordinate scheme. + /// + /// Precondition: `sizes` and `perm` must be valid for `rank`. static SparseTensorCOO *newSparseTensorCOO(uint64_t rank, const uint64_t *sizes, const uint64_t *perm, @@ -233,6 +236,144 @@ } }; +//===----------------------------------------------------------------------===// +/// This class provides the interface required by `SparseTensorEnumerator`, +/// and is a superclass of `SparseTensorStorage` for abstracting +/// over the `` templating. We need that abstraction so that direct +/// sparse-to-sparse conversion need not take separate `` parameters +/// for the source vs the target (nor require that the source parameters +/// match the target parameters). +template +class EnumerableSparseTensorStorage : public SparseTensorStorageBase { +public: + // NOTE: The first four methods do not strictly need to exist, nor + // be virtual if they do exist; they're just implemented as such + // for convenience. See the design doc for discussion of alternative + // implementations. + // + // Conversely, the `getPointer` and `getIndex` methods must remain + // virtual, since they're the methods which perform the + // erasure/abstraction over the `P` and `I` types. + + /// Gets the "reverse" permutation stored in `SparseTensorStorage`. + virtual const std::vector &getRev() const = 0; + /// Gets the dimension sizes array stored in `SparseTensorStorage`. + virtual const std::vector &getSizes() const = 0; + /// Returns true if the dimension is compressed. + virtual bool isCompressedDim(uint64_t d) const = 0; + /// Looks up the value stored at the given position. + virtual V getValue(uint64_t pos) const = 0; + /// Looks up the `d`-level position/pointer stored at the given + /// `d-1`-level position, and converts the stored `P` into `uint64_t`. + virtual uint64_t getPointer(uint64_t d, uint64_t parentPos) const = 0; + /// Looks up the `d`-level coordinate/index stored at the given + /// `d`-level position, and converts the stored `I` into `uint64_t`. + virtual uint64_t getIndex(uint64_t d, uint64_t pos) const = 0; +}; + +/// A (higher-order) function object for enumerating the elements of some +/// `SparseTensorStorage` under a permutation. That is, the `forallElements` +/// method encapsulates the loop-nest for enumerating the elements of +/// the source tensor (in whatever order is best for the source tensor), +/// and applies a permutation to the coordinates/indices before handing +/// each element to the callback. A single enumerator object can be +/// freely reused for several calls to `forallElements`, just so long +/// as each call is sequential with respect to one another. +/// +/// N.B., this class stores a reference to the `EnumerableSparseTensorStorage` +/// passed to the constructor; thus, objects of this class must not +/// outlive the sparse tensor they depend on. +template +class SparseTensorEnumerator { + const EnumerableSparseTensorStorage &src; + // These two are conceptually constant, but cannot be marked `const` + // without reorganizing the constructor. + std::vector permsz; // in target order. + std::vector reord; // source storage-order -> target order. + // This one is actually mutated. + std::vector cursor; // in target order. +public: + /// Constructs an enumerator with the trivial permutation, thus + /// enumerating elements with the semantic-ordering of dimensions. + explicit SparseTensorEnumerator( + const EnumerableSparseTensorStorage &tensor) + : src(tensor), permsz(src.getRev().size()), reord(src.getRev()), + cursor(getRank()) { + const auto &sizes = src.getSizes(); + for (uint64_t rank = getRank(), s = 0; s < rank; s++) + permsz[reord[s]] = sizes[s]; + } + /// Constructs an enumerator with the given permutation for mapping + /// the semantic-ordering of dimensions to the desired target-ordering. + /// + /// Precondition: `perm` must be valid for `rank`. + SparseTensorEnumerator(const EnumerableSparseTensorStorage &tensor, + uint64_t rank, const uint64_t *perm) + : src(tensor), permsz(src.getRev().size()), reord(getRank()), + cursor(getRank()) { + assert(rank == getRank() && "Permutation rank mismatch"); + const auto &rev = src.getRev(); // src storage-order -> semantic-order + const auto &sizes = src.getSizes(); // in source storage-order + for (uint64_t s = 0; s < rank; s++) { // `s` source storage-order + uint64_t t = perm[rev[s]]; // `t` target-order + reord[s] = t; + permsz[t] = sizes[s]; + } + } + + SparseTensorEnumerator(const SparseTensorEnumerator &) = delete; + SparseTensorEnumerator(SparseTensorEnumerator &&) = delete; + SparseTensorEnumerator &operator=(const SparseTensorEnumerator &) = delete; + SparseTensorEnumerator &operator=(SparseTensorEnumerator &&) = delete; + + /// Returns the source/target tensor's rank. (The source-rank and + /// target-rank are always equal since we only support permutations. + /// Though once we add support for other dimension mappings, this + /// method will have to be split in two.) + uint64_t getRank() const { return permsz.size(); } + + /// Returns the target tensor's dimension sizes. + const std::vector &permutedSizes() const { return permsz; } + + /// The type of callback functions which receive an element (in target + /// order). We avoid packaging the coordinates and value together + /// as an `Element` object because this helps keep code somewhat cleaner. + typedef const std::function &, V)> + &ElementConsumer; + + /// Enumerates all elements of the source tensor, permutes their + /// indices, and passes the permuted element to the callback. + /// The callback must not store the cursor reference directly, since + /// this function reuses the storage. Instead, the callback must copy + /// it if they want to keep it. + void forallElements(ElementConsumer yield) { forallElements(yield, 0, 0); } + +private: + /// The recursive component of the public `forallElements`. + void forallElements(ElementConsumer yield, uint64_t parentPos, uint64_t d) { + assert(d <= getRank()); + if (d == getRank()) { + // TODO: + yield(cursor, src.getValue(parentPos)); + } else if (src.isCompressedDim(d)) { + const uint64_t pstart = src.getPointer(d, parentPos); + const uint64_t pstop = src.getPointer(d, parentPos + 1); + for (uint64_t pos = pstart; pos < pstop; pos++) { + cursor[reord[d]] = src.getIndex(d, pos); + forallElements(yield, pos, d + 1); + } + } else { // Dense dimension. + const uint64_t sz = src.getSizes()[d]; + const uint64_t pstart = parentPos * sz; + for (uint64_t i = 0; i < sz; i++) { + cursor[reord[d]] = i; + forallElements(yield, pstart + i, d + 1); + } + } + } +}; // class SparseTensorEnumerator + +//===----------------------------------------------------------------------===// /// A memory-resident sparse tensor using a storage scheme based on /// per-dimension sparse/dense annotations. This data structure provides a /// bufferized form of a sparse tensor type. In contrast to generating setup @@ -240,11 +381,13 @@ /// a convenient "one-size-fits-all" solution that simply takes an input tensor /// and annotations to implement all required setup in a general manner. template -class SparseTensorStorage : public SparseTensorStorageBase { +class SparseTensorStorage : protected EnumerableSparseTensorStorage { public: /// Constructs a sparse tensor storage scheme with the given dimensions, /// permutation, and per-dimension dense/sparse annotations, using /// the coordinate scheme tensor for the initial contents if provided. + /// + /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`. SparseTensorStorage(const std::vector &szs, const uint64_t *perm, const DimLevelType *sparsity, SparseTensorCOO *tensor = nullptr) @@ -365,24 +508,20 @@ /// Returns this sparse tensor storage scheme as a new memory-resident /// sparse tensor in coordinate scheme with the given dimension order. - SparseTensorCOO *toCOO(const uint64_t *perm) { - // Restore original order of the dimension sizes and allocate coordinate - // scheme with desired new ordering specified in perm. - uint64_t rank = getRank(); - std::vector orgsz(rank); - for (uint64_t r = 0; r < rank; r++) - orgsz[rev[r]] = sizes[r]; - SparseTensorCOO *tensor = SparseTensorCOO::newSparseTensorCOO( - rank, orgsz.data(), perm, values.size()); - // Populate coordinate scheme restored from old ordering and changed with - // new ordering. Rather than applying both reorderings during the recursion, - // we compute the combine permutation in advance. - std::vector reord(rank); - for (uint64_t r = 0; r < rank; r++) - reord[r] = perm[rev[r]]; - toCOO(*tensor, reord, 0, 0); - assert(tensor->getElements().size() == values.size()); - return tensor; + /// + /// Precondition: `perm` must be valid for `getRank()`. + SparseTensorCOO *toCOO(const uint64_t *perm) const { + SparseTensorEnumerator enumerator(*this, getRank(), perm); + SparseTensorCOO *coo = + new SparseTensorCOO(enumerator.permutedSizes(), values.size()); + enumerator.forallElements([&coo](const std::vector &ind, V val) { + coo->add(ind, val); + }); + // TODO: This assertion assumes there are no stored zeros, + // or if there are then that we don't filter them out. + // Cf., + assert(coo->getElements().size() == values.size()); + return coo; } /// Factory method. Constructs a sparse tensor storage scheme with the given @@ -390,12 +529,14 @@ /// using the coordinate scheme tensor for the initial contents if provided. /// In the latter case, the coordinate scheme must respect the same /// permutation as is desired for the new sparse tensor storage. + /// + /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`. static SparseTensorStorage * newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm, const DimLevelType *sparsity, SparseTensorCOO *tensor) { SparseTensorStorage *n = nullptr; if (tensor) { - assert(tensor->getRank() == rank); + assert(tensor->getRank() == rank && "Tensor rank mismatch"); for (uint64_t r = 0; r < rank; r++) assert(shape[r] == 0 || shape[r] == tensor->getSizes()[perm[r]]); n = new SparseTensorStorage(tensor->getSizes(), perm, sparsity, @@ -483,29 +624,6 @@ } } - /// Stores the sparse tensor storage scheme into a memory-resident sparse - /// tensor in coordinate scheme. - void toCOO(SparseTensorCOO &tensor, std::vector &reord, - uint64_t pos, uint64_t d) { - assert(d <= getRank()); - if (d == getRank()) { - assert(pos < values.size()); - tensor.add(idx, values[pos]); - } else if (isCompressedDim(d)) { - // Sparse dimension. - for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) { - idx[reord[d]] = indices[d][ii]; - toCOO(tensor, reord, ii, d + 1); - } - } else { - // Dense dimension. - for (uint64_t i = 0, sz = sizes[d], off = pos * sz; i < sz; i++) { - idx[reord[d]] = i; - toCOO(tensor, reord, off + i, d + 1); - } - } - } - /// Ends a deeper, never seen before dimension. void endDim(uint64_t d) { assert(d <= getRank()); @@ -564,11 +682,33 @@ } /// Returns true if dimension is compressed. - inline bool isCompressedDim(uint64_t d) const { + inline bool isCompressedDim(uint64_t d) const override { assert(d < getRank()); return (!pointers[d].empty()); } +protected: + // Allow `SparseTensorEnumerator` to access these methods without + // making them public to other client code. + friend class SparseTensorEnumerator; + const std::vector &getRev() const override { return rev; } + const std::vector &getSizes() const override { return sizes; } + V getValue(uint64_t pos) const override { + assert(pos < values.size() && "Value position is out of bounds"); + return values[pos]; + } + uint64_t getPointer(uint64_t d, uint64_t parentPos) const override { + assert(isCompressedDim(d)); // Entails `d < getRank()`. + assert(parentPos < pointers[d].size() && + "Pointer position is out of bounds"); + return pointers[d][parentPos]; // Converts the stored `P` into `uint64_t`. + } + uint64_t getIndex(uint64_t d, uint64_t pos) const override { + assert(isCompressedDim(d)); // Entails `d < getRank()`. + assert(pos < indices[d].size() && "Index position is out of bounds"); + return indices[d][pos]; // Converts the stored `I` into `uint64_t`. + } + private: const std::vector sizes; // per-dimension sizes std::vector rev; // "reverse" permutation