diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -86,20 +86,9 @@ /// ({i,j,k,l,m}, a[i,j,k,l,m]) template struct Element { - Element(const std::vector &ind, V val) : indices(ind), value(val){}; - std::vector indices; + Element(uint64_t *ind, V val) : indices(ind), value(val){}; + uint64_t *indices; // pointer into shared index pool V value; - /// Returns true if indices of e1 < indices of e2. - static bool lexOrder(const Element &e1, const Element &e2) { - uint64_t rank = e1.indices.size(); - assert(rank == e2.indices.size()); - for (uint64_t r = 0; r < rank; r++) { - if (e1.indices[r] == e2.indices[r]) - continue; - return e1.indices[r] < e2.indices[r]; - } - return false; - } }; /// A memory-resident sparse tensor in coordinate scheme (collection of @@ -112,29 +101,72 @@ public: SparseTensorCOO(const std::vector &szs, uint64_t capacity) : sizes(szs) { - if (capacity) + if (capacity) { elements.reserve(capacity); + indices.reserve(capacity * getRank()); + } } + /// Adds element as indices and value. void add(const std::vector &ind, V val) { assert(!iteratorLocked && "Attempt to add() after startIterator()"); + uint64_t *base = indices.data(); + uint64_t size = indices.size(); uint64_t rank = getRank(); assert(rank == ind.size()); - for (uint64_t r = 0; r < rank; r++) + for (uint64_t r = 0; r < rank; r++) { assert(ind[r] < sizes[r]); // within bounds - elements.emplace_back(ind, val); + indices.push_back(ind[r]); + } + // This base only changes if indices were reallocated. In that case, we + // need to correct all previous pointers into the vector. Note that this + // only happens if we did not set the initial capacity right, and then only + // for every internal vector reallocation (which with the doubling rule + // should only incur an amortized linear overhead). + uint64_t *new_base = indices.data(); + if (new_base != base) { + for (uint64_t i = 0, n = elements.size(); i < n; i++) + elements[i].indices = new_base + (elements[i].indices - base); + base = new_base; + } + // Add element as (pointer into shared index pool, value) pair. + elements.emplace_back(base + size, val); } + + /// Returns true if indices of e1 < indices of e2. + bool lexOrder(const Element &e1, const Element &e2) { + uint64_t rank = getRank(); + for (uint64_t r = 0; r < rank; r++) { + if (e1.indices[r] == e2.indices[r]) + continue; + return e1.indices[r] < e2.indices[r]; + } + return false; + } + /// Sorts elements lexicographically by index. void sort() { assert(!iteratorLocked && "Attempt to sort() after startIterator()"); // TODO: we may want to cache an `isSorted` bit, to avoid // unnecessary/redundant sorting. - std::sort(elements.begin(), elements.end(), Element::lexOrder); + std::sort(elements.begin(), elements.end(), + [this](const Element &e1, const Element &e2) { + uint64_t rank = getRank(); + for (uint64_t r = 0; r < rank; r++) { + if (e1.indices[r] == e2.indices[r]) + continue; + return e1.indices[r] < e2.indices[r]; + } + return false; + }); } + /// Returns rank. uint64_t getRank() const { return sizes.size(); } + /// Getter for sizes array. const std::vector &getSizes() const { return sizes; } + /// Getter for elements array. const std::vector> &getElements() const { return elements; } @@ -143,6 +175,7 @@ iteratorLocked = true; iteratorPos = 0; } + /// Get the next element. const Element *getNext() { assert(iteratorLocked && "Attempt to getNext() before startIterator()"); @@ -172,7 +205,8 @@ private: const std::vector sizes; // per-dimension sizes - std::vector> elements; + std::vector> elements; // all COO elements + std::vector indices; // shared index pool bool iteratorLocked = false; unsigned iteratorPos = 0; };