diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h @@ -209,9 +209,25 @@ return detail::readCOOValue(&linePtr, isPattern()); } - /// Allocates a new COO object for `lvlSizes`, initializes it by reading - /// all the elements from the file and applying `dim2lvl` to their indices, - /// and then closes the file. + /// Reads all `getNNZ()`-many elements from the file, passes each + /// element to the callback, and finally closes the file. N.B., + /// this function passes the dimension-indices to the callback, + /// thus the callback must perform the mapping from dimension-indices + /// to level-indices. + /// + /// Preconditions: + /// * the file's actual value type can be read as `V`. + /// + /// Asserts: + /// * `isValid()`. + // + // NOTE: This method is factored out from `readCOO` so that it can be + // reused by the codegen-pass. + template + void readAllElements(ElementConsumer yield); + + /// Allocates a new COO object for `lvlSizes`, and initializes it via + /// `readAllElements` and applying `dim2lvl` to each element's indices. /// /// Preconditions: /// * `lvlSizes` must be valid for `lvlRank`. @@ -234,9 +250,8 @@ const uint64_t *dim2lvl); /// Allocates a new sparse-tensor storage object with the given encoding, - /// initializes it by reading all the elements from the file, and then - /// closes the file. Preconditions/assertions are as per `readCOO` - /// and `SparseTensorStorage::newFromCOO`. + /// and initializes it via `readAllElements`. Preconditions/assertions + /// are as per `readCOO` and `SparseTensorStorage::newFromCOO`. template SparseTensorStorage * readSparseTensor(uint64_t lvlRank, const uint64_t *lvlSizes, @@ -263,17 +278,11 @@ /// Precondition: `indices` is valid for `getRank()`. char *readCOOIndices(uint64_t *indices); - /// The internal implementation of `readCOO`. We template over - /// `IsPattern` and `IsSymmetric` in order to perform LICM without - /// needing to duplicate the source code. - // - // TODO: We currently take the `dim2lvl` argument as a `PermutationRef` - // since that's what `readCOO` creates. Once we update `readCOO` to - // functionalize the mapping, then this helper will just take that - // same function. + /// The internal implementation of `readAllElements`. We template over + /// `IsPattern` and `IsSymmetric` in order to perform LICM without needing + /// to duplicate the source code. template - void readCOOLoop(uint64_t lvlRank, detail::PermutationRef dim2lvl, - SparseTensorCOO *lvlCOO); + void readAllElementsImpl(ElementConsumer yield); /// Reads the MME header of a general sparse matrix of type real. void readMMEHeader(); @@ -309,49 +318,53 @@ detail::PermutationRef d2l(dimRank, dim2lvl); // Prepare a COO object with the number of nonzeros as initial capacity. auto *lvlCOO = new SparseTensorCOO(lvlRank, lvlSizes, getNNZ()); + // Parse all the elements and add them to the COO. + std::vector lvlInd(lvlRank); + readAllElements([=, &lvlInd](const std::vector &dimInd, V value) { + d2l.pushforward(dimRank, dimInd.data(), lvlInd.data()); + lvlCOO->add(lvlInd, value); + }); + return lvlCOO; +} + +template +void SparseTensorReader::readAllElements(ElementConsumer yield) { + assert(isValid() && "Attempt to readAllElements() before readHeader()"); // Do some manual LICM, to avoid assertions in the for-loop. const bool IsPattern = isPattern(); const bool IsSymmetric = (isSymmetric() && getRank() == 2); if (IsPattern && IsSymmetric) - readCOOLoop(lvlRank, d2l, lvlCOO); + readAllElementsImpl(yield); else if (IsPattern) - readCOOLoop(lvlRank, d2l, lvlCOO); + readAllElementsImpl(yield); else if (IsSymmetric) - readCOOLoop(lvlRank, d2l, lvlCOO); + readAllElementsImpl(yield); else - readCOOLoop(lvlRank, d2l, lvlCOO); - // Close the file and return the COO. - closeFile(); - return lvlCOO; + readAllElementsImpl(yield); } template -void SparseTensorReader::readCOOLoop(uint64_t lvlRank, - detail::PermutationRef dim2lvl, - SparseTensorCOO *lvlCOO) { - const uint64_t dimRank = getRank(); - std::vector dimInd(dimRank); - std::vector lvlInd(lvlRank); +void SparseTensorReader::readAllElementsImpl(ElementConsumer yield) { + std::vector dimInd(getRank()); for (uint64_t nnz = getNNZ(), k = 0; k < nnz; ++k) { // We inline `readCOOElement` here in order to avoid redundant // assertions, since they're guaranteed by the call to `isValid()` // and the construction of `dimInd` above. char *linePtr = readCOOIndices(dimInd.data()); const V value = detail::readCOOValue(&linePtr); - dim2lvl.pushforward(dimRank, dimInd.data(), lvlInd.data()); // TODO: - lvlCOO->add(lvlInd, value); + yield(dimInd, value); // We currently chose to deal with symmetric matrices by fully // constructing them. In the future, we may want to make symmetry // implicit for storage reasons. if constexpr (IsSymmetric) if (dimInd[0] != dimInd[1]) { - // Must recompute `lvlInd`, since arbitrary mappings don't preserve swap. std::swap(dimInd[0], dimInd[1]); - d2l.pushforward(dimRank, dimInd.data(), lvlInd.data()); - lvlCOO->add(lvlInd, value); + yield(dimInd, value); } } + // Close the file, convert the COO to SparseTensorStorage, and return. + closeFile(); } /// Writes the sparse tensor to `filename` in extended FROSTT format.