diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -674,6 +674,92 @@ } private: + // Handles sparse constant or dense tensor to annotate all dense sparse tensor + // conversion as follows: + // If the source is a dense tensor: + // if (destination has an identity ordering) + // copy src to dst_tensor_values + // else for i1 in dim1 + // .. + // for ik in dimk + // val = src[i1,..,ik] + // if val != 0 + // dst_tensor_values[i1,..,ik] = val + // + // If the source is a sparse constant in COO format: + // for i in range(NNZ) + // val = values[i] + // [i1,..,ik] = indices[i] + // dst_tensor_values[i1,..,ik] = val + LogicalResult dense2AnnotateAllDenseRewrite(ConvertOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Value src = op.getSource(); + auto dstTp = getRankedTensorType(op); + SmallVector sizes; + sizesFromSrc(rewriter, sizes, loc, src); + SmallVector dynSizes; + getDynamicSizes(dstTp, sizes, dynSizes); + + bool fromSparseConst = false; + if (auto constOp = op.getSource().getDefiningOp()) { + if (constOp.getValue().dyn_cast()) { + fromSparseConst = true; + } + } + + SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); + assert(encDst.isAllDense()); + + Value tensor = + rewriter.create(loc, dstTp, dynSizes).getResult(); + int64_t rank = dstTp.getRank(); + // Create a view of the values buffer to match the unannotated dense + // tensor. + Value valuesBuffer = genToValues(rewriter, loc, tensor); + Value idxBuffer = genAlloca(rewriter, loc, rank, rewriter.getIndexType(), + /*staticShape=*/true); + Value dst = reshapeValuesToLevels(rewriter, loc, encDst, sizes, + valuesBuffer, idxBuffer); + + if (!fromSparseConst && encDst.hasIdDimOrdering()) { + auto srcTp = MemRefType::get( + op.getSource().getType().cast().getShape(), + dstTp.getElementType()); + src = rewriter.create(loc, srcTp, src); + rewriter.create(loc, src, dst); + dst = rewriter.create(loc, dstTp, tensor).getResult(); + rewriter.replaceOp(op, dst); + return success(); + } + + rewriter.create( + loc, src, std::nullopt, + [&](OpBuilder &builder, Location loc, ValueRange indices, Value v, + ValueRange) { + uint64_t rank = dstTp.getRank(); + SmallVector indicesArray(rank, Value()); + for (uint64_t i = 0; i < rank; i++) + indicesArray[toStoredDim(encDst, i)] = indices[i]; + if (fromSparseConst) { + builder.create(loc, v, dst, indicesArray); + } else { + Value cond = genIsNonzero(builder, loc, v); + auto ifOp = builder.create(loc, TypeRange(), cond, + /*else=*/false); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + builder.create(loc, v, dst, indicesArray); + builder.setInsertionPointAfter(ifOp); + } + builder.create(loc); + }); + rewriter.setInsertionPointAfter(op); + dst = rewriter.create(loc, dstTp, tensor).getResult(); + rewriter.replaceOp(op, dst); + + return success(); + } + // Handles sparse constant to sparse tensor or dense tensor to sparse tensor // conversion as follows: // t = new sparse COO tensor @@ -711,6 +797,9 @@ } SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); + if (encDst.isAllDense()) + return dense2AnnotateAllDenseRewrite(op, rewriter); + // We don't need a temporary COO tensor if the destination has an identity // ordering. Otherwise, we use the destination ordering for the temporary // COO tensor. diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir @@ -10,6 +10,10 @@ dimLevelType = ["dense", "compressed"] }> +#DD = #sparse_tensor.encoding<{ + dimLevelType = ["dense", "dense"] +}> + #CSC = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(i, j) -> (j, i)> @@ -134,6 +138,28 @@ return %0 : tensor<2x4xf64, #CSR> } +// CHECK-RWT-LABEL: func.func @sparse_convert_2d_dd( +// CHECK-RWT-SAME: %[[T0:.*]]: tensor<2x4xf64>) -> tensor<2x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense" ] }>> { +// CHECK-RWT-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index +// CHECK-RWT-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index +// CHECK-RWT-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index +// CHECK-RWT-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index +// CHECK-RWT: %[[T1:.*]] = bufferization.alloc_tensor() +// CHECK-RWT: %[[VAL_0:.*]] = sparse_tensor.values %[[T1]] +// CHECK-RWT: %[[DIM_0:.*]] = memref.alloca() : memref<2xindex> +// CHECK-RWT: memref.store %[[TMP_c2]], %[[DIM_0]][%[[TMP_c0]]] +// CHECK-RWT: memref.store %[[TMP_c4]], %[[DIM_0]][%[[TMP_c1]]] +// CHECK-RWT: %[[VAL_1:.*]] = memref.reshape %[[VAL_0]](%[[DIM_0]]) +// CHECK-RWT: %[[VAL_2:.*]] = bufferization.to_memref %[[T0]] +// CHECK-RWT: memref.copy %[[VAL_2]], %[[VAL_1]] +// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[T1]] +// CHECK-RWT: return %[[R]] +// CHECK-RWT: } +func.func @sparse_convert_2d_dd(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #DD> { + %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #DD> + return %0 : tensor<2x4xf64, #DD> +} + // CHECK-LABEL: func @sparse_constant() -> !llvm.ptr { // CHECK-DAG: %[[EmptyCOO:.*]] = arith.constant 4 : i32 // CHECK-DAG: %[[FromCOO:.*]] = arith.constant 2 : i32 @@ -183,6 +209,33 @@ return %1 : tensor<8x7xf32, #CSR> } +// CHECK-RWT-LABEL: func.func @sparse_constant_dd() -> tensor<8x7xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense" ] }>> { +// CHECK-RWT-DAG: %[[TMP_c8:.*]] = arith.constant 8 : index +// CHECK-RWT-DAG: %[[TMP_c7:.*]] = arith.constant 7 : index +// CHECK-RWT-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index +// CHECK-RWT-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index +// CHECK-RWT: %[[F0:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32> +// CHECK-RWT: %[[T0:.*]] = bufferization.alloc_tensor() +// CHECK-RWT: %[[VAL_0:.*]] = sparse_tensor.values %[[T0]] +// CHECK-RWT: %[[DIM_0:.*]] = memref.alloca() +// CHECK-RWT: memref.store %[[TMP_c8]], %[[DIM_0]][%[[TMP_c0]]] +// CHECK-RWT: memref.store %[[TMP_c7]], %[[DIM_0]][%[[TMP_c1]]] +// CHECK-RWT: %[[VAL_1:.*]] = memref.reshape %[[VAL_0]](%[[DIM_0]]) +// CHECK-RWT: sparse_tensor.foreach in %[[F0]] +// CHECK-RWT: ^bb0(%[[L0I0:.*]]: index, %[[L0I1:.*]]: index, %[[L0V:.*]]: f32): +// CHECK-RWT: memref.store %[[L0V]], %[[VAL_1]]{{\[}}%[[L0I0]], %[[L0I1]]] +// CHECK-RWT: } +// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[T0]] +// CHECK-RWT: return %[[R]] +// CHECK-RWT: } +func.func @sparse_constant_dd() -> tensor<8x7xf32, #DD>{ + // Initialize a tensor. + %0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32> + // Convert the tensor to a sparse tensor. + %1 = sparse_tensor.convert %0 : tensor<8x7xf32> to tensor<8x7xf32, #DD> + return %1 : tensor<8x7xf32, #DD> +} + // CHECK-RWT-LABEL: func.func @sparse_constant_csc() -> tensor<8x7xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>> { // CHECK-RWT: %[[F0:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32> // CHECK-RWT: %[[T0:.*]] = bufferization.alloc_tensor()