diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -533,6 +533,13 @@ SmallVector dynSizes; getDynamicSizes(dstTp, sizes, dynSizes); + bool fromSparseConst = false; + if (auto constOp = op.getSource().getDefiningOp()) { + if (constOp.getValue().dyn_cast()) { + fromSparseConst = true; + } + } + RankedTensorType cooTp = getUnorderedCOOFromType(dstTp); auto cooBuffer = rewriter.create(loc, cooTp, dynSizes).getResult(); @@ -540,8 +547,22 @@ loc, src, cooBuffer, [&](OpBuilder &builder, Location loc, ValueRange indices, Value v, ValueRange reduc) { - builder.create( - loc, builder.create(loc, v, reduc.front(), indices)); + Value input = reduc.front(); + if (fromSparseConst) { + input = builder.create(loc, v, input, indices); + } else { + Value cond = genIsNonzero(builder, loc, v); + auto ifOp = builder.create( + loc, TypeRange(input.getType()), cond, /*else*/ true); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value insert = builder.create(loc, v, input, indices); + builder.create(loc, insert); + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + builder.create(loc, input); + builder.setInsertionPointAfter(ifOp); + input = ifOp.getResult(0); + } + builder.create(loc, input); }); rewriter.setInsertionPointAfter(op); src = rewriter.create(loc, foreachOp.getResult(0), true); diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir @@ -109,8 +109,14 @@ // CHECK-RWT: %[[VAL_2:.*]] = bufferization.alloc_tensor() // CHECK-RWT: %[[VAL_3:.*]] = sparse_tensor.foreach in %[[T0]] init(%[[VAL_2]]) // CHECK-RWT: ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: f64, %[[VAL_7:.*]]: tensor -// CHECK-RWT: %[[VAL_8:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]{{\[}}%[[VAL_4]], %[[VAL_5]]] -// CHECK-RWT: sparse_tensor.yield %[[VAL_8]] +// CHECK-RWT: %[[CMP:.*]] = arith.cmpf une, %[[VAL_6]] +// CHECK-RWT: %[[IFR:.*]] = scf.if %[[CMP]] +// CHECK-RWT: %[[Y1:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]] +// CHECK-RWT: scf.yield %[[Y1]] +// CHECK-RWT: } else { +// CHECK-RWT: scf.yield %[[VAL_7]] +// CHECK-RWT: } +// CHECK-RWT: sparse_tensor.yield %[[IFR]] // CHECK-RWT: } // CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[VAL_3]] hasInserts // CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index} @@ -166,7 +172,7 @@ // CHECK-RWT: %[[VAL_0:.*]] = arith.constant 1 : index // CHECK-RWT: %[[VAL_1:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32> // CHECK-RWT: %[[T0:.*]] = bufferization.alloc_tensor() -// CHECK-RWT: %[[T1:.*]] = sparse_tensor.foreach in %[[VAL_1]] init(%[[T0]]) +// CHECK-RWT: %[[T1:.*]] = sparse_tensor.foreach in %[[VAL_1]] init(%[[T0]]) // CHECK-RWT: ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: tensor // CHECK-RWT: %[[T2:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]{{\[}}%[[VAL_4]], %[[VAL_5]]] // CHECK-RWT: sparse_tensor.yield %[[T2]] diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir @@ -116,6 +116,35 @@ %b3 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #CSR> %b4 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #DCSR> + // + // Sanity check on stored entries before going into the computations. + // + // CHECK: 32 + // CHECK-NEXT: 32 + // CHECK-NEXT: 4 + // CHECK-NEXT: 4 + // CHECK-NEXT: 32 + // CHECK-NEXT: 32 + // CHECK-NEXT: 8 + // CHECK-NEXT: 8 + // + %noea1 = sparse_tensor.number_of_entries %a1 : tensor<4x8xf64, #CSR> + %noea2 = sparse_tensor.number_of_entries %a2 : tensor<4x8xf64, #DCSR> + %noea3 = sparse_tensor.number_of_entries %a3 : tensor<4x8xf64, #CSR> + %noea4 = sparse_tensor.number_of_entries %a4 : tensor<4x8xf64, #DCSR> + %noeb1 = sparse_tensor.number_of_entries %b1 : tensor<8x4xf64, #CSR> + %noeb2 = sparse_tensor.number_of_entries %b2 : tensor<8x4xf64, #DCSR> + %noeb3 = sparse_tensor.number_of_entries %b3 : tensor<8x4xf64, #CSR> + %noeb4 = sparse_tensor.number_of_entries %b4 : tensor<8x4xf64, #DCSR> + vector.print %noea1 : index + vector.print %noea2 : index + vector.print %noea3 : index + vector.print %noea4 : index + vector.print %noeb1 : index + vector.print %noeb2 : index + vector.print %noeb3 : index + vector.print %noeb4 : index + // Call kernels with dense. %0 = call @matmul1(%da, %db, %zero) : (tensor<4x8xf64>, tensor<8x4xf64>, tensor<4x4xf64>) -> tensor<4x4xf64> @@ -205,20 +234,20 @@ vector.print %v5 : vector<4x4xf64> // - // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) ) + // CHECK-NEXT: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) ) // %v6 = vector.transfer_read %6[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64> vector.print %v6 : vector<4x4xf64> // - // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) ) + // CHECK-NEXT: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) ) // %c7 = sparse_tensor.convert %7 : tensor<4x4xf64, #CSR> to tensor<4x4xf64> %v7 = vector.transfer_read %c7[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64> vector.print %v7 : vector<4x4xf64> // - // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) ) + // CHECK-NEXT: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) ) // %c8 = sparse_tensor.convert %8 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64> %v8 = vector.transfer_read %c8[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64> @@ -227,17 +256,26 @@ // // Sanity check on nonzeros. // - // FIXME: bring this back once dense2sparse skips zeros - // - // C_HECK: ( 30.5, 4.2, 4.6, 7, 8 ) - // C_HECK: ( 30.5, 4.2, 4.6, 7, 8 ) + // CHECK-NEXT: ( 30.5, 4.2, 4.6, 7, 8 ) + // CHECK-NEXT: ( 30.5, 4.2, 4.6, 7, 8 ) // %val7 = sparse_tensor.values %7 : tensor<4x4xf64, #CSR> to memref %val8 = sparse_tensor.values %8 : tensor<4x4xf64, #DCSR> to memref - %nz7 = vector.transfer_read %val7[%c0], %d1 : memref, vector<8xf64> - %nz8 = vector.transfer_read %val8[%c0], %d1 : memref, vector<8xf64> - vector.print %nz7 : vector<8xf64> - vector.print %nz8 : vector<8xf64> + %nz7 = vector.transfer_read %val7[%c0], %d1 : memref, vector<5xf64> + %nz8 = vector.transfer_read %val8[%c0], %d1 : memref, vector<5xf64> + vector.print %nz7 : vector<5xf64> + vector.print %nz8 : vector<5xf64> + + // + // Sanity check on stored entries after the computations. + // + // CHECK-NEXT: 5 + // CHECK-NEXT: 5 + // + %noe7 = sparse_tensor.number_of_entries %7 : tensor<4x4xf64, #CSR> + %noe8 = sparse_tensor.number_of_entries %8 : tensor<4x4xf64, #DCSR> + vector.print %noe7 : index + vector.print %noe8 : index // Release the resources. bufferization.dealloc_tensor %a1 : tensor<4x8xf64, #CSR>