diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp @@ -76,8 +76,8 @@ } /// Adds element as indices and value. void add(const std::vector &ind, double val) { - assert(sizes.size() == ind.size()); - for (int64_t r = 0, rank = sizes.size(); r < rank; r++) + assert(getRank() == ind.size()); + for (int64_t r = 0, rank = getRank(); r < rank; r++) assert(ind[r] < sizes[r]); // within bounds elements.emplace_back(Element(ind, val)); } @@ -85,6 +85,8 @@ void sort() { std::sort(elements.begin(), elements.end(), lexOrder); } /// Primitive one-time iteration. const Element &next() { return elements[pos++]; } + /// Returns rank. + uint64_t getRank() const { return sizes.size(); } /// Getter for sizes array. const std::vector &getSizes() const { return sizes; } /// Getter for elements array. @@ -139,13 +141,13 @@ /// Constructs sparse tensor storage scheme following the given /// per-rank dimension dense/sparse annotations. SparseTensorStorage(SparseTensor *tensor, bool *sparsity) - : sizes(tensor->getSizes()), pointers(sizes.size()), - indices(sizes.size()) { + : sizes(tensor->getSizes()), pointers(getRank()), indices(getRank()) { // Provide hints on capacity. // TODO: needs fine-tuning based on sparsity - values.reserve(tensor->getElements().size()); - for (uint64_t d = 0, s = 1, rank = sizes.size(); d < rank; d++) { - s *= tensor->getSizes()[d]; + uint64_t nnz = tensor->getElements().size(); + values.reserve(nnz); + for (uint64_t d = 0, s = 1, rank = getRank(); d < rank; d++) { + s *= sizes[d]; if (sparsity[d]) { pointers[d].reserve(s + 1); indices[d].reserve(s); @@ -153,12 +155,16 @@ } } // Then setup the tensor. - traverse(tensor, sparsity, 0, tensor->getElements().size(), 0); + traverse(tensor, sparsity, 0, nnz, 0); } virtual ~SparseTensorStorage() {} + uint64_t getRank() const { return sizes.size(); } + uint64_t getDimSize(uint64_t d) override { return sizes[d]; } + + // Partially specialize these three methods based on template types. void getPointers(std::vector

**out, uint64_t d) override { *out = &pointers[d]; } @@ -176,7 +182,7 @@ uint64_t d) { const std::vector &elements = tensor->getElements(); // Once dimensions are exhausted, insert the numerical values. - if (d == sizes.size()) { + if (d == getRank()) { values.push_back(lo < hi ? elements[lo].value : 0.0); return; } @@ -221,9 +227,10 @@ /// Templated reader. template -void *newSparseTensor(char *filename, bool *sparsity) { +void *newSparseTensor(char *filename, bool *sparsity, uint64_t size) { uint64_t idata[64]; SparseTensor *t = static_cast(openTensorC(filename, idata)); + assert(size == t->getRank()); // sparsity array must match rank SparseTensorStorageBase *tensor = new SparseTensorStorage(t, sparsity); delete t; @@ -481,21 +488,29 @@ assert(astride == 1); bool *sparsity = abase + aoff; if (ptrTp == kU64 && indTp == kU64 && valTp == kF64) - return newSparseTensor(filename, sparsity); + return newSparseTensor(filename, sparsity, + asize); if (ptrTp == kU64 && indTp == kU64 && valTp == kF32) - return newSparseTensor(filename, sparsity); + return newSparseTensor(filename, sparsity, + asize); if (ptrTp == kU64 && indTp == kU32 && valTp == kF64) - return newSparseTensor(filename, sparsity); + return newSparseTensor(filename, sparsity, + asize); if (ptrTp == kU64 && indTp == kU32 && valTp == kF32) - return newSparseTensor(filename, sparsity); + return newSparseTensor(filename, sparsity, + asize); if (ptrTp == kU32 && indTp == kU64 && valTp == kF64) - return newSparseTensor(filename, sparsity); + return newSparseTensor(filename, sparsity, + asize); if (ptrTp == kU32 && indTp == kU64 && valTp == kF32) - return newSparseTensor(filename, sparsity); + return newSparseTensor(filename, sparsity, + asize); if (ptrTp == kU32 && indTp == kU32 && valTp == kF64) - return newSparseTensor(filename, sparsity); + return newSparseTensor(filename, sparsity, + asize); if (ptrTp == kU32 && indTp == kU32 && valTp == kF32) - return newSparseTensor(filename, sparsity); + return newSparseTensor(filename, sparsity, + asize); fputs("unsupported combination of types\n", stderr); exit(1); }