diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -26,15 +26,19 @@ /// If updating, keep them in sync and update the static_assert in the impl /// file. enum MlirSparseTensorDimLevelType { - MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4, // 0b001_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8, // 0b010_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9, // 0b010_01 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10, // 0b010_10 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b010_11 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16, // 0b100_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17, // 0b100_01 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18, // 0b100_10 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19, // 0b100_11 + MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4, // 0b0001_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8, // 0b0010_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9, // 0b0010_01 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10, // 0b0010_10 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b0010_11 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16, // 0b0100_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17, // 0b0100_01 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18, // 0b0100_10 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19, // 0b0100_11 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI = 32, // 0b1000_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU = 33, // 0b1000_01 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NO = 34, // 0b1000_10 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU_NO = 35, // 0b1000_11 }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -172,24 +172,29 @@ /// It should not be used externally, since it does not indicate an /// actual/representable format. enum class DimLevelType : uint8_t { - Undef = 0, // 0b000_00 - Dense = 4, // 0b001_00 - Compressed = 8, // 0b010_00 - CompressedNu = 9, // 0b010_01 - CompressedNo = 10, // 0b010_10 - CompressedNuNo = 11, // 0b010_11 - Singleton = 16, // 0b100_00 - SingletonNu = 17, // 0b100_01 - SingletonNo = 18, // 0b100_10 - SingletonNuNo = 19, // 0b100_11 + Undef = 0, // 0b0000_00 + Dense = 4, // 0b0001_00 + Compressed = 8, // 0b0010_00 + CompressedNu = 9, // 0b0010_01 + CompressedNo = 10, // 0b0010_10 + CompressedNuNo = 11, // 0b0010_11 + Singleton = 16, // 0b0100_00 + SingletonNu = 17, // 0b0100_01 + SingletonNo = 18, // 0b0100_10 + SingletonNuNo = 19, // 0b0100_11 + CompressedWithHi = 32, // 0b1000_00 + CompressedWithHiNu = 33, // 0b1000_01 + CompressedWithHiNo = 34, // 0b1000_10 + CompressedWithHiNuNo = 35, // 0b1000_11 }; /// This enum defines all the storage formats supported by the sparse compiler, /// without the level properties. enum class LevelFormat : uint8_t { - Dense = 4, // 0b001_00 - Compressed = 8, // 0b010_00 - Singleton = 16, // 0b100_00 + Dense = 4, // 0b0001_00 + Compressed = 8, // 0b0010_00 + Singleton = 16, // 0b0100_00 + CompressedWithHi = 32, // 0b1000_00 }; /// Returns string representation of the given dimension level type. @@ -216,6 +221,14 @@ return "singleton-no"; case DimLevelType::SingletonNuNo: return "singleton-nu-no"; + case DimLevelType::CompressedWithHi: + return "compressed-hi"; + case DimLevelType::CompressedWithHiNu: + return "compressed-hi-nu"; + case DimLevelType::CompressedWithHiNo: + return "compressed-hi-no"; + case DimLevelType::CompressedWithHiNuNo: + return "compressed-hi-nu-no"; } return ""; } @@ -226,8 +239,9 @@ const uint8_t propertyBits = static_cast(dlt) & 3; // If undefined or dense, then must be unique and ordered. // Otherwise, the format must be one of the known ones. - return (formatBits <= 1) ? (propertyBits == 0) - : (formatBits == 2 || formatBits == 4); + return (formatBits <= 1) + ? (propertyBits == 0) + : (formatBits == 2 || formatBits == 4 || formatBits == 8); } /// Check if the `DimLevelType` is the special undefined value. @@ -250,6 +264,12 @@ static_cast(DimLevelType::Compressed); } +/// Check if the `DimLevelType` is compressed (regardless of properties). +constexpr bool isCompressedWithHiDLT(DimLevelType dlt) { + return (static_cast(dlt) & ~3) == + static_cast(DimLevelType::CompressedWithHi); +} + /// Check if the `DimLevelType` is singleton (regardless of properties). constexpr bool isSingletonDLT(DimLevelType dlt) { return (static_cast(dlt) & ~3) == @@ -333,7 +353,11 @@ isValidDLT(DimLevelType::Singleton) && isValidDLT(DimLevelType::SingletonNu) && isValidDLT(DimLevelType::SingletonNo) && - isValidDLT(DimLevelType::SingletonNuNo)), + isValidDLT(DimLevelType::SingletonNuNo) && + isValidDLT(DimLevelType::CompressedWithHi) && + isValidDLT(DimLevelType::CompressedWithHiNu) && + isValidDLT(DimLevelType::CompressedWithHiNo) && + isValidDLT(DimLevelType::CompressedWithHiNuNo)), "isValidDLT definition is broken"); static_assert((!isCompressedDLT(DimLevelType::Dense) && @@ -347,6 +371,17 @@ !isCompressedDLT(DimLevelType::SingletonNuNo)), "isCompressedDLT definition is broken"); +static_assert((!isCompressedWithHiDLT(DimLevelType::Dense) && + isCompressedWithHiDLT(DimLevelType::CompressedWithHi) && + isCompressedWithHiDLT(DimLevelType::CompressedWithHiNu) && + isCompressedWithHiDLT(DimLevelType::CompressedWithHiNo) && + isCompressedWithHiDLT(DimLevelType::CompressedWithHiNuNo) && + !isCompressedWithHiDLT(DimLevelType::Singleton) && + !isCompressedWithHiDLT(DimLevelType::SingletonNu) && + !isCompressedWithHiDLT(DimLevelType::SingletonNo) && + !isCompressedWithHiDLT(DimLevelType::SingletonNuNo)), + "isCompressedWithHiDLT definition is broken"); + static_assert((!isSingletonDLT(DimLevelType::Dense) && !isSingletonDLT(DimLevelType::Compressed) && !isSingletonDLT(DimLevelType::CompressedNu) && @@ -366,7 +401,11 @@ isOrderedDLT(DimLevelType::Singleton) && isOrderedDLT(DimLevelType::SingletonNu) && !isOrderedDLT(DimLevelType::SingletonNo) && - !isOrderedDLT(DimLevelType::SingletonNuNo)), + !isOrderedDLT(DimLevelType::SingletonNuNo) && + isOrderedDLT(DimLevelType::CompressedWithHi) && + isOrderedDLT(DimLevelType::CompressedWithHiNu) && + !isOrderedDLT(DimLevelType::CompressedWithHiNo) && + !isOrderedDLT(DimLevelType::CompressedWithHiNuNo)), "isOrderedDLT definition is broken"); static_assert((isUniqueDLT(DimLevelType::Dense) && @@ -377,7 +416,11 @@ isUniqueDLT(DimLevelType::Singleton) && !isUniqueDLT(DimLevelType::SingletonNu) && isUniqueDLT(DimLevelType::SingletonNo) && - !isUniqueDLT(DimLevelType::SingletonNuNo)), + !isUniqueDLT(DimLevelType::SingletonNuNo) && + isUniqueDLT(DimLevelType::CompressedWithHi) && + !isUniqueDLT(DimLevelType::CompressedWithHiNu) && + isUniqueDLT(DimLevelType::CompressedWithHiNo) && + !isUniqueDLT(DimLevelType::CompressedWithHiNuNo)), "isUniqueDLT definition is broken"); } // namespace sparse_tensor diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -337,6 +337,7 @@ bool isDenseLvl(::mlir::sparse_tensor::Level l) const { return isDenseDLT(getLvlType(l)); } bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedDLT(getLvlType(l)); } + bool isCompressedWithHiLvl(::mlir::sparse_tensor::Level l) const { return isCompressedWithHiDLT(getLvlType(l)); } bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonDLT(getLvlType(l)); } bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedDLT(getLvlType(l)); } bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueDLT(getLvlType(l)); } diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -26,7 +26,14 @@ .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) .value("singleton-nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) .value("singleton-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) - .value("singleton-nu-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO); + .value("singleton-nu-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) + .value("compressed-hi", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI) + .value("compressed-hi-nu", + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU) + .value("compressed-hi-no", + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NO) + .value("compressed-hi-nu-no", + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU_NO); mlir_attribute_subclass(m, "EncodingAttr", mlirAttributeIsASparseTensorEncodingAttr) diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -198,12 +198,19 @@ return getStaticDimSliceStride(toOrigDim(*this, lvl)); } -const static DimLevelType validDLTs[] = { - DimLevelType::Dense, DimLevelType::Compressed, - DimLevelType::CompressedNu, DimLevelType::CompressedNo, - DimLevelType::CompressedNuNo, DimLevelType::Singleton, - DimLevelType::SingletonNu, DimLevelType::SingletonNo, - DimLevelType::SingletonNuNo}; +const static DimLevelType validDLTs[] = {DimLevelType::Dense, + DimLevelType::Compressed, + DimLevelType::CompressedNu, + DimLevelType::CompressedNo, + DimLevelType::CompressedNuNo, + DimLevelType::Singleton, + DimLevelType::SingletonNu, + DimLevelType::SingletonNo, + DimLevelType::SingletonNuNo, + DimLevelType::CompressedWithHi, + DimLevelType::CompressedWithHiNu, + DimLevelType::CompressedWithHiNo, + DimLevelType::CompressedWithHiNuNo}; static std::optional parseDLT(StringRef str) { for (DimLevelType dlt : validDLTs) diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir @@ -55,6 +55,16 @@ // ----- +#BCOO = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed-hi-nu", "singleton" ] +}> + +// CHECK-LABEL: func private @sparse_bcoo( +// CHECK-SAME: tensor>) +func.func private @sparse_bcoo(tensor) + +// ----- + #SortedCOO = #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>