diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -48,6 +48,27 @@ return false; } +static bool isAllDimOrdered(RankedTensorType rtp) { + if (auto enc = getSparseTensorEncoding(rtp)) { + return llvm::all_of(enc.getDimLevelType(), isOrderedDLT); + } + return true; +} + +static bool hasSameDimOrdering(RankedTensorType rtp1, RankedTensorType rtp2) { + assert(rtp1.getRank() == rtp2.getRank()); + AffineMap idMap = + AffineMap::getMultiDimIdentityMap(rtp1.getRank(), rtp1.getContext()); + + auto enc1 = getSparseTensorEncoding(rtp1); + auto enc2 = getSparseTensorEncoding(rtp2); + + auto order1 = (enc1 && enc1.getDimOrdering()) ? enc1.getDimOrdering() : idMap; + auto order2 = (enc2 && enc2.getDimOrdering()) ? enc2.getDimOrdering() : idMap; + + return order1 == order2; +} + // Helper method to find zero/uninitialized allocation. static bool isAlloc(OpOperand *op, bool isZero) { Value val = op->get(); @@ -732,7 +753,13 @@ SmallVector srcSizes; sizesForTensor(rewriter, srcSizes, loc, srcTp, src); Value tmpCoo = Value(); - if (!isUniqueCOOType(srcTp)) { + // We need a tmp COO buffer if and only if + // 1. the src tensor is not a COO and + // 2. the src tensor is not ordered in the same way as the target + // tensor (e.g., src tensor is not ordered or src tensor haves a different + // dimOrdering). + if (!isUniqueCOOType(srcTp) && + !(isAllDimOrdered(srcTp) && hasSameDimOrdering(srcTp, dstTp))) { // Construct a COO tensor from the src tensor. // TODO: there may be cases for which more efficiently without // going through an intermediate COO, such as cases that only change @@ -754,32 +781,36 @@ src = rewriter.create(loc, foreachOp.getResult(0), true); } - // Sort the COO tensor so that its elements are ordered via increasing - // indices for the storage ordering of the dst tensor. - SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); - auto dynShape = {ShapedType::kDynamic}; - auto indTp = - MemRefType::get(dynShape, getIndexOverheadType(rewriter, encSrc)); - uint64_t rank = dstTp.getRank(); - // Gather the indices-arrays in the dst tensor storage order. - SmallVector xs(rank, Value()); - for (uint64_t i = 0; i < rank; i++) { - uint64_t orgDim = toOrigDim(encSrc, i); - xs[toStoredDim(encDst, orgDim)] = rewriter.create( - loc, indTp, src, rewriter.getIndexAttr(i)); - } + // Only need to sort if the srcTp is not already sorted (we faithfully take + // the guarantee from the sparse tensor encoding). + if (!isAllDimOrdered(srcTp)) { + // Sort the COO tensor so that its elements are ordered via increasing + // indices for the storage ordering of the dst tensor. + SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); + auto dynShape = {ShapedType::kDynamic}; + auto indTp = + MemRefType::get(dynShape, getIndexOverheadType(rewriter, encSrc)); + uint64_t rank = dstTp.getRank(); + // Gather the indices-arrays in the dst tensor storage order. + SmallVector xs(rank, Value()); + for (uint64_t i = 0; i < rank; i++) { + uint64_t orgDim = toOrigDim(encSrc, i); + xs[toStoredDim(encDst, orgDim)] = rewriter.create( + loc, indTp, src, rewriter.getIndexAttr(i)); + } - // Retrieve NNZ. - Value nnz = rewriter.create(loc, src); - nnz = - rewriter.create(loc, rewriter.getIndexType(), nnz); + // Retrieve NNZ. + Value nnz = rewriter.create(loc, src); + nnz = rewriter.create(loc, rewriter.getIndexType(), + nnz); - // Retrieve the values-array. - auto valTp = MemRefType::get(dynShape, srcTp.getElementType()); - Value y = rewriter.create(loc, valTp, src); + // Retrieve the values-array. + auto valTp = MemRefType::get(dynShape, srcTp.getElementType()); + Value y = rewriter.create(loc, valTp, src); - // Sort the COO tensor. - rewriter.create(loc, nnz, xs, ValueRange{y}); + // Sort the COO tensor. + rewriter.create(loc, nnz, xs, ValueRange{y}); + } // For each element in the COO tensor, insert the element to the dst tensor. SmallVector dynDstSizes; diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir @@ -104,10 +104,6 @@ // CHECK-RWT-SAME: %[[A:.*]]: tensor>) // CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-RWT: %[[D:.*]] = tensor.dim %[[A]], %[[C0]] -// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[A]] {dimension = 0 : index} -// CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[A]] -// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[A]] -// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]] jointly %[[V]] // CHECK-RWT: %[[DST:.*]] = bufferization.alloc_tensor(%[[D]]) // CHECK-RWT: %[[RET:.*]] = sparse_tensor.foreach in %[[A]] init(%[[DST]]) // CHECK-RWT: ^bb0(%[[FI2:.*]]: index, %[[FV2:.*]]: f32, %[[T:.*]]: tensor