diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -669,16 +669,19 @@ } const auto encDst = dstTp.getEncoding(); - // We don't need a temporary COO tensor for dense => sparse conversion. - const RankedTensorType bufferTp = dstTp.getRankedTensorType(); + // We don't need a temporary COO tensor if the destination has an identity + // ordering. Otherwise, we use the destination ordering for the temporary + // COO tensor. + // TODO: enhance foreachOp to take ordering to remove the need of a + // temporary COO tensor here. + const RankedTensorType bufferTp = dstTp.isIdentity() + ? dstTp.getRankedTensorType() + : getUnorderedCOOFromTypeWithOrdering( + dstTp, dstTp.getDimToLvlMap()); auto buffer = rewriter.create(loc, bufferTp, dynSizes).getResult(); - AffineMapAttr foreachOrder = nullptr; - if (encDst.getDimOrdering()) - foreachOrder = AffineMapAttr::get(encDst.getDimOrdering()); - auto foreachOp = rewriter.create( - loc, src, buffer, foreachOrder, + loc, src, buffer, [&](OpBuilder &builder, Location loc, ValueRange indices, Value v, ValueRange reduc) { Value input = reduc.front(); @@ -706,8 +709,14 @@ }); rewriter.setInsertionPointAfter(op); src = rewriter.create(loc, foreachOp.getResult(0), true); + if (bufferTp != dstTp) { + rewriter.replaceOpWithNewOp(op, dstTp.getRankedTensorType(), + src); + rewriter.create(loc, src); + } else { + rewriter.replaceOp(op, src); + } - rewriter.replaceOp(op, src); return success(); } diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir @@ -183,17 +183,30 @@ return %1 : tensor<8x7xf32, #CSR> } -// CHECK-RWT-LABEL: func.func @sparse_constant_csc() -> tensor<8x7xf32, -// CHECK-RWT: %[[VAL_0:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32> -// CHECK-RWT: %[[VAL_1:.*]] = bufferization.alloc_tensor() : -// CHECK-RWT: %[[VAL_2:.*]] = sparse_tensor.foreach in %[[VAL_0]] init(%[[VAL_1]]) {order = #map} : tensor<8x7xf32>, -// CHECK-RWT: ^bb0(%[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: tensor -// CHECK-RWT: %[[VAL_7:.*]] = sparse_tensor.insert %[[VAL_5]] into %[[VAL_6]]{{\[}}%[[VAL_4]], %[[VAL_3]]] : -// CHECK-RWT: sparse_tensor.yield %[[VAL_7]] : -// CHECK-RWT: } -// CHECK-RWT: %[[VAL_8:.*]] = sparse_tensor.load %[[VAL_9:.*]] hasInserts : -// CHECK-RWT: return %[[VAL_8]] : -// CHECK-RWT: } +// CHECK-RWT-LABEL: func.func @sparse_constant_csc() -> tensor<8x7xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>> { +// CHECK-RWT: %[[F0:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32> +// CHECK-RWT: %[[T0:.*]] = bufferization.alloc_tensor() +// CHECK-RWT: %[[T1:.*]] = sparse_tensor.foreach in %[[F0]] init(%[[T0]]) +// CHECK-RWT: ^bb0(%[[L0I0:.*]]: index, %[[L0I1:.*]]: index, %[[L0V:.*]]: f32, %[[L0T:.*]]: tensor +// CHECK-RWT: %[[L0T2:.*]] = sparse_tensor.insert %[[L0V]] into %[[L0T]]{{\[}}%[[L0I1]], %[[L0I0]]] +// CHECK-RWT: sparse_tensor.yield %[[L0T2]] +// CHECK-RWT: } +// CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[T1]] hasInserts +// CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]] +// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]] +// CHECK-RWT: %[[I:.*]] = sparse_tensor.indices_buffer %[[COO]] +// CHECK-RWT: sparse_tensor.sort_coo hybrid_quick_sort %[[NNZ]], %[[I]] jointly %[[V]] {nx = 2 : index, ny = 0 : index} +// CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor() +// CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]]) +// CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f32, %[[L1T:.*]]: tensor +// CHECK-RWT: %[[L1T2:.*]] = sparse_tensor.insert %[[L1V]] into %[[L1T]]{{\[}}%[[L1I1]], %[[L1I0]]] +// CHECK-RWT: sparse_tensor.yield %[[L1T2]] +// CHECK-RWT: } +// CHECK-RWT: %[[T5:.*]] = sparse_tensor.load %[[T4]] hasInserts +// CHECK-RWT: %[[T6:.*]] = sparse_tensor.convert %[[T5]] +// CHECK-RWT: bufferization.dealloc_tensor %[[COO]] +// CHECK-RWT: return %[[T6]] +// CHECK-RWT: } func.func @sparse_constant_csc() -> tensor<8x7xf32, #CSC>{ // Initialize a tensor. %0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32>