diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -169,6 +169,9 @@ /// // TODO: We should generalize TwoOutOfFour to N out of M and use property to // encode the value of N and M. +// TODO: Update DimLevelType to use lower 8 bits for storage formats and the +// higher 4 bits to store level properties. Consider CompressedWithHi and +// TwoOutOfFour as properties instead of formats. enum class DimLevelType : uint8_t { Undef = 0, // 0b00000_00 Dense = 4, // 0b00001_00 @@ -197,6 +200,14 @@ TwoOutOfFour = 64, // 0b10000_00 }; +/// This enum defines all the nondefault properties for storage formats. +enum class LevelNondefaultProperty : uint8_t { + Nonunique = 1, // 0b00000_01 + Nonordered = 2, // 0b00000_10 + High = 32, // 0b01000_00 + Block2_4 = 64 // 0b10000_00 +}; + /// Returns string representation of the given dimension level type. constexpr const char *toMLIRString(DimLevelType dlt) { switch (dlt) { diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp @@ -391,7 +391,7 @@ os << '{'; llvm::interleaveComma( lvlSpecs, os, [&](LvlSpec const &spec) { os << spec.getBoundVar(); }); - os << '}'; + os << "} "; } // Dimension specifiers. diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp @@ -354,7 +354,7 @@ const auto type = lvlTypeParser.parseLvlType(parser); FAILURE_IF_FAILED(type) - lvlSpecs.emplace_back(var, expr, *type); + lvlSpecs.emplace_back(var, expr, static_cast(*type)); return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h @@ -9,56 +9,19 @@ #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_LVLTYPEPARSER_H #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_LVLTYPEPARSER_H -#include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/IR/OpImplementation.h" -#include "llvm/ADT/StringMap.h" namespace mlir { namespace sparse_tensor { namespace ir_detail { -//===----------------------------------------------------------------------===// -// These macros are for generating a C++ expression of type -// `std::initializer_list>` since there's -// no way to construct an object of that type directly via C++ code. -#define FOREVERY_LEVELTYPE(DO) \ - DO(DimLevelType::Dense) \ - DO(DimLevelType::Compressed) \ - DO(DimLevelType::CompressedNu) \ - DO(DimLevelType::CompressedNo) \ - DO(DimLevelType::CompressedNuNo) \ - DO(DimLevelType::Singleton) \ - DO(DimLevelType::SingletonNu) \ - DO(DimLevelType::SingletonNo) \ - DO(DimLevelType::SingletonNuNo) \ - DO(DimLevelType::CompressedWithHi) \ - DO(DimLevelType::CompressedWithHiNu) \ - DO(DimLevelType::CompressedWithHiNo) \ - DO(DimLevelType::CompressedWithHiNuNo) \ - DO(DimLevelType::TwoOutOfFour) -#define LEVELTYPE_INITLIST_ELEMENT(lvlType) \ - std::make_pair(StringRef(toMLIRString(lvlType)), lvlType), -#define LEVELTYPE_INITLIST \ - { FOREVERY_LEVELTYPE(LEVELTYPE_INITLIST_ELEMENT) } - -// TODO(wrengr): Since this parser is non-trivial to construct, is there -// any way to hook into the parsing process so that we construct it only once -// at the begining of parsing and then destroy it once parsing has finished? class LvlTypeParser { - const llvm::StringMap map; - public: - explicit LvlTypeParser() : map(LEVELTYPE_INITLIST) {} -#undef LEVELTYPE_INITLIST -#undef LEVELTYPE_INITLIST_ELEMENT -#undef FOREVERY_LEVELTYPE + LvlTypeParser() = default; + FailureOr parseLvlType(AsmParser &parser) const; - std::optional lookup(StringRef str) const; - std::optional lookup(StringAttr str) const; - ParseResult parseLvlType(AsmParser &parser, DimLevelType &out) const; - FailureOr parseLvlType(AsmParser &parser) const; - // TODO(wrengr): `parseOptionalLvlType`? - // TODO(wrengr): `parseLvlTypeList`? +private: + ParseResult parseProperty(AsmParser &parser, uint8_t *properties) const; }; } // namespace ir_detail diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "LvlTypeParser.h" +#include "mlir/Dialect/SparseTensor/IR/Enums.h" using namespace mlir; using namespace mlir::sparse_tensor; @@ -46,34 +47,57 @@ // `LvlTypeParser` implementation. //===----------------------------------------------------------------------===// -std::optional LvlTypeParser::lookup(StringRef str) const { - // NOTE: `StringMap::lookup` will return a default-constructed value if - // the key isn't found; which for enums means zero, and therefore makes - // it impossible to distinguish between actual zero-DimLevelType vs - // not-found. Whereas `StringMap::at` asserts that the key is found, - // which we don't want either. - const auto it = map.find(str); - return it == map.end() ? std::nullopt : std::make_optional(it->second); -} +FailureOr LvlTypeParser::parseLvlType(AsmParser &parser) const { + StringRef base; + FAILURE_IF_FAILED(parser.parseOptionalKeyword(&base)); + uint8_t properties = 0; + const auto loc = parser.getCurrentLocation(); -std::optional LvlTypeParser::lookup(StringAttr str) const { - return str ? lookup(str.getValue()) : std::nullopt; -} + ParseResult res = parser.parseCommaSeparatedList( + mlir::OpAsmParser::Delimiter::OptionalParen, + [&]() -> ParseResult { return parseProperty(parser, &properties); }, + " in level property list"); + FAILURE_IF_FAILED(res) -FailureOr LvlTypeParser::parseLvlType(AsmParser &parser) const { - DimLevelType out; - FAILURE_IF_FAILED(parseLvlType(parser, out)) - return out; + // Set the base bit for properties. + if (base.compare("dense") == 0) { + properties |= static_cast(LevelFormat::Dense); + } else if (base.compare("compressed") == 0) { + // TODO: Remove this condition once dimLvlType enum is refactored. Current + // enum treats High and TwoOutOfFour as formats instead of properties. + if (!(properties & static_cast(LevelNondefaultProperty::High) || + properties & + static_cast(LevelNondefaultProperty::Block2_4))) { + properties |= static_cast(LevelFormat::Compressed); + } + } else if (base.compare("singleton") == 0) { + properties |= static_cast(LevelFormat::Singleton); + } else { + parser.emitError(loc, "unknown level format"); + return failure(); + } + + ERROR_IF(!isValidDLT(static_cast(properties)), + "invalid level type"); + return properties; } -ParseResult LvlTypeParser::parseLvlType(AsmParser &parser, - DimLevelType &out) const { - const auto loc = parser.getCurrentLocation(); +ParseResult LvlTypeParser::parseProperty(AsmParser &parser, + uint8_t *properties) const { StringRef strVal; FAILURE_IF_FAILED(parser.parseOptionalKeyword(&strVal)); - const auto lvlType = lookup(strVal); - ERROR_IF(!lvlType, "unknown level-type '" + strVal + "'") - out = *lvlType; + if (strVal.compare("nonunique") == 0) { + *properties |= static_cast(LevelNondefaultProperty::Nonunique); + } else if (strVal.compare("nonordered") == 0) { + *properties |= static_cast(LevelNondefaultProperty::Nonordered); + } else if (strVal.compare("high") == 0) { + *properties |= static_cast(LevelNondefaultProperty::High); + } else if (strVal.compare("block2_4") == 0) { + *properties |= static_cast(LevelNondefaultProperty::Block2_4); + } else { + parser.emitError(parser.getCurrentLocation(), "unknown level property"); + return failure(); + } return success(); } diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir @@ -55,7 +55,7 @@ // ----- #COO = #sparse_tensor.encoding<{ - lvlTypes = [ "compressed_nu_no", "singleton_no" ] + map = (d0, d1) -> (d0 : compressed(nonunique, nonordered), d1 : singleton(nonordered)) }> // CHECK-LABEL: func private @sparse_coo( @@ -65,7 +65,7 @@ // ----- #BCOO = #sparse_tensor.encoding<{ - lvlTypes = [ "dense", "compressed_hi_nu", "singleton" ] + map = (d0, d1, d2) -> (d0 : dense, d1 : compressed(nonunique, high), d2 : singleton) }> // CHECK-LABEL: func private @sparse_bcoo( @@ -75,7 +75,7 @@ // ----- #SortedCOO = #sparse_tensor.encoding<{ - lvlTypes = [ "compressed_nu", "singleton" ] + map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }> // CHECK-LABEL: func private @sparse_sorted_coo( @@ -144,7 +144,7 @@ // below) to encode a 2D matrix, but it would require dim2lvl mapping which is not ready yet. // So we take the simple path for now. #NV_24= #sparse_tensor.encoding<{ - lvlTypes = [ "dense", "compressed24" ], + map = (d0, d1) -> (d0 : dense, d1 : compressed(block2_4)) }> // CHECK-LABEL: func private @sparse_2_out_of_4( @@ -195,7 +195,7 @@ map = ( i, j ) -> ( i : dense, j floordiv 4 : dense, - j mod 4 : compressed24 + j mod 4 : compressed(block2_4) ) }>