diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h @@ -42,32 +42,46 @@ template struct is_complex> final : public std::true_type {}; -/// Reads an element of a non-complex type for the current indices in -/// coordinate scheme. -template -inline std::enable_if_t::value, V> -readCOOValue(char **linePtr, bool is_pattern) { +/// Returns an element-value of non-complex type. If `IsPattern` is true, +/// then returns an arbitrary value. If `IsPattern` is false, then +/// reads the value from the current line buffer beginning at `linePtr`. +template +inline std::enable_if_t::value, V> readCOOValue(char **linePtr) { // The external formats always store these numerical values with the type // double, but we cast these values to the sparse tensor object type. // For a pattern tensor, we arbitrarily pick the value 1 for all entries. - return is_pattern ? 1.0 : strtod(*linePtr, linePtr); + if constexpr (IsPattern) + return 1.0; + return strtod(*linePtr, linePtr); } -/// Reads an element of a complex type for the current indices in -/// coordinate scheme. -template -inline std::enable_if_t::value, V> readCOOValue(char **linePtr, - bool is_pattern) { +/// Returns an element-value of complex type. If `IsPattern` is true, +/// then returns an arbitrary value. If `IsPattern` is false, then reads +/// the value from the current line buffer beginning at `linePtr`. +template +inline std::enable_if_t::value, V> readCOOValue(char **linePtr) { // Read two values to make a complex. The external formats always store // numerical values with the type double, but we cast these values to the // sparse tensor object type. For a pattern tensor, we arbitrarily pick the // value 1 for all entries. - double re = is_pattern ? 1.0 : strtod(*linePtr, linePtr); - double im = is_pattern ? 1.0 : strtod(*linePtr, linePtr); + if constexpr (IsPattern) + return V(1.0, 1.0); + double re = strtod(*linePtr, linePtr); + double im = strtod(*linePtr, linePtr); // Avoiding brace-notation since that forbids narrowing to `float`. return V(re, im); } +/// Returns an element-value. If `is_pattern` is true, then returns an +/// arbitrary value. If `is_pattern` is false, then reads the value from +/// the current line buffer beginning at `linePtr`. +template +inline V readCOOValue(char **linePtr, bool is_pattern) { + if (is_pattern) + return readCOOValue(linePtr); + return readCOOValue(linePtr); +} + } // namespace detail //===----------------------------------------------------------------------===// @@ -249,6 +263,18 @@ /// Precondition: `indices` is valid for `getRank()`. char *readCOOIndices(uint64_t *indices); + /// The internal implementation of `readCOO`. We template over + /// `IsPattern` and `IsSymmetric` in order to perform LICM without + /// needing to duplicate the source code. + // + // TODO: We currently take the `dim2lvl` argument as a `PermutationRef` + // since that's what `readCOO` creates. Once we update `readCOO` to + // functionalize the mapping, then this helper will just take that + // same function. + template + void readCOOLoop(uint64_t lvlRank, detail::PermutationRef dim2lvl, + SparseTensorCOO *lvlCOO); + /// Reads the MME header of a general sparse matrix of type real. void readMMEHeader(); @@ -282,36 +308,50 @@ assert(lvlRank == dimRank && "Rank mismatch"); detail::PermutationRef d2l(dimRank, dim2lvl); // Prepare a COO object with the number of nonzeros as initial capacity. - const uint64_t nnz = getNNZ(); - auto *lvlCOO = new SparseTensorCOO(lvlRank, lvlSizes, nnz); - // Read all nonzero elements. + auto *lvlCOO = new SparseTensorCOO(lvlRank, lvlSizes, getNNZ()); + // Do some manual LICM, to avoid assertions in the for-loop. + const bool IsPattern = isPattern(); + const bool IsSymmetric = (isSymmetric() && getRank() == 2); + if (IsPattern && IsSymmetric) + readCOOLoop(lvlRank, d2l, lvlCOO); + else if (IsPattern) + readCOOLoop(lvlRank, d2l, lvlCOO); + else if (IsSymmetric) + readCOOLoop(lvlRank, d2l, lvlCOO); + else + readCOOLoop(lvlRank, d2l, lvlCOO); + // Close the file and return the COO. + closeFile(); + return lvlCOO; +} + +template +void SparseTensorReader::readCOOLoop(uint64_t lvlRank, + detail::PermutationRef dim2lvl, + SparseTensorCOO *lvlCOO) { + const uint64_t dimRank = getRank(); std::vector dimInd(dimRank); std::vector lvlInd(lvlRank); - // Do some manual LICM, to avoid assertions in the for-loop. - const bool addSymmetric = (isSymmetric() && dimRank == 2); - const bool isPattern_ = isPattern(); - for (uint64_t k = 0; k < nnz; ++k) { + for (uint64_t nnz = getNNZ(), k = 0; k < nnz; ++k) { // We inline `readCOOElement` here in order to avoid redundant // assertions, since they're guaranteed by the call to `isValid()` // and the construction of `dimInd` above. char *linePtr = readCOOIndices(dimInd.data()); - const V value = detail::readCOOValue(&linePtr, isPattern_); - d2l.pushforward(dimRank, dimInd.data(), lvlInd.data()); + const V value = detail::readCOOValue(&linePtr); + dim2lvl.pushforward(dimRank, dimInd.data(), lvlInd.data()); // TODO: lvlCOO->add(lvlInd, value); // We currently chose to deal with symmetric matrices by fully // constructing them. In the future, we may want to make symmetry // implicit for storage reasons. - if (addSymmetric && dimInd[0] != dimInd[1]) { - // Must recompute `lvlInd`, since arbitrary mappings don't preserve swap. - std::swap(dimInd[0], dimInd[1]); - d2l.pushforward(dimRank, dimInd.data(), lvlInd.data()); - lvlCOO->add(lvlInd, value); - } + if constexpr (IsSymmetric) + if (dimInd[0] != dimInd[1]) { + // Must recompute `lvlInd`, since arbitrary mappings don't preserve swap. + std::swap(dimInd[0], dimInd[1]); + d2l.pushforward(dimRank, dimInd.data(), lvlInd.data()); + lvlCOO->add(lvlInd, value); + } } - // Close the file and return the COO. - closeFile(); - return lvlCOO; } /// Writes the sparse tensor to `filename` in extended FROSTT format.