diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -713,51 +713,95 @@ } SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); - // We don't need a temporary COO tensor if the destination has an identity - // ordering. Otherwise, we use the destination ordering for the temporary - // COO tensor. + bool allDense = encDst.isAllDense(); + // We don't need a temporary COO tensor if the destination has all dense + // dimension or has an identity ordering. Otherwise, we use the destination + // ordering for the temporary COO tensor. // TODO: enhance foreachOp to take ordering to remove the need of a // temporary COO tensor here. - RankedTensorType bufferTp = encDst.hasIdDimOrdering() - ? dstTp - : getUnorderedCOOFromTypeWithOrdering( - dstTp, encDst.getDimOrdering()); - auto buffer = + bool needTmpCOO = !(allDense || encDst.hasIdDimOrdering()); + RankedTensorType bufferTp = needTmpCOO + ? getUnorderedCOOFromTypeWithOrdering( + dstTp, encDst.getDimOrdering()) + : dstTp; + Value dst = rewriter.create(loc, bufferTp, dynSizes).getResult(); + Value allDenseTensor; + int64_t rank = dstTp.getRank(); + SmallVector initArgs; + if (allDense) { + // Create a view of the values buffer to match the unannotated dense + // tensor. + Value valuesBuffer = genToValues(rewriter, loc, dst); + Value idxBuffer = genAlloca(rewriter, loc, rank, rewriter.getIndexType(), + /*staticShape=*/true); + allDenseTensor = dst; + dst = reshapeValuesToLevels(rewriter, loc, encDst, sizes, valuesBuffer, + idxBuffer); + } else { + initArgs.push_back(dst); + } + + if (allDense && !fromSparseConst && encDst.hasIdDimOrdering()) { + auto srcTp = MemRefType::get( + op.getSource().getType().cast().getShape(), + dstTp.getElementType()); + src = rewriter.create(loc, srcTp, src); + rewriter.create(loc, src, dst); + dst = rewriter.create(loc, dstTp, allDenseTensor).getResult(); + rewriter.replaceOp(op, dst); + return success(); + } + auto foreachOp = rewriter.create( - loc, src, buffer, + loc, src, initArgs, [&](OpBuilder &builder, Location loc, ValueRange indices, Value v, ValueRange reduc) { - Value input = reduc.front(); + Value tensor = allDense ? dst : reduc.front(); uint64_t rank = dstTp.getRank(); SmallVector indicesArray(rank, Value()); for (uint64_t i = 0; i < rank; i++) indicesArray[toStoredDim(encDst, i)] = indices[i]; if (fromSparseConst) { - input = builder.create(loc, v, input, indicesArray); + if (allDense) + builder.create(loc, v, tensor, indicesArray); + else + tensor = builder.create(loc, v, tensor, indicesArray); } else { Value cond = genIsNonzero(builder, loc, v); auto ifOp = builder.create( - loc, TypeRange(input.getType()), cond, /*else*/ true); + loc, allDense ? TypeRange() : TypeRange(tensor.getType()), cond, + /*else=*/!allDense); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - Value insert = - builder.create(loc, v, input, indicesArray); - builder.create(loc, insert); - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - builder.create(loc, input); + if (allDense) { + builder.create(loc, v, tensor, indicesArray); + } else { + Value updatedTensor = + builder.create(loc, v, tensor, indicesArray); + builder.create(loc, updatedTensor); + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + builder.create(loc, tensor); + tensor = ifOp.getResult(0); + } builder.setInsertionPointAfter(ifOp); - input = ifOp.getResult(0); } - builder.create(loc, input); + if (allDense) + builder.create(loc); + else + builder.create(loc, tensor); }); rewriter.setInsertionPointAfter(op); - src = rewriter.create(loc, foreachOp.getResult(0), true); - if (bufferTp != dstTp) { - rewriter.replaceOpWithNewOp(op, dstTp, src); - rewriter.create(loc, src); + if (allDense) { + dst = rewriter.create(loc, dstTp, allDenseTensor).getResult(); } else { - rewriter.replaceOp(op, src); + dst = rewriter.create(loc, foreachOp.getResult(0), true); + if (needTmpCOO) { + Value tmpCoo = dst; + dst = rewriter.create(loc, dstTp, tmpCoo).getResult(); + rewriter.create(loc, tmpCoo); + } } + rewriter.replaceOp(op, dst); return success(); } diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir @@ -1,6 +1,6 @@ -// RxUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s +// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s // RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=false" \ -// RUN: --canonicalize --cse +// RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT #SparseVector = #sparse_tensor.encoding<{ dimLevelType = ["compressed"] @@ -10,6 +10,10 @@ dimLevelType = ["dense", "compressed"] }> +#DD = #sparse_tensor.encoding<{ + dimLevelType = ["dense", "dense"] +}> + #CSC = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(i, j) -> (j, i)> @@ -134,6 +138,28 @@ return %0 : tensor<2x4xf64, #CSR> } +// CHECK-RWT-LABEL: func.func @sparse_convert_2d_dd( +// CHECK-RWT-SAME: %[[T0:.*]]: tensor<2x4xf64>) -> tensor<2x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense" ] }>> { +// CHECK-RWT-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index +// CHECK-RWT-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index +// CHECK-RWT-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index +// CHECK-RWT-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index +// CHECK-RWT: %[[T1:.*]] = bufferization.alloc_tensor() +// CHECK-RWT: %[[VAL_0:.*]] = sparse_tensor.values %[[T1]] +// CHECK-RWT: %[[DIM_0:.*]] = memref.alloca() : memref<2xindex> +// CHECK-RWT: memref.store %[[TMP_c2]], %[[DIM_0]][%[[TMP_c0]]] +// CHECK-RWT: memref.store %[[TMP_c4]], %[[DIM_0]][%[[TMP_c1]]] +// CHECK-RWT: %[[VAL_1:.*]] = memref.reshape %[[VAL_0]](%[[DIM_0]]) +// CHECK-RWT: %[[VAL_2:.*]] = bufferization.to_memref %[[T0]] +// CHECK-RWT: memref.copy %[[VAL_2]], %[[VAL_1]] +// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[T1]] +// CHECK-RWT: return %[[R]] +// CHECK-RWT: } +func.func @sparse_convert_2d_dd(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #DD> { + %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #DD> + return %0 : tensor<2x4xf64, #DD> +} + // CHECK-LABEL: func @sparse_constant() -> !llvm.ptr { // CHECK-DAG: %[[EmptyCOO:.*]] = arith.constant 4 : i32 // CHECK-DAG: %[[FromCOO:.*]] = arith.constant 2 : i32 @@ -183,6 +209,33 @@ return %1 : tensor<8x7xf32, #CSR> } +// CHECK-RWT-LABEL: func.func @sparse_constant_dd() -> tensor<8x7xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense" ] }>> { +// CHECK-RWT-DAG: %[[TMP_c8:.*]] = arith.constant 8 : index +// CHECK-RWT-DAG: %[[TMP_c7:.*]] = arith.constant 7 : index +// CHECK-RWT-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index +// CHECK-RWT-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index +// CHECK-RWT: %[[F0:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32> +// CHECK-RWT: %[[T0:.*]] = bufferization.alloc_tensor() +// CHECK-RWT: %[[VAL_0:.*]] = sparse_tensor.values %[[T0]] +// CHECK-RWT: %[[DIM_0:.*]] = memref.alloca() +// CHECK-RWT: memref.store %[[TMP_c8]], %[[DIM_0]][%[[TMP_c0]]] +// CHECK-RWT: memref.store %[[TMP_c7]], %[[DIM_0]][%[[TMP_c1]]] +// CHECK-RWT: %[[VAL_1:.*]] = memref.reshape %[[VAL_0]](%[[DIM_0]]) +// CHECK-RWT: sparse_tensor.foreach in %[[F0]] +// CHECK-RWT: ^bb0(%[[L0I0:.*]]: index, %[[L0I1:.*]]: index, %[[L0V:.*]]: f32): +// CHECK-RWT: memref.store %[[L0V]], %[[VAL_1]]{{\[}}%[[L0I0]], %[[L0I1]]] +// CHECK-RWT: } +// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[T0]] +// CHECK-RWT: return %[[R]] +// CHECK-RWT: } +func.func @sparse_constant_dd() -> tensor<8x7xf32, #DD>{ + // Initialize a tensor. + %0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32> + // Convert the tensor to a sparse tensor. + %1 = sparse_tensor.convert %0 : tensor<8x7xf32> to tensor<8x7xf32, #DD> + return %1 : tensor<8x7xf32, #DD> +} + // CHECK-RWT-LABEL: func.func @sparse_constant_csc() -> tensor<8x7xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>> { // CHECK-RWT: %[[F0:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32> // CHECK-RWT: %[[T0:.*]] = bufferization.alloc_tensor()