diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -157,6 +157,34 @@ SingletonNuNo = 19, // 0b100_11 }; +/// Returns string representation of the given dimension level type. +inline std::string toMLIRString(DimLevelType dlt) { + switch (dlt) { + // TODO: should probably raise an error instead of printing it... + case DimLevelType::Undef: + return "\"undef\""; + case DimLevelType::Dense: + return "\"dense\""; + case DimLevelType::Compressed: + return "\"compressed\""; + case DimLevelType::CompressedNu: + return "\"compressed-nu\""; + case DimLevelType::CompressedNo: + return "\"compressed-no\""; + case DimLevelType::CompressedNuNo: + return "\"compressed-nu-no\""; + case DimLevelType::Singleton: + return "\"singleton\""; + case DimLevelType::SingletonNu: + return "\"singleton-nu\""; + case DimLevelType::SingletonNo: + return "\"singleton-no\""; + case DimLevelType::SingletonNuNo: + return "\"singleton-nu-no\""; + } + return ""; +} + /// Check that the `DimLevelType` contains a valid (possibly undefined) value. constexpr bool isValidDLT(DimLevelType dlt) { const uint8_t formatBits = static_cast(dlt) >> 2; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -160,39 +160,7 @@ // Print the struct-like storage in dictionary fashion. printer << "<{ dimLevelType = [ "; for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { - switch (getDimLevelType()[i]) { - case DimLevelType::Undef: - // TODO: should probably raise an error instead of printing it... - printer << "\"undef\""; - break; - case DimLevelType::Dense: - printer << "\"dense\""; - break; - case DimLevelType::Compressed: - printer << "\"compressed\""; - break; - case DimLevelType::CompressedNu: - printer << "\"compressed-nu\""; - break; - case DimLevelType::CompressedNo: - printer << "\"compressed-no\""; - break; - case DimLevelType::CompressedNuNo: - printer << "\"compressed-nu-no\""; - break; - case DimLevelType::Singleton: - printer << "\"singleton\""; - break; - case DimLevelType::SingletonNu: - printer << "\"singleton-nu\""; - break; - case DimLevelType::SingletonNo: - printer << "\"singleton-no\""; - break; - case DimLevelType::SingletonNuNo: - printer << "\"singleton-nu-no\""; - break; - } + printer << toMLIRString(getDimLevelType()[i]); if (i != e - 1) printer << ", "; }