diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h @@ -157,6 +157,63 @@ kSingletonNuNo = 8, }; +/// Check if the `DimLevelType` is dense. +constexpr MLIR_SPARSETENSOR_EXPORT bool isDenseDLT(DimLevelType dlt) { + return dlt == DimLevelType::kDense; +} + +/// Check if the `DimLevelType` is compressed (regardless of properties). +constexpr MLIR_SPARSETENSOR_EXPORT bool isCompressedDLT(DimLevelType dlt) { + switch (dlt) { + case DimLevelType::kCompressed: + case DimLevelType::kCompressedNu: + case DimLevelType::kCompressedNo: + case DimLevelType::kCompressedNuNo: + return true; + default: + return false; + } +} + +/// Check if the `DimLevelType` is singleton (regardless of properties). +constexpr MLIR_SPARSETENSOR_EXPORT bool isSingletonDLT(DimLevelType dlt) { + switch (dlt) { + case DimLevelType::kSingleton: + case DimLevelType::kSingletonNu: + case DimLevelType::kSingletonNo: + case DimLevelType::kSingletonNuNo: + return true; + default: + return false; + } +} + +/// Check if the `DimLevelType` is ordered (regardless of storage format). +constexpr MLIR_SPARSETENSOR_EXPORT bool isOrderedDLT(DimLevelType dlt) { + switch (dlt) { + case DimLevelType::kCompressedNo: + case DimLevelType::kCompressedNuNo: + case DimLevelType::kSingletonNo: + case DimLevelType::kSingletonNuNo: + return false; + default: + return true; + } +} + +/// Check if the `DimLevelType` is unique (regardless of storage format). +constexpr MLIR_SPARSETENSOR_EXPORT bool isUniqueDLT(DimLevelType dlt) { + switch (dlt) { + case DimLevelType::kCompressedNu: + case DimLevelType::kCompressedNuNo: + case DimLevelType::kSingletonNu: + case DimLevelType::kSingletonNuNo: + return false; + default: + return true; + } +} + } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h @@ -102,67 +102,30 @@ /// Get the dimension-types array, in storage-order. const std::vector &getDimTypes() const { return dimTypes; } - /// Safely check if the (storage-order) dimension uses dense storage. - bool isDenseDim(uint64_t d) const { + /// Safely lookup the level-type of the given (storage-order) dimension. + DimLevelType getDimType(uint64_t d) const { ASSERT_VALID_DIM(d); - return dimTypes[d] == DimLevelType::kDense; + return dimTypes[d]; } + /// Safely check if the (storage-order) dimension uses dense storage. + bool isDenseDim(uint64_t d) const { return isDenseDLT(getDimType(d)); } + /// Safely check if the (storage-order) dimension uses compressed storage. bool isCompressedDim(uint64_t d) const { - ASSERT_VALID_DIM(d); - switch (dimTypes[d]) { - case DimLevelType::kCompressed: - case DimLevelType::kCompressedNu: - case DimLevelType::kCompressedNo: - case DimLevelType::kCompressedNuNo: - return true; - default: - return false; - } + return isCompressedDLT(getDimType(d)); } /// Safely check if the (storage-order) dimension uses singleton storage. bool isSingletonDim(uint64_t d) const { - ASSERT_VALID_DIM(d); - switch (dimTypes[d]) { - case DimLevelType::kSingleton: - case DimLevelType::kSingletonNu: - case DimLevelType::kSingletonNo: - case DimLevelType::kSingletonNuNo: - return true; - default: - return false; - } + return isSingletonDLT(getDimType(d)); } /// Safely check if the (storage-order) dimension is ordered. - bool isOrderedDim(uint64_t d) const { - ASSERT_VALID_DIM(d); - switch (dimTypes[d]) { - case DimLevelType::kCompressedNo: - case DimLevelType::kCompressedNuNo: - case DimLevelType::kSingletonNo: - case DimLevelType::kSingletonNuNo: - return false; - default: - return true; - } - } + bool isOrderedDim(uint64_t d) const { return isOrderedDLT(getDimType(d)); } /// Safely check if the (storage-order) dimension is unique. - bool isUniqueDim(uint64_t d) const { - ASSERT_VALID_DIM(d); - switch (dimTypes[d]) { - case DimLevelType::kCompressedNu: - case DimLevelType::kCompressedNuNo: - case DimLevelType::kSingletonNu: - case DimLevelType::kSingletonNuNo: - return false; - default: - return true; - } - } + bool isUniqueDim(uint64_t d) const { return isUniqueDLT(getDimType(d)); } /// Allocate a new enumerator. #define DECL_NEWENUMERATOR(VNAME, V) \