diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -132,18 +132,19 @@ auto rank = src.getRank(); SmallVector dims; - // An unordered and non-unique compressed dim at beginning unless the tensor - // is a 1D tensor. - if (rank > 1) - dims.push_back(DimLevelType::CompressedNuNo); - - // TODO: it is actually ordered at the level for ordered input. - // Followed by unordered non-unique n-2 singleton levels. - std::fill_n(std::back_inserter(dims), rank - 2, DimLevelType::SingletonNuNo); - // TODO: only if all the inputs (for concatentate) are unique at the last - // level should the COO has a unique level at the end. Ends by a unordered - // unique singleton level. - dims.push_back(DimLevelType::SingletonNo); + // An unordered and non-unique compressed dim at beginning. + dims.push_back(DimLevelType::CompressedNuNo); + + if (rank > 1) { + // TODO: it is actually ordered at the level for ordered input. + // Followed by unordered non-unique n-2 singleton levels. + std::fill_n(std::back_inserter(dims), rank - 2, + DimLevelType::SingletonNuNo); + // TODO: only if all the inputs (for concatentate) are unique at the last + // level should the COO has a unique level at the end. Ends by a unordered + // unique singleton level unless the tensor rank is 1. + dims.push_back(DimLevelType::SingletonNo); + } SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(src); // TODO: Maybe pick the bitwidth based on input/output tensors (probably the // largest one among them) in the original operation instead of using the