diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h @@ -274,6 +274,11 @@ /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc. MLIR_CRUNNERUTILS_EXPORT char *getTensorFilename(index_type id); +/// Helper function to read the header of a file and return the +/// shape/sizes, without parsing the elements of the file. +MLIR_CRUNNERUTILS_EXPORT void readSparseTensorShape(char *filename, + std::vector *out); + /// Initializes sparse tensor from a COO-flavored format expressed using /// C-style data structures. The expected parameters are: /// diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -88,9 +88,11 @@ exit(1); \ } -// TODO: adjust this so it can be used by `openSparseTensorCOO` too. -// That version doesn't have the permutation, and the `dimSizes` are -// a pointer/C-array rather than `std::vector`. +// TODO: try to unify this with `SparseTensorFile::assertMatchesShape` +// which is used by `openSparseTensorCOO`. It's easy enough to resolve +// the `std::vector` vs pointer mismatch for `dimSizes`; but it's trickier +// to resolve the presence/absence of `perm` (without introducing extra +// overhead), so perhaps the code duplication is unavoidable. // /// Asserts that the `dimSizes` (in target-order) under the `perm` (mapping /// semantic-order to target-order) are a refinement of the desired `shape` @@ -1099,9 +1101,128 @@ return token; } +/// This class abstracts over the information stored in file headers, +/// as well as providing the buffers and methods for parsing those headers. +class SparseTensorFile final { +public: + explicit SparseTensorFile(char *filename) : filename(filename) { + assert(filename && "Received nullptr for filename"); + } + + // Disallows copying, to avoid duplicating the `file` pointer. + SparseTensorFile(const SparseTensorFile &) = delete; + SparseTensorFile &operator=(const SparseTensorFile &) = delete; + + // This dtor tries to avoid leaking the `file`. (Though it's better + // to call `closeFile` explicitly when possible, since there are + // circumstances where dtors are not called reliably.) + ~SparseTensorFile() { closeFile(); } + + /// Opens the file for reading. + void openFile() { + if (file) + FATAL("Already opened file %s\n", filename); + file = fopen(filename, "r"); + if (!file) + FATAL("Cannot find file %s\n", filename); + } + + /// Closes the file. + void closeFile() { + if (file) { + fclose(file); + file = nullptr; + } + } + + // TODO(wrengr/bixia): figure out how to reorganize the element-parsing + // loop of `openSparseTensorCOO` into methods of this class, so we can + // avoid leaking access to the `line` pointer (both for general hygiene + // and because we can't mark it const due to the second argument of + // `strtoul`/`strtoud` being `char * *restrict` rather than + // `char const* *restrict`). + // + /// Attempts to read a line from the file. + char *readLine() { + if (fgets(line, kColWidth, file)) + return line; + FATAL("Cannot read next line of %s\n", filename); + } + + /// Reads and parses the file's header. + void readHeader() { + assert(file && "Attempt to readHeader() before openFile()"); + if (strstr(filename, ".mtx")) + readMMEHeader(); + else if (strstr(filename, ".tns")) + readExtFROSTTHeader(); + else + FATAL("Unknown format %s\n", filename); + assert(isValid && "Failed to read the header"); + } + + /// Gets the MME "pattern" property setting. Is only valid after + /// parsing the header. + bool isPattern() const { + assert(isValid && "Attempt to isPattern() before readHeader()"); + return isPattern_; + } + + /// Gets the MME "symmetric" property setting. Is only valid after + /// parsing the header. + bool isSymmetric() const { + assert(isValid && "Attempt to isSymmetric() before readHeader()"); + return isSymmetric_; + } + + /// Gets the rank of the tensor. Is only valid after parsing the header. + uint64_t getRank() const { + assert(isValid && "Attempt to getRank() before readHeader()"); + return idata[0]; + } + + /// Gets the number of non-zeros. Is only valid after parsing the header. + uint64_t getNNZ() const { + assert(isValid && "Attempt to getNNZ() before readHeader()"); + return idata[1]; + } + + /// Gets the dimension-sizes array. The pointer itself is always + /// valid; however, the values stored therein are only valid after + /// parsing the header. + const uint64_t *getDimSizes() const { return idata + 2; } + + /// Safely gets the size of the given dimension. Is only valid + /// after parsing the header. + uint64_t getDimSize(uint64_t d) const { + assert(d < getRank()); + return idata[2 + d]; + } + + /// Asserts the shape subsumes the actual dimension sizes. Is only + /// valid after parsing the header. + void assertMatchesShape(uint64_t rank, const uint64_t *shape) const { + assert(rank == getRank() && "Rank mismatch"); + for (uint64_t r = 0; r < rank; r++) + assert((shape[r] == 0 || shape[r] == idata[2 + r]) && + "Dimension size mismatch"); + } + +private: + void readMMEHeader(); + void readExtFROSTTHeader(); + + const char *filename; + FILE *file = nullptr; + bool isValid = false; + bool isPattern_ = false; + bool isSymmetric_ = false; + uint64_t idata[512]; + char line[kColWidth]; +}; + /// Read the MME header of a general sparse matrix of type real. -static void readMMEHeader(FILE *file, char *filename, char *line, - uint64_t *idata, bool *isPattern, bool *isSymmetric) { +void SparseTensorFile::readMMEHeader() { char header[64]; char object[64]; char format[64]; @@ -1112,19 +1233,18 @@ symmetry) != 5) FATAL("Corrupt header in %s\n", filename); // Set properties - *isPattern = (strcmp(toLower(field), "pattern") == 0); - *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0); + isPattern_ = (strcmp(toLower(field), "pattern") == 0); + isSymmetric_ = (strcmp(toLower(symmetry), "symmetric") == 0); // Make sure this is a general sparse matrix. if (strcmp(toLower(header), "%%matrixmarket") || strcmp(toLower(object), "matrix") || strcmp(toLower(format), "coordinate") || - (strcmp(toLower(field), "real") && !(*isPattern)) || - (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) + (strcmp(toLower(field), "real") && !isPattern_) || + (strcmp(toLower(symmetry), "general") && !isSymmetric_)) FATAL("Cannot find a general sparse matrix in %s\n", filename); // Skip comments. while (true) { - if (!fgets(line, kColWidth, file)) - FATAL("Cannot find data in %s\n", filename); + readLine(); if (line[0] != '%') break; } @@ -1133,18 +1253,17 @@ if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3, idata + 1) != 3) FATAL("Cannot find size in %s\n", filename); + isValid = true; } /// Read the "extended" FROSTT header. Although not part of the documented /// format, we assume that the file starts with optional comments followed /// by two lines that define the rank, the number of nonzeros, and the /// dimensions sizes (one per rank) of the sparse tensor. -static void readExtFROSTTHeader(FILE *file, char *filename, char *line, - uint64_t *idata) { +void SparseTensorFile::readExtFROSTTHeader() { // Skip comments. while (true) { - if (!fgets(line, kColWidth, file)) - FATAL("Cannot find data in %s\n", filename); + readLine(); if (line[0] != '#') break; } @@ -1155,7 +1274,8 @@ for (uint64_t r = 0; r < idata[0]; r++) if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) FATAL("Cannot find dimension size %s\n", filename); - fgets(line, kColWidth, file); // end of line + readLine(); // end of line + isValid = true; } /// Reads a sparse tensor with the given filename into a memory-resident @@ -1164,38 +1284,19 @@ static SparseTensorCOO *openSparseTensorCOO(char *filename, uint64_t rank, const uint64_t *shape, const uint64_t *perm) { - // Open the file. - assert(filename && "Received nullptr for filename"); - FILE *file = fopen(filename, "r"); - if (!file) - FATAL("Cannot find file %s\n", filename); - // Perform some file format dependent set up. - char line[kColWidth]; - uint64_t idata[512]; - bool isPattern = false; - bool isSymmetric = false; - if (strstr(filename, ".mtx")) { - readMMEHeader(file, filename, line, idata, &isPattern, &isSymmetric); - } else if (strstr(filename, ".tns")) { - readExtFROSTTHeader(file, filename, line, idata); - } else { - FATAL("Unknown format %s\n", filename); - } + SparseTensorFile stfile(filename); + stfile.openFile(); + stfile.readHeader(); + stfile.assertMatchesShape(rank, shape); // Prepare sparse tensor object with per-dimension sizes // and the number of nonzeros as initial capacity. - assert(rank == idata[0] && "rank mismatch"); - uint64_t nnz = idata[1]; - for (uint64_t r = 0; r < rank; r++) - assert((shape[r] == 0 || shape[r] == idata[2 + r]) && - "dimension size mismatch"); - SparseTensorCOO *tensor = - SparseTensorCOO::newSparseTensorCOO(rank, idata + 2, perm, nnz); + uint64_t nnz = stfile.getNNZ(); + auto *coo = SparseTensorCOO::newSparseTensorCOO(rank, stfile.getDimSizes(), + perm, nnz); // Read all nonzero elements. std::vector indices(rank); for (uint64_t k = 0; k < nnz; k++) { - if (!fgets(line, kColWidth, file)) - FATAL("Cannot find next line of data in %s\n", filename); - char *linePtr = line; + char *linePtr = stfile.readLine(); for (uint64_t r = 0; r < rank; r++) { uint64_t idx = strtoul(linePtr, &linePtr, 10); // Add 0-based index. @@ -1204,17 +1305,18 @@ // The external formats always store the numerical values with the type // double, but we cast these values to the sparse tensor object type. // For a pattern tensor, we arbitrarily pick the value 1 for all entries. - double value = isPattern ? 1.0 : strtod(linePtr, &linePtr); - tensor->add(indices, value); + double value = stfile.isPattern() ? 1.0 : strtod(linePtr, &linePtr); + // TODO: + coo->add(indices, value); // We currently chose to deal with symmetric matrices by fully constructing // them. In the future, we may want to make symmetry implicit for storage // reasons. - if (isSymmetric && indices[0] != indices[1]) - tensor->add({indices[1], indices[0]}, value); + if (stfile.isSymmetric() && indices[0] != indices[1]) + coo->add({indices[1], indices[0]}, value); } // Close the file and return tensor. - fclose(file); - return tensor; + stfile.closeFile(); + return coo; } /// Writes the sparse tensor to `dest` in extended FROSTT format. @@ -1670,6 +1772,18 @@ return env; } +void readSparseTensorShape(char *filename, std::vector *out) { + assert(out && "Received nullptr for out-parameter"); + SparseTensorFile stfile(filename); + stfile.openFile(); + stfile.readHeader(); + stfile.closeFile(); + const uint64_t rank = stfile.getRank(); + const uint64_t *dimSizes = stfile.getDimSizes(); + out->reserve(rank); + out->assign(dimSizes, dimSizes + rank); +} + // TODO: generalize beyond 64-bit indices. #define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V) \ void *convertToMLIRSparseTensor##VNAME( \