diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -33,6 +33,7 @@ #include #include +#include namespace mlir { namespace sparse_tensor { @@ -157,6 +158,16 @@ SingletonNuNo = 19, // 0b100_11 }; +/// This enum defines all the sparse representations supportable by +/// the SparseTensor dialect. Unlike DimLevelType, it does not encode level +/// properties, which are irrelevant to sparse tensor storage scheme. +enum class LevelFormat : uint8_t { + Undef = 0, // 0b000_00 + Dense = 4, // 0b001_00 + Compressed = 8, // 0b010_00 + Singleton = 16, // 0b100_00 +}; + /// Returns string representation of the given dimension level type. inline std::string toMLIRString(DimLevelType dlt) { switch (dlt) { @@ -231,6 +242,66 @@ return !(static_cast(dlt) & 1); } +/// Convert a DimLevelType to its corresponding StorageLevelType. +constexpr LevelFormat getLevelFormat(DimLevelType dlt) { + return static_cast(static_cast(dlt) & ~3); +} + +/// Convert a DimLevelType to its corresponding StorageLevelType. +constexpr std::optional +getDimLevelType(LevelFormat lf, bool ordered, bool unique) { + auto dlt = static_cast(static_cast(lf) | + (ordered ? 0 : 2) | (unique ? 0 : 1)); + if (isValidDLT(dlt)) + return dlt; + return std::nullopt; +} + +/// Ensure both StorageLevelType and DimLevelType have the same underlying type. +static_assert(std::is_same_v, + std::underlying_type_t>); + +/// Ensure the above conversion works as intended. +static_assert( + (getLevelFormat(DimLevelType::Undef) == LevelFormat::Undef && + getLevelFormat(DimLevelType::Dense) == LevelFormat::Dense && + getLevelFormat(DimLevelType::Compressed) == LevelFormat::Compressed && + getLevelFormat(DimLevelType::CompressedNu) == LevelFormat::Compressed && + getLevelFormat(DimLevelType::CompressedNo) == LevelFormat::Compressed && + getLevelFormat(DimLevelType::CompressedNuNo) == LevelFormat::Compressed && + getLevelFormat(DimLevelType::Singleton) == LevelFormat::Singleton && + getLevelFormat(DimLevelType::SingletonNu) == LevelFormat::Singleton && + getLevelFormat(DimLevelType::SingletonNo) == LevelFormat::Singleton && + getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton), + "getLevelFormat conversion is broken"); + +static_assert( + (getDimLevelType(LevelFormat::Undef, false, true) == std::nullopt && + getDimLevelType(LevelFormat::Undef, true, false) == std::nullopt && + getDimLevelType(LevelFormat::Undef, false, false) == std::nullopt && + getDimLevelType(LevelFormat::Dense, false, true) == std::nullopt && + getDimLevelType(LevelFormat::Dense, true, false) == std::nullopt && + getDimLevelType(LevelFormat::Dense, false, false) == std::nullopt && + *getDimLevelType(LevelFormat::Undef, true, true) == DimLevelType::Undef && + *getDimLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense && + *getDimLevelType(LevelFormat::Compressed, true, true) == + DimLevelType::Compressed && + *getDimLevelType(LevelFormat::Compressed, true, false) == + DimLevelType::CompressedNu && + *getDimLevelType(LevelFormat::Compressed, false, true) == + DimLevelType::CompressedNo && + *getDimLevelType(LevelFormat::Compressed, false, false) == + DimLevelType::CompressedNuNo && + *getDimLevelType(LevelFormat::Singleton, true, true) == + DimLevelType::Singleton && + *getDimLevelType(LevelFormat::Singleton, true, false) == + DimLevelType::SingletonNu && + *getDimLevelType(LevelFormat::Singleton, false, true) == + DimLevelType::SingletonNo && + *getDimLevelType(LevelFormat::Singleton, false, false) == + DimLevelType::SingletonNuNo), + "getDimLevelType conversion is broken"); + // Ensure the above predicates work as intended. static_assert((isValidDLT(DimLevelType::Undef) && isValidDLT(DimLevelType::Dense) &&