diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h @@ -96,6 +96,10 @@ /// Checks if a header has been successfully read. bool isValid() const { return valueKind_ != ValueKind::kInvalid; } + /// Checks if the file's ValueKind can be converted into the given + /// tensor PrimaryType. Is only valid after parsing the header. + bool canReadAs(PrimaryType valTy) const; + /// Gets the MME "pattern" property setting. Is only valid after /// parsing the header. bool isPattern() const { @@ -210,16 +214,10 @@ stfile.openFile(); stfile.readHeader(); // Check tensor element type against the value type in the input file. - SparseTensorFile::ValueKind valueKind = stfile.getValueKind(); - bool tensorIsInteger = - (valTp >= PrimaryType::kI64 && valTp <= PrimaryType::kI8); - bool tensorIsReal = (valTp >= PrimaryType::kF64 && valTp <= PrimaryType::kI8); - if ((valueKind == SparseTensorFile::ValueKind::kReal && tensorIsInteger) || - (valueKind == SparseTensorFile::ValueKind::kComplex && tensorIsReal)) { + if (!stfile.canReadAs(valTp)) MLIR_SPARSETENSOR_FATAL( "Tensor element type %d not compatible with values in file %s\n", static_cast(valTp), filename); - } stfile.assertMatchesShape(rank, shape); // Prepare sparse tensor object with per-dimension sizes // and the number of nonzeros as initial capacity. diff --git a/mlir/lib/ExecutionEngine/SparseTensor/File.cpp b/mlir/lib/ExecutionEngine/SparseTensor/File.cpp --- a/mlir/lib/ExecutionEngine/SparseTensor/File.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensor/File.cpp @@ -83,6 +83,33 @@ "Dimension size mismatch"); } +bool SparseTensorFile::canReadAs(PrimaryType valTy) const { + switch (valueKind_) { + case ValueKind::kInvalid: + assert(false && "Must readHeader() before calling canReadAs()"); + return false; // In case assertions are disabled. + case ValueKind::kPattern: + return true; + case ValueKind::kInteger: + // When the file is specified to store integer values, we still + // allow implicitly converting those to floating primary-types. + return isRealPrimaryType(valTy); + case ValueKind::kReal: + // When the file is specified to store real/floating values, then + // we disallow implicit conversion to integer primary-types. + return isFloatingPrimaryType(valTy); + case ValueKind::kComplex: + // When the file is specified to store complex values, then we + // require a complex primary-type. + return isComplexPrimaryType(valTy); + case ValueKind::kUndefined: + // The "extended" FROSTT format doesn't specify a ValueKind. + // So we allow implicitly converting the stored values to both + // integer and floating primary-types. + return isRealPrimaryType(valTy); + } +} + /// Helper to convert C-style strings (i.e., '\0' terminated) to lower case. static inline char *toLower(char *token) { for (char *c = token; *c; ++c)