diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -358,9 +358,12 @@ Value srcTensor = op.getSrc(); auto srcTp = getRankedTensorType(srcTensor); auto dstTp = getRankedTensorType(op.getResult()); - SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); - SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); - if (!encDst || !encSrc) { + + SparseTensorType srcStt(srcTp); + SparseTensorType dstStt(dstTp); + + const auto encSrc = srcStt.getEncoding(); + if (!srcStt.hasEncoding() || !dstStt.hasEncoding()) { return failure(); } @@ -382,22 +385,29 @@ dstDynSizes.push_back(dstSizes[idx]); } } - - // Implement the sparse2sparse reshape as follows: - // %tmp = bufferization.alloc_tensor : unordered COO - // foreach srcCoords %srcTensor - // insert reshapeCvs(srcCoords), %tmp - // %t = sparse_tensor.cast %tmp Value nnz = rewriter.create(loc, srcTensor); - RankedTensorType cooTp = getUnorderedCOOFromType(dstTp); - Value cooBuffer = + // Only need a unordered COO buffer if input and output are not sorted + // in the same way. + Type bufferTp = + srcStt.isAllOrdered() && srcStt.isIdentity() && dstStt.isIdentity() + ? dstTp + : getUnorderedCOOFromType(dstTp); + + Value buffer = rewriter - .create(loc, cooTp, dstDynSizes, Value(), + .create(loc, bufferTp, dstDynSizes, Value(), /*sizeHint=*/nnz, Attribute()) .getResult(); + // Implement the sparse2sparse reshape as follows: + // foreach srcCoords %srcTensor + // insert reshapeCvs(srcCoords), %buffer + // + // followed by an optional + // %t = sparse_tensor.cast %tmp + // depending on whether the input/output are sorted in the same way. ForeachOp foreachOp = rewriter.create( - loc, srcTensor, cooBuffer, + loc, srcTensor, buffer, [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v, ValueRange reduc) { const Dimension dimRank = srcTp.getRank(); @@ -414,10 +424,14 @@ auto t = builder.create(loc, v, reduc.front(), dstDcvs); builder.create(loc, t); }); - auto t = rewriter.create(loc, foreachOp.getResult(0), true); - auto converted = rewriter.create(loc, dstTp, t).getResult(); - rewriter.create(loc, t); - rewriter.replaceOp(op, converted); + + Value t = rewriter.create(loc, foreachOp.getResult(0), true); + if (bufferTp != dstTp) { + Value converted = rewriter.create(loc, dstTp, t).getResult(); + rewriter.create(loc, t); + t = converted; + } + rewriter.replaceOp(op, t); return success(); } }; diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir @@ -61,8 +61,8 @@ // CHECK-RWT: scf.yield %[[NT:.*]] // CHECK-RWT: } // CHECK-RWT: %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts -// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[NT1]] -// CHECK-RWT: return %[[T]] : tensor<10x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// CHECK-RWT-NOT: sparse_tensor.convert +// CHECK-RWT: return %[[NT1]] : tensor<10x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> // func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> { %0 = tensor.expand_shape %arg0 [[0, 1]] : @@ -134,8 +134,8 @@ // CHECK-RWT scf.yield %[[RET_1]] // CHECK-RWT: } // CHECK-RWT: %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts -// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[NT1]] -// CHECK-RWT: return %[[T]] : tensor<100xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> +// CHECK-RWT-NOT: sparse_tensor.convert +// CHECK-RWT: return %[[NT1]] : tensor<100xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> // func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<100xf64, #SparseVector> { %0 = tensor.collapse_shape %arg0 [[0, 1]] : @@ -209,8 +209,8 @@ // CHECK-RWT: scf.yield %[[NT]] // CHECK-RWT: } // CHECK-RWT: %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts -// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[NT1]] -// CHECK-RWT: return %[[T]] : tensor> +// CHECK-RWT-NOT: sparse_tensor.convert +// CHECK-RWT: return %[[NT1]] : tensor> // func.func @dynamic_sparse_expand(%arg0: tensor) -> tensor { %0 = tensor.expand_shape %arg0 [[0, 1]] : @@ -291,8 +291,8 @@ // CHECK-RWT scf.yield %[[RET_1]] // CHECK-RWT: } // CHECK-RWT: %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts -// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[NT1]] -// CHECK-RWT: return %[[T]] : tensor> +// CHECK-RWT-NOT: sparse_tensor.convert +// CHECK-RWT: return %[[NT1]] : tensor> // func.func @dynamic_sparse_collapse(%arg0: tensor<10x?xf64, #SparseMatrix>) -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1]] :