diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -84,6 +84,17 @@ return lhs * rhs; } +// This macro helps minimize repetition of this idiom, as well as ensuring +// we have some additional output indicating where the error is coming from. +// (Since `fprintf` doesn't provide a stacktrace, this helps make it easier +// to track down whether an error is coming from our code vs somewhere else +// in MLIR.) +#define FATAL(...) \ + { \ + fprintf(stderr, "SparseTensorUtils: " __VA_ARGS__); \ + exit(1); \ + } + // TODO: adjust this so it can be used by `openSparseTensorCOO` too. // That version doesn't have the permutation, and the `dimSizes` are // a pointer/C-array rather than `std::vector`. @@ -262,6 +273,11 @@ template class SparseTensorEnumeratorBase; +// Helper macro for generating error messages when some +// `SparseTensorStorage` is cast to `SparseTensorStorageBase` +// and then the wrong "partial method specialization" is called. +#define FATAL_PIV(NAME) FATAL(" type mismatch for: " #NAME); + /// Abstract base class for `SparseTensorStorage`. This class /// takes responsibility for all the ``-independent aspects /// of the tensor (e.g., shape, sparsity, permutation). In addition, @@ -325,37 +341,53 @@ #define DECL_NEWENUMERATOR(VNAME, V) \ virtual void newEnumerator(SparseTensorEnumeratorBase **, uint64_t, \ const uint64_t *) const { \ - fatal("newEnumerator" #VNAME); \ + FATAL_PIV("newEnumerator" #VNAME); \ } FOREVERY_V(DECL_NEWENUMERATOR) #undef DECL_NEWENUMERATOR /// Overhead storage. - virtual void getPointers(std::vector **, uint64_t) { fatal("p64"); } - virtual void getPointers(std::vector **, uint64_t) { fatal("p32"); } - virtual void getPointers(std::vector **, uint64_t) { fatal("p16"); } - virtual void getPointers(std::vector **, uint64_t) { fatal("p8"); } - virtual void getIndices(std::vector **, uint64_t) { fatal("i64"); } - virtual void getIndices(std::vector **, uint64_t) { fatal("i32"); } - virtual void getIndices(std::vector **, uint64_t) { fatal("i16"); } - virtual void getIndices(std::vector **, uint64_t) { fatal("i8"); } + virtual void getPointers(std::vector **, uint64_t) { + FATAL_PIV("p64"); + } + virtual void getPointers(std::vector **, uint64_t) { + FATAL_PIV("p32"); + } + virtual void getPointers(std::vector **, uint64_t) { + FATAL_PIV("p16"); + } + virtual void getPointers(std::vector **, uint64_t) { + FATAL_PIV("p8"); + } + virtual void getIndices(std::vector **, uint64_t) { + FATAL_PIV("i64"); + } + virtual void getIndices(std::vector **, uint64_t) { + FATAL_PIV("i32"); + } + virtual void getIndices(std::vector **, uint64_t) { + FATAL_PIV("i16"); + } + virtual void getIndices(std::vector **, uint64_t) { + FATAL_PIV("i8"); + } /// Primary storage. #define DECL_GETVALUES(VNAME, V) \ - virtual void getValues(std::vector **) { fatal("getValues" #VNAME); } + virtual void getValues(std::vector **) { FATAL_PIV("getValues" #VNAME); } FOREVERY_V(DECL_GETVALUES) #undef DECL_GETVALUES /// Element-wise insertion in lexicographic index order. #define DECL_LEXINSERT(VNAME, V) \ - virtual void lexInsert(const uint64_t *, V) { fatal("lexInsert" #VNAME); } + virtual void lexInsert(const uint64_t *, V) { FATAL_PIV("lexInsert" #VNAME); } FOREVERY_V(DECL_LEXINSERT) #undef DECL_LEXINSERT /// Expanded insertion. #define DECL_EXPINSERT(VNAME, V) \ virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) { \ - fatal("expInsert" #VNAME); \ + FATAL_PIV("expInsert" #VNAME); \ } FOREVERY_V(DECL_EXPINSERT) #undef DECL_EXPINSERT @@ -374,16 +406,13 @@ SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete; private: - static void fatal(const char *tp) { - fprintf(stderr, "unsupported %s\n", tp); - exit(1); - } - const std::vector dimSizes; std::vector rev; const std::vector dimTypes; }; +#undef FATAL_PIV + // Forward. template class SparseTensorEnumerator; @@ -1122,10 +1151,8 @@ char symmetry[64]; // Read header line. if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field, - symmetry) != 5) { - fprintf(stderr, "Corrupt header in %s\n", filename); - exit(1); - } + symmetry) != 5) + FATAL("Corrupt header in %s\n", filename); // Set properties *isPattern = (strcmp(toLower(field), "pattern") == 0); *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0); @@ -1134,26 +1161,20 @@ strcmp(toLower(object), "matrix") || strcmp(toLower(format), "coordinate") || (strcmp(toLower(field), "real") && !(*isPattern)) || - (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) { - fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename); - exit(1); - } + (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) + FATAL("Cannot find a general sparse matrix in %s\n", filename); // Skip comments. while (true) { - if (!fgets(line, kColWidth, file)) { - fprintf(stderr, "Cannot find data in %s\n", filename); - exit(1); - } + if (!fgets(line, kColWidth, file)) + FATAL("Cannot find data in %s\n", filename); if (line[0] != '%') break; } // Next line contains M N NNZ. idata[0] = 2; // rank if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3, - idata + 1) != 3) { - fprintf(stderr, "Cannot find size in %s\n", filename); - exit(1); - } + idata + 1) != 3) + FATAL("Cannot find size in %s\n", filename); } /// Read the "extended" FROSTT header. Although not part of the documented @@ -1164,25 +1185,18 @@ uint64_t *idata) { // Skip comments. while (true) { - if (!fgets(line, kColWidth, file)) { - fprintf(stderr, "Cannot find data in %s\n", filename); - exit(1); - } + if (!fgets(line, kColWidth, file)) + FATAL("Cannot find data in %s\n", filename); if (line[0] != '#') break; } // Next line contains RANK and NNZ. - if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) { - fprintf(stderr, "Cannot find metadata in %s\n", filename); - exit(1); - } + if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) + FATAL("Cannot find metadata in %s\n", filename); // Followed by a line with the dimension sizes (one per rank). - for (uint64_t r = 0; r < idata[0]; r++) { - if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) { - fprintf(stderr, "Cannot find dimension size %s\n", filename); - exit(1); - } - } + for (uint64_t r = 0; r < idata[0]; r++) + if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) + FATAL("Cannot find dimension size %s\n", filename); fgets(line, kColWidth, file); // end of line } @@ -1193,12 +1207,10 @@ const uint64_t *shape, const uint64_t *perm) { // Open the file. + assert(filename && "Received nullptr for filename"); FILE *file = fopen(filename, "r"); - if (!file) { - assert(filename && "Received nullptr for filename"); - fprintf(stderr, "Cannot find file %s\n", filename); - exit(1); - } + if (!file) + FATAL("Cannot find file %s\n", filename); // Perform some file format dependent set up. char line[kColWidth]; uint64_t idata[512]; @@ -1209,8 +1221,7 @@ } else if (strstr(filename, ".tns")) { readExtFROSTTHeader(file, filename, line, idata); } else { - fprintf(stderr, "Unknown format %s\n", filename); - exit(1); + FATAL("Unknown format %s\n", filename); } // Prepare sparse tensor object with per-dimension sizes // and the number of nonzeros as initial capacity. @@ -1224,10 +1235,8 @@ // Read all nonzero elements. std::vector indices(rank); for (uint64_t k = 0; k < nnz; k++) { - if (!fgets(line, kColWidth, file)) { - fprintf(stderr, "Cannot find next line of data in %s\n", filename); - exit(1); - } + if (!fgets(line, kColWidth, file)) + FATAL("Cannot find next line of data in %s\n", filename); char *linePtr = line; for (uint64_t r = 0; r < rank; r++) { uint64_t idx = strtoul(linePtr, &linePtr, 10); @@ -1290,22 +1299,15 @@ // Verify that perm is a permutation of 0..(rank-1). std::vector order(perm, perm + rank); std::sort(order.begin(), order.end()); - for (uint64_t i = 0; i < rank; ++i) { - if (i != order[i]) { - fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank); - exit(1); - } - } + for (uint64_t i = 0; i < rank; ++i) + if (i != order[i]) + FATAL("Not a permutation of 0..%" PRIu64 "\n", rank); // Verify that the sparsity values are supported. - for (uint64_t i = 0; i < rank; ++i) { + for (uint64_t i = 0; i < rank; ++i) if (sparsity[i] != DimLevelType::kDense && - sparsity[i] != DimLevelType::kCompressed) { - fprintf(stderr, "Unsupported sparsity value %d\n", - static_cast(sparsity[i])); - exit(1); - } - } + sparsity[i] != DimLevelType::kCompressed) + FATAL("Unsupported sparsity value %d\n", static_cast(sparsity[i])); #endif // Convert external format to internal COO. @@ -1539,8 +1541,10 @@ CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32); // Unsupported case (add above if needed). - fputs("unsupported combination of types\n", stderr); - exit(1); + // TODO: better pretty-printing of enum values! + FATAL("unsupported combination of types: \n", + static_cast(ptrTp), static_cast(indTp), + static_cast(valTp)); } #undef CASE #undef CASE_SECSAME @@ -1704,10 +1708,8 @@ char var[80]; sprintf(var, "TENSOR%" PRIu64, id); char *env = getenv(var); - if (!env) { - fprintf(stderr, "Environment variable %s is not set\n", var); - exit(1); - } + if (!env) + FATAL("Environment variable %s is not set\n", var); return env; }