diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h @@ -188,6 +188,22 @@ } // namespace detail +/// Reads a sparse tensor element from the next line in the input file and +/// returns the value of the element. Stores the coordinates of the element +/// to the `indices` array. +template +inline V readCOOElement(SparseTensorReader &stfile, uint64_t *indices, + uint64_t rank, const uint64_t *perm = nullptr) { + char *linePtr = stfile.readLine(); + for (uint64_t r = 0; r < rank; ++r) { + // Parse the 1-based index. + uint64_t idx = strtoul(linePtr, &linePtr, 10); + // Store the 0-based index. + indices[perm ? perm[r] : r] = idx - 1; + } + return detail::readCOOValue(&linePtr, stfile.isPattern()); +} + /// Reads a sparse tensor with the given filename into a memory-resident /// sparse tensor in coordinate scheme. template @@ -211,14 +227,7 @@ // Read all nonzero elements. std::vector indices(rank); for (uint64_t k = 0; k < nnz; ++k) { - char *linePtr = stfile.readLine(); - for (uint64_t r = 0; r < rank; ++r) { - // Parse the 1-based index. - uint64_t idx = strtoul(linePtr, &linePtr, 10); - // Add the 0-based index. - indices[perm[r]] = idx - 1; - } - const V value = detail::readCOOValue(&linePtr, stfile.isPattern()); + const V value = readCOOElement(stfile, indices.data(), rank, perm); // TODO: coo->add(indices, value); // We currently chose to deal with symmetric matrices by fully diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp @@ -625,12 +625,7 @@ index_type *indices = iref->data + iref->offset; \ SparseTensorReader *stfile = static_cast(p); \ index_type rank = stfile->getRank(); \ - char *linePtr = stfile->readLine(); \ - for (index_type r = 0; r < rank; ++r) { \ - uint64_t idx = strtoul(linePtr, &linePtr, 10); \ - indices[r] = idx - 1; \ - } \ - return detail::readCOOValue(&linePtr, stfile->isPattern()); \ + return readCOOElement(*stfile, indices, rank); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT) #undef IMPL_GETNEXT