diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -33,6 +33,7 @@ #include #include +#include namespace mlir { namespace sparse_tensor { @@ -157,6 +158,14 @@ SingletonNuNo = 19, // 0b100_11 }; +/// This enum defines all the storage formats supported by the sparse compiler, +/// without the level properties. +enum class LevelFormat : uint8_t { + Dense = 4, // 0b001_00 + Compressed = 8, // 0b010_00 + Singleton = 16, // 0b100_00 +}; + /// Returns string representation of the given dimension level type. inline std::string toMLIRString(DimLevelType dlt) { switch (dlt) { @@ -231,6 +240,63 @@ return !(static_cast(dlt) & 1); } +/// Convert a DimLevelType to its corresponding LevelFormat. +/// Returns std::nullopt when input dlt is Undef. +constexpr std::optional getLevelFormat(DimLevelType dlt) { + if (dlt == DimLevelType::Undef) + return std::nullopt; + return static_cast(static_cast(dlt) & ~3); +} + +/// Convert a LevelFormat to its corresponding DimLevelType with the given +/// properties. Returns std::nullopt when the properties are not applicable for +/// the input level format. +/// TODO: factor out a new LevelProperties type so we can add new properties +/// without changing this function's signature +constexpr std::optional +getDimLevelType(LevelFormat lf, bool ordered, bool unique) { + auto dlt = static_cast(static_cast(lf) | + (ordered ? 0 : 2) | (unique ? 0 : 1)); + return isValidDLT(dlt) ? std::optional(dlt) : std::nullopt; +} + +/// Ensure the above conversion works as intended. +static_assert( + (getLevelFormat(DimLevelType::Undef) == std::nullopt && + *getLevelFormat(DimLevelType::Dense) == LevelFormat::Dense && + *getLevelFormat(DimLevelType::Compressed) == LevelFormat::Compressed && + *getLevelFormat(DimLevelType::CompressedNu) == LevelFormat::Compressed && + *getLevelFormat(DimLevelType::CompressedNo) == LevelFormat::Compressed && + *getLevelFormat(DimLevelType::CompressedNuNo) == LevelFormat::Compressed && + *getLevelFormat(DimLevelType::Singleton) == LevelFormat::Singleton && + *getLevelFormat(DimLevelType::SingletonNu) == LevelFormat::Singleton && + *getLevelFormat(DimLevelType::SingletonNo) == LevelFormat::Singleton && + *getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton), + "getLevelFormat conversion is broken"); + +static_assert( + (getDimLevelType(LevelFormat::Dense, false, true) == std::nullopt && + getDimLevelType(LevelFormat::Dense, true, false) == std::nullopt && + getDimLevelType(LevelFormat::Dense, false, false) == std::nullopt && + *getDimLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense && + *getDimLevelType(LevelFormat::Compressed, true, true) == + DimLevelType::Compressed && + *getDimLevelType(LevelFormat::Compressed, true, false) == + DimLevelType::CompressedNu && + *getDimLevelType(LevelFormat::Compressed, false, true) == + DimLevelType::CompressedNo && + *getDimLevelType(LevelFormat::Compressed, false, false) == + DimLevelType::CompressedNuNo && + *getDimLevelType(LevelFormat::Singleton, true, true) == + DimLevelType::Singleton && + *getDimLevelType(LevelFormat::Singleton, true, false) == + DimLevelType::SingletonNu && + *getDimLevelType(LevelFormat::Singleton, false, true) == + DimLevelType::SingletonNo && + *getDimLevelType(LevelFormat::Singleton, false, false) == + DimLevelType::SingletonNuNo), + "getDimLevelType conversion is broken"); + // Ensure the above predicates work as intended. static_assert((isValidDLT(DimLevelType::Undef) && isValidDLT(DimLevelType::Dense) &&