diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -132,18 +132,17 @@ auto rank = src.getRank(); SmallVector dims; - // An unordered and non-unique compressed dim at beginning unless the tensor - // is a 1D tensor. - if (rank > 1) - dims.push_back(DimLevelType::CompressedNuNo); + // An unordered and non-unique compressed dim at beginning. + dims.push_back(DimLevelType::CompressedNuNo); // TODO: it is actually ordered at the level for ordered input. // Followed by unordered non-unique n-2 singleton levels. std::fill_n(std::back_inserter(dims), rank - 2, DimLevelType::SingletonNuNo); // TODO: only if all the inputs (for concatentate) are unique at the last // level should the COO has a unique level at the end. Ends by a unordered - // unique singleton level. - dims.push_back(DimLevelType::SingletonNo); + // unique singleton level unless the tensor rank is 1. + if (rank > 1) + dims.push_back(DimLevelType::SingletonNo); SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(src); // TODO: Maybe pick the bitwidth based on input/output tensors (probably the // largest one among them) in the original operation instead of using the