diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -50,6 +50,33 @@ // Dimension level types. // +/// Returns string representation of the given dimension level type. +inline StringRef toString(DimLevelType dlt) { + switch (dlt) { + case DimLevelType::Undef: + return "\"undef\""; + case DimLevelType::Dense: + return "\"dense\""; + case DimLevelType::Compressed: + return "\"compressed\""; + case DimLevelType::CompressedNu: + return "\"compressed-nu\""; + case DimLevelType::CompressedNo: + return "\"compressed-no\""; + case DimLevelType::CompressedNuNo: + return "\"compressed-nu-no\""; + case DimLevelType::Singleton: + return "\"singleton\""; + case DimLevelType::SingletonNu: + return "\"singleton-nu\""; + case DimLevelType::SingletonNo: + return "\"singleton-no\""; + case DimLevelType::SingletonNuNo: + return "\"singleton-nu-no\""; + } + llvm_unreachable("unknown DimLevelType"); +} + // MSVC does not allow this function to be constexpr, because // `SparseTensorEncodingAttr::operator bool` isn't declared constexpr. // And therefore all functions calling it cannot be constexpr either. diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -160,39 +160,7 @@ // Print the struct-like storage in dictionary fashion. printer << "<{ dimLevelType = [ "; for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { - switch (getDimLevelType()[i]) { - case DimLevelType::Undef: - // TODO: should probably raise an error instead of printing it... - printer << "\"undef\""; - break; - case DimLevelType::Dense: - printer << "\"dense\""; - break; - case DimLevelType::Compressed: - printer << "\"compressed\""; - break; - case DimLevelType::CompressedNu: - printer << "\"compressed-nu\""; - break; - case DimLevelType::CompressedNo: - printer << "\"compressed-no\""; - break; - case DimLevelType::CompressedNuNo: - printer << "\"compressed-nu-no\""; - break; - case DimLevelType::Singleton: - printer << "\"singleton\""; - break; - case DimLevelType::SingletonNu: - printer << "\"singleton-nu\""; - break; - case DimLevelType::SingletonNo: - printer << "\"singleton-no\""; - break; - case DimLevelType::SingletonNuNo: - printer << "\"singleton-nu-no\""; - break; - } + printer << toString(getDimLevelType()[i]); if (i != e - 1) printer << ", "; }