diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp @@ -151,12 +151,12 @@ /// each differently annotated sparse tensor, this method provides a convenient /// "one-size-fits-all" solution that simply takes an input tensor and /// annotations to implement all required setup in a general manner. -template +template class SparseTensorStorage : public SparseTensorStorageBase { public: /// Constructs sparse tensor storage scheme following the given /// per-rank dimension dense/sparse annotations. - SparseTensorStorage(SparseTensor *tensor, uint8_t *sparsity) + SparseTensorStorage(SparseTensor *tensor, uint8_t *sparsity) : sizes(tensor->getSizes()), pointers(getRank()), indices(getRank()) { // Provide hints on capacity. // TODO: needs fine-tuning based on sparsity @@ -191,14 +191,23 @@ } void getValues(std::vector **out) override { *out = &values; } + // Factory method. + static SparseTensorStorage *newSparseTensor(SparseTensor *t, + uint8_t *s) { + t->sort(); // sort lexicographically + SparseTensorStorage *n = new SparseTensorStorage(t, s); + delete t; + return n; + } + private: /// Initializes sparse tensor storage scheme from a memory-resident /// representation of an external sparse tensor. This method prepares /// the pointers and indices arrays under the given per-rank dimension /// dense/sparse annotations. - void traverse(SparseTensor *tensor, uint8_t *sparsity, uint64_t lo, + void traverse(SparseTensor *tensor, uint8_t *sparsity, uint64_t lo, uint64_t hi, uint64_t d) { - const std::vector> &elements = tensor->getElements(); + const std::vector> &elements = tensor->getElements(); // Once dimensions are exhausted, insert the numerical values. if (d == getRank()) { values.push_back(lo < hi ? elements[lo].value : 0); @@ -321,9 +330,9 @@ } /// Reads a sparse tensor with the given filename into a memory-resident -/// sparse tensor in coordinate scheme. The external formats always store -/// the numerical values with the type double. -static SparseTensor *openTensor(char *filename, uint64_t *perm) { +/// sparse tensor in coordinate scheme. +template +static SparseTensor *openTensor(char *filename, uint64_t *perm) { // Open the file. FILE *file = fopen(filename, "r"); if (!file) { @@ -347,7 +356,7 @@ std::vector indices(rank); for (uint64_t r = 0; r < rank; r++) indices[perm[r]] = idata[2 + r]; - SparseTensor *tensor = new SparseTensor(indices, nnz); + SparseTensor *tensor = new SparseTensor(indices, nnz); // Read all nonzero elements. for (uint64_t k = 0; k < nnz; k++) { uint64_t idx = -1; @@ -359,6 +368,8 @@ // Add 0-based index. indices[perm[r]] = idx - 1; } + // The external formats always store the numerical values with the type + // double, but we cast these values to the sparse tensor object type. double value; if (fscanf(file, "%lg\n", &value) != 1) { fprintf(stderr, "Cannot find next value in %s\n", filename); @@ -366,21 +377,8 @@ } tensor->add(indices, value); } - // Close the file and return sorted tensor. + // Close the file and return tensor. fclose(file); - tensor->sort(); // sort lexicographically - return tensor; -} - -/// Templated reader. -template -void *newSparseTensor(char *filename, uint8_t *sparsity, uint64_t *perm, - uint64_t size) { - SparseTensor *t = openTensor(filename, perm); - assert(size == t->getRank()); // sparsity array must match rank - SparseTensorStorageBase *tensor = - new SparseTensorStorage(t, sparsity); - delete t; return tensor; } @@ -419,8 +417,11 @@ } #define CASE(p, i, v, P, I, V) \ - if (ptrTp == (p) && indTp == (i) && valTp == (v)) \ - return newSparseTensor(filename, sparsity, perm, asize) + if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \ + SparseTensor *tensor = openTensor(filename, perm); \ + assert(asize == tensor->getRank()); \ + return SparseTensorStorage::newSparseTensor(tensor, sparsity); \ + } #define IMPL1(RET, NAME, TYPE, LIB) \ RET NAME(void *tensor) { \