diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h @@ -65,6 +65,7 @@ const std::vector dimSizes; // per-dimension sizes std::vector> elements; // all elements std::vector indices; // shared index pool + bool isSorted = false; bool iteratorLocked = false; unsigned iteratorPos = 0; @@ -137,6 +138,9 @@ base = newBase; } // Add element as (pointer into shared index pool, value) pair. + // TODO: Only invalidate the `isSorted` bit if the new element isn't + // greater than the immediately preceding one. + isSorted = false; elements.emplace_back(base + size, val); } @@ -146,8 +150,8 @@ /// Asserts: is not in iterator mode. void sort() { assert(!iteratorLocked && "Attempt to sort() after startIterator()"); - // TODO: we may want to cache an `isSorted` bit, to avoid - // unnecessary/redundant sorting. + if (isSorted) + return; uint64_t rank = getRank(); std::sort(elements.begin(), elements.end(), [rank](const Element &e1, const Element &e2) { @@ -158,6 +162,7 @@ } return false; }); + isSorted = true; } /// Switch into iterator mode. If already in iterator mode, then