diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h @@ -48,6 +48,25 @@ V value; }; +/// Closure object for `operator<` on `Element` with a given rank. +template +struct ElementLT final { + const uint64_t rank; + ElementLT(uint64_t rank) : rank(rank) {} + + /// Compare two elements a la `operator<`. + /// + /// Precondition: the elements must both be valid for `rank`. + bool operator()(const Element &e1, const Element &e2) const { + for (uint64_t d = 0; d < rank; ++d) { + if (e1.indices[d] == e2.indices[d]) + continue; + return e1.indices[d] < e2.indices[d]; + } + return false; + } +}; + /// The type of callback functions which receive an element. We avoid /// packaging the coordinates and value together as an `Element` object /// because this helps keep code somewhat cleaner. @@ -65,12 +84,14 @@ const std::vector dimSizes; // per-dimension sizes std::vector> elements; // all elements std::vector indices; // shared index pool - bool iteratorLocked = false; - unsigned iteratorPos = 0; + bool isSorted; + bool iteratorLocked; + unsigned iteratorPos; public: SparseTensorCOO(const std::vector &dimSizes, uint64_t capacity) - : dimSizes(dimSizes) { + : dimSizes(dimSizes), isSorted(true), iteratorLocked(false), + iteratorPos(0) { if (capacity) { elements.reserve(capacity); indices.reserve(capacity * getRank()); @@ -106,6 +127,9 @@ /// Get the elements array. const std::vector> &getElements() const { return elements; } + /// Returns the `operator<` closure object for the COO's element type. + ElementLT getElementLT() const { return ElementLT(getRank()); } + /// Adds an element to the tensor. This method does not check whether /// `ind` is already associated with a value, it adds it regardless. /// Resolving such conflicts is left up to clients of the iterator @@ -136,8 +160,11 @@ elements[i].indices = newBase + (elements[i].indices - base); base = newBase; } - // Add element as (pointer into shared index pool, value) pair. - elements.emplace_back(base + size, val); + // Add the new element and update the sorted bit. + Element addedElem(base + size, val); + if (!elements.empty() && isSorted) + isSorted = getElementLT()(elements.back(), addedElem); + elements.push_back(addedElem); } /// Sorts elements lexicographically by index. If an index is mapped to @@ -146,18 +173,10 @@ /// Asserts: is not in iterator mode. void sort() { assert(!iteratorLocked && "Attempt to sort() after startIterator()"); - // TODO: we may want to cache an `isSorted` bit, to avoid - // unnecessary/redundant sorting. - uint64_t rank = getRank(); - std::sort(elements.begin(), elements.end(), - [rank](const Element &e1, const Element &e2) { - for (uint64_t r = 0; r < rank; ++r) { - if (e1.indices[r] == e2.indices[r]) - continue; - return e1.indices[r] < e2.indices[r]; - } - return false; - }); + if (isSorted) + return; + std::sort(elements.begin(), elements.end(), getElementLT()); + isSorted = true; } /// Switch into iterator mode. If already in iterator mode, then