diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h @@ -33,6 +33,44 @@ namespace mlir { namespace sparse_tensor { +namespace detail { + +template +struct is_complex final : public std::false_type {}; + +template +struct is_complex> final : public std::true_type {}; + +/// Reads an element of a non-complex type for the current indices in +/// coordinate scheme. +template +inline std::enable_if_t::value, V> +readCOOValue(char **linePtr, bool is_pattern) { + // The external formats always store these numerical values with the type + // double, but we cast these values to the sparse tensor object type. + // For a pattern tensor, we arbitrarily pick the value 1 for all entries. + return is_pattern ? 1.0 : strtod(*linePtr, linePtr); +} + +/// Reads an element of a complex type for the current indices in +/// coordinate scheme. +template +inline std::enable_if_t::value, V> readCOOValue(char **linePtr, + bool is_pattern) { + // Read two values to make a complex. The external formats always store + // numerical values with the type double, but we cast these values to the + // sparse tensor object type. For a pattern tensor, we arbitrarily pick the + // value 1 for all entries. + double re = is_pattern ? 1.0 : strtod(*linePtr, linePtr); + double im = is_pattern ? 1.0 : strtod(*linePtr, linePtr); + // Avoiding brace-notation since that forbids narrowing to `float`. + return V(re, im); +} + +} // namespace detail + +//===----------------------------------------------------------------------===// + // TODO: benchmark whether to keep various methods inline vs moving them // off to the cpp file. @@ -132,6 +170,31 @@ /// valid after parsing the header. void assertMatchesShape(uint64_t rank, const uint64_t *shape) const; + /// Reads a sparse tensor element from the next line in the input file and + /// returns the value of the element. Stores the coordinates of the element + /// to the `indices` array. + template + V readCOOElement(uint64_t rank, uint64_t *indices, + const uint64_t *perm = nullptr) { + assert(rank == getRank() && "Rank mismatch"); + char *linePtr = readLine(); + if (perm) + for (uint64_t r = 0; r < rank; ++r) { + // Parse the 1-based index. + uint64_t idx = strtoul(linePtr, &linePtr, 10); + // Store the 0-based index. + indices[perm[r]] = idx - 1; + } + else + for (uint64_t r = 0; r < rank; ++r) { + // Parse the 1-based index. + uint64_t idx = strtoul(linePtr, &linePtr, 10); + // Store the 0-based index. + indices[r] = idx - 1; + } + return detail::readCOOValue(&linePtr, isPattern()); + } + private: /// Reads the MME header of a general sparse matrix of type real. void readMMEHeader(); @@ -152,41 +215,6 @@ }; //===----------------------------------------------------------------------===// -namespace detail { - -template -struct is_complex final : public std::false_type {}; - -template -struct is_complex> final : public std::true_type {}; - -/// Reads an element of a non-complex type for the current indices in -/// coordinate scheme. -template -inline std::enable_if_t::value, V> -readCOOValue(char **linePtr, bool is_pattern) { - // The external formats always store these numerical values with the type - // double, but we cast these values to the sparse tensor object type. - // For a pattern tensor, we arbitrarily pick the value 1 for all entries. - return is_pattern ? 1.0 : strtod(*linePtr, linePtr); -} - -/// Reads an element of a complex type for the current indices in -/// coordinate scheme. -template -inline std::enable_if_t::value, V> readCOOValue(char **linePtr, - bool is_pattern) { - // Read two values to make a complex. The external formats always store - // numerical values with the type double, but we cast these values to the - // sparse tensor object type. For a pattern tensor, we arbitrarily pick the - // value 1 for all entries. - double re = is_pattern ? 1.0 : strtod(*linePtr, linePtr); - double im = is_pattern ? 1.0 : strtod(*linePtr, linePtr); - // Avoiding brace-notation since that forbids narrowing to `float`. - return V(re, im); -} - -} // namespace detail /// Reads a sparse tensor with the given filename into a memory-resident /// sparse tensor in coordinate scheme. @@ -211,14 +239,7 @@ // Read all nonzero elements. std::vector indices(rank); for (uint64_t k = 0; k < nnz; ++k) { - char *linePtr = stfile.readLine(); - for (uint64_t r = 0; r < rank; ++r) { - // Parse the 1-based index. - uint64_t idx = strtoul(linePtr, &linePtr, 10); - // Add the 0-based index. - indices[perm[r]] = idx - 1; - } - const V value = detail::readCOOValue(&linePtr, stfile.isPattern()); + const V value = stfile.readCOOElement(rank, indices.data(), perm); // TODO: coo->add(indices, value); // We currently chose to deal with symmetric matrices by fully diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp @@ -626,13 +626,8 @@ index_type *indices = iref->data + iref->offset; \ SparseTensorReader *stfile = static_cast(p); \ index_type rank = stfile->getRank(); \ - char *linePtr = stfile->readLine(); \ - for (index_type r = 0; r < rank; ++r) { \ - uint64_t idx = strtoul(linePtr, &linePtr, 10); \ - indices[r] = idx - 1; \ - } \ V *value = vref->data + vref->offset; \ - *value = detail::readCOOValue(&linePtr, stfile->isPattern()); \ + *value = stfile->readCOOElement(rank, indices); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT) #undef IMPL_GETNEXT