diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -146,12 +146,13 @@ // level should the COO has a unique level at the end. Ends by a unordered // unique singleton level. dims.push_back(SparseTensorEncodingAttr::DimLevelType::SingletonNo); + SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(src); // TODO: Maybe pick the bitwidth based on input/output tensors (probably the // largest one among them) in the original operation instead of using the // default value. auto enc = SparseTensorEncodingAttr::get( - ctx, dims, AffineMap::getMultiDimIdentityMap(rank, ctx), AffineMap(), 0, - 0); + ctx, dims, AffineMap::getMultiDimIdentityMap(rank, ctx), AffineMap(), + encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); return RankedTensorType::get(src.getShape(), src.getElementType(), enc); }