diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h @@ -209,9 +209,30 @@ return detail::readCOOValue(&linePtr, isPattern()); } - /// Allocates a new COO object for `lvlSizes`, initializes it by reading - /// all the elements from the file and applying `dim2lvl` to their indices, - /// and then closes the file. + /// Reads all `getNNZ()`-many elements from the file, applies the + /// `dim2lvl` mapping to each element's indices, passes each element + /// to the callback, and finally closes the file. + /// + /// Preconditions: + /// * the file's actual value type can be read as `V`. + /// + /// Asserts: + /// * `isValid()`. + /// * `getRank() == dim2lvl.size()`. + /// * `lvlRank == dim2lvl.size()`. + // + // TODO: For now, we take the `dim2lvl` argument as a `PermutationRef` + // for convenience. However, this argument will eventually be changed + // to a function, so that it can perform arbitrary non-permutation mappings. + // + // NOTE: This method is factored out from `readCOO` so that it can be + // reused by the codegen-pass. + template + void readAllElements(uint64_t lvlRank, detail::PermutationRef dim2lvl, + ElementConsumer yield); + + /// Allocates a new COO object for `lvlSizes`, and initializes it via + /// `readAllElements`. /// /// Preconditions: /// * `lvlSizes` must be valid for `lvlRank`. @@ -234,9 +255,8 @@ const uint64_t *dim2lvl); /// Allocates a new sparse-tensor storage object with the given encoding, - /// initializes it by reading all the elements from the file, and then - /// closes the file. Preconditions/assertions are as per `readCOO` - /// and `SparseTensorStorage::newFromCOO`. + /// and initializes it via `readAllElements`. Preconditions/assertions + /// are as per `readCOO` and `SparseTensorStorage::newFromCOO`. template SparseTensorStorage * readSparseTensor(uint64_t lvlRank, const uint64_t *lvlSizes, @@ -263,17 +283,17 @@ /// Precondition: `indices` is valid for `getRank()`. char *readCOOIndices(uint64_t *indices); - /// The internal implementation of `readCOO`. We template over - /// `IsPattern` and `IsSymmetric` in order to perform LICM without - /// needing to duplicate the source code. + /// The internal implementation of `readAllElements`. We template over + /// `IsPattern` and `IsSymmetric` in order to perform LICM without needing + /// to duplicate the source code. // - // TODO: We currently take the `dim2lvl` argument as a `PermutationRef` - // since that's what `readCOO` creates. Once we update `readCOO` to - // functionalize the mapping, then this helper will just take that + // NOTE: We currently take the `dim2lvl` argument as a `PermutationRef` + // since that's what `readAllElements` takes. Once that's updated + // to functionalize the mapping, then this helper will just take that // same function. template - void readCOOLoop(uint64_t lvlRank, detail::PermutationRef dim2lvl, - SparseTensorCOO *lvlCOO); + void readAllElementsImpl(uint64_t lvlRank, detail::PermutationRef dim2lvl, + ElementConsumer yield); /// Reads the MME header of a general sparse matrix of type real. void readMMEHeader(); @@ -309,26 +329,38 @@ detail::PermutationRef d2l(dimRank, dim2lvl); // Prepare a COO object with the number of nonzeros as initial capacity. auto *lvlCOO = new SparseTensorCOO(lvlRank, lvlSizes, getNNZ()); + // Parse all the elements and add them to the COO. + readAllElements(lvlRank, d2l, + [lvlCOO](const std::vector &lvlInd, V value) { + lvlCOO->add(lvlInd, value); + }); + return lvlCOO; +} + +template +void SparseTensorReader::readAllElements(uint64_t lvlRank, + detail::PermutationRef dim2lvl, + ElementConsumer yield) { + assert(isValid() && "Attempt to readAllElements() before readHeader()"); + assert(getRank() == dim2lvl.size() && "Dimension-rank mismatch"); + assert(lvlRank == dim2lvl.size() && "Level-rank mismatch"); // Do some manual LICM, to avoid assertions in the for-loop. const bool IsPattern = isPattern(); const bool IsSymmetric = (isSymmetric() && getRank() == 2); if (IsPattern && IsSymmetric) - readCOOLoop(lvlRank, d2l, lvlCOO); + readAllElementsImpl(lvlRank, dim2lvl, yield); else if (IsPattern) - readCOOLoop(lvlRank, d2l, lvlCOO); + readAllElementsImpl(lvlRank, dim2lvl, yield); else if (IsSymmetric) - readCOOLoop(lvlRank, d2l, lvlCOO); + readAllElementsImpl(lvlRank, dim2lvl, yield); else - readCOOLoop(lvlRank, d2l, lvlCOO); - // Close the file and return the COO. - closeFile(); - return lvlCOO; + readAllElementsImpl(lvlRank, dim2lvl, yield); } template -void SparseTensorReader::readCOOLoop(uint64_t lvlRank, - detail::PermutationRef dim2lvl, - SparseTensorCOO *lvlCOO) { +void SparseTensorReader::readAllElementsImpl(uint64_t lvlRank, + detail::PermutationRef dim2lvl, + ElementConsumer yield) { const uint64_t dimRank = getRank(); std::vector dimInd(dimRank); std::vector lvlInd(lvlRank); @@ -340,14 +372,16 @@ const V value = detail::readCOOValue(&linePtr); dim2lvl.pushforward(dimRank, dimInd.data(), lvlInd.data()); // TODO: - lvlCOO->add(lvlInd, value); + yield(lvlInd, value); // We currently chose to deal with symmetric matrices by fully // constructing them. In the future, we may want to make symmetry // implicit for storage reasons. if constexpr (IsSymmetric) if (lvlInd[0] != lvlInd[1]) - lvlCOO->add({lvlInd[1], lvlInd[0]}, value); + yield({lvlInd[1], lvlInd[0]}, value); } + // Close the file, convert the COO to SparseTensorStorage, and return. + closeFile(); } /// Writes the sparse tensor to `filename` in extended FROSTT format.