diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h @@ -154,50 +154,36 @@ //===----------------------------------------------------------------------===// namespace detail { -// Adds a value to a tensor in coordinate scheme. If is_symmetric_value is true, -// also adds the value to its symmetric location. -template -inline void addValue(T *coo, V value, const std::vector indices, - bool is_symmetric_value) { - // TODO: - coo->add(indices, value); - // We currently chose to deal with symmetric matrices by fully constructing - // them. In the future, we may want to make symmetry implicit for storage - // reasons. - if (is_symmetric_value) - coo->add({indices[1], indices[0]}, value); +template +struct is_complex final : public std::false_type {}; + +template +struct is_complex> final : public std::true_type {}; + +/// Reads an element of a non-complex type for the current indices in +/// coordinate scheme. +template +inline typename std::enable_if::value, V>::type +readCOOValue(char **linePtr, bool is_pattern) { + // The external formats always store these numerical values with the type + // double, but we cast these values to the sparse tensor object type. + // For a pattern tensor, we arbitrarily pick the value 1 for all entries. + return is_pattern ? 1.0 : strtod(*linePtr, linePtr); } /// Reads an element of a complex type for the current indices in /// coordinate scheme. template -inline void readCOOValue(SparseTensorCOO> *coo, - const std::vector indices, char **linePtr, - bool is_pattern, bool add_symmetric_value) { +inline typename std::enable_if::value, V>::type +readCOOValue(char **linePtr, bool is_pattern) { // Read two values to make a complex. The external formats always store // numerical values with the type double, but we cast these values to the // sparse tensor object type. For a pattern tensor, we arbitrarily pick the // value 1 for all entries. - V re = is_pattern ? 1.0 : strtod(*linePtr, linePtr); - V im = is_pattern ? 1.0 : strtod(*linePtr, linePtr); - std::complex value = {re, im}; - addValue(coo, value, indices, add_symmetric_value); -} - -// Reads an element of a non-complex type for the current indices in coordinate -// scheme. -template , V>::value && - !std::is_same, V>::value>::type * = nullptr> -inline void readCOOValue(SparseTensorCOO *coo, - const std::vector indices, char **linePtr, - bool is_pattern, bool is_symmetric_value) { - // The external formats always store these numerical values with the type - // double, but we cast these values to the sparse tensor object type. - // For a pattern tensor, we arbitrarily pick the value 1 for all entries. - double value = is_pattern ? 1.0 : strtod(*linePtr, linePtr); - addValue(coo, value, indices, is_symmetric_value); + double re = is_pattern ? 1.0 : strtod(*linePtr, linePtr); + double im = is_pattern ? 1.0 : strtod(*linePtr, linePtr); + // Avoiding brace-notation since that forbids narrowing to `float`. + return V(re, im); } } // namespace detail @@ -232,8 +218,14 @@ // Add the 0-based index. indices[perm[r]] = idx - 1; } - detail::readCOOValue(coo, indices, &linePtr, stfile.isPattern(), - stfile.isSymmetric() && indices[0] != indices[1]); + const V value = detail::readCOOValue(&linePtr, stfile.isPattern()); + // TODO: + coo->add(indices, value); + // We currently chose to deal with symmetric matrices by fully + // constructing them. In the future, we may want to make symmetry + // implicit for storage reasons. + if (stfile.isSymmetric() && indices[0] != indices[1]) + coo->add({indices[1], indices[0]}, value); } // Close the file and return tensor. stfile.closeFile();