diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -269,9 +269,10 @@ pointers[r].push_back(0); // Then assign contents from coordinate scheme tensor if provided. if (tensor) { - uint64_t nnz = tensor->getElements().size(); + const std::vector> &elements = tensor->getElements(); + uint64_t nnz = elements.size(); values.reserve(nnz); - fromCOO(tensor, 0, nnz, 0); + fromCOO(elements, 0, nnz, 0); } else if (allDense) { values.resize(sz, 0); } @@ -367,7 +368,7 @@ std::vector reord(rank); for (uint64_t r = 0; r < rank; r++) reord[r] = perm[rev[r]]; - toCOO(tensor, reord, 0, 0); + toCOO(*tensor, reord, 0, 0); assert(tensor->getElements().size() == values.size()); return tensor; } @@ -402,9 +403,8 @@ /// Initializes sparse tensor storage scheme from a memory-resident sparse /// tensor in coordinate scheme. This method prepares the pointers and /// indices arrays under the given per-dimension dense/sparse annotations. - void fromCOO(SparseTensorCOO *tensor, uint64_t lo, uint64_t hi, - uint64_t d) { - const std::vector> &elements = tensor->getElements(); + void fromCOO(const std::vector> &elements, uint64_t lo, + uint64_t hi, uint64_t d) { // Once dimensions are exhausted, insert the numerical values. assert(d <= getRank()); if (d == getRank()) { @@ -432,7 +432,7 @@ endDim(d + 1); full++; } - fromCOO(tensor, lo, seg, d + 1); + fromCOO(elements, lo, seg, d + 1); // And move on to next segment in interval. lo = seg; } @@ -449,12 +449,12 @@ /// Stores the sparse tensor storage scheme into a memory-resident sparse /// tensor in coordinate scheme. - void toCOO(SparseTensorCOO *tensor, std::vector &reord, + void toCOO(SparseTensorCOO &tensor, std::vector &reord, uint64_t pos, uint64_t d) { assert(d <= getRank()); if (d == getRank()) { assert(pos < values.size()); - tensor->add(idx, values[pos]); + tensor.add(idx, values[pos]); } else if (isCompressedDim(d)) { // Sparse dimension. for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {