diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -364,6 +364,9 @@ Value genToCoordinates(OpBuilder &builder, Location loc, Value tensor, Level lvl, Level cooStart); +/// Infers the result type and generates `ToCoordinatesBufferOp`. +Value genToCoordinatesBuffer(OpBuilder &builder, Location loc, Value tensor); + /// Infers the result type and generates `ToValuesOp`. Value genToValues(OpBuilder &builder, Location loc, Value tensor); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -679,6 +679,14 @@ builder.getIndexAttr(lvl)); } +Value sparse_tensor::genToCoordinatesBuffer(OpBuilder &builder, Location loc, + Value tensor) { + const auto srcTp = getSparseTensorType(tensor); + const Type crdTp = srcTp.getEncoding().getCrdType(); + const Type memTp = get1DMemRefType(crdTp, /*withLayout=*/false); + return builder.create(loc, memTp, tensor); +} + Value sparse_tensor::genToValues(OpBuilder &builder, Location loc, Value tensor) { RankedTensorType srcTp = getRankedTensorType(tensor); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -895,9 +895,7 @@ // coordinates for the storage ordering of the dst tensor. Use SortCoo // if the COO tensor has the same ordering as the dst tensor. if (dimRank > 1 && srcTp.hasSameDimToLvlMap(dstTp)) { - MemRefType coordsTp = - get1DMemRefType(encSrc.getCrdType(), /*withLayout=*/false); - Value xs = rewriter.create(loc, coordsTp, src); + Value xs = genToCoordinatesBuffer(rewriter, loc, src); rewriter.create( loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(dimRank), rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);