diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -645,12 +645,7 @@ } // Retrieve NNZ. - auto ptrTp = - MemRefType::get(dynShape, getPointerOverheadType(rewriter, encSrc)); - Value p0 = - rewriter.create(loc, ptrTp, src, rewriter.getIndexAttr(0)); - Value c1 = constantIndex(rewriter, loc, 1); - Value nnz = rewriter.create(loc, p0, c1); + Value nnz = rewriter.create(loc, src); nnz = rewriter.create(loc, rewriter.getIndexType(), nnz); diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir @@ -105,29 +105,28 @@ // CHECK-RWT-LABEL: func.func @sparse_convert_2d( // CHECK-RWT-SAME: %[[T0:.*]]: tensor<2x4xf64>) -> tensor<2x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> { -// CHECK-RWT: %[[VAL_1:.*]] = arith.constant 1 : index -// CHECK-RWT: %[[VAL_2:.*]] = bufferization.alloc_tensor() -// CHECK-RWT: %[[VAL_3:.*]] = sparse_tensor.foreach in %[[T0]] init(%[[VAL_2]]) -// CHECK-RWT: ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: f64, %[[VAL_7:.*]]: tensor -// CHECK-RWT: %[[VAL_8:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]{{\[}}%[[VAL_4]], %[[VAL_5]]] -// CHECK-RWT: sparse_tensor.yield %[[VAL_8]] +// CHECK-RWT: %[[T1:.*]] = bufferization.alloc_tensor() +// CHECK-RWT: %[[T2:.*]] = sparse_tensor.foreach in %[[T0]] init(%[[T1]]) +// CHECK-RWT: ^bb0(%[[L0I0:.*]]: index, %[[L0I1:.*]]: index, %[[L0V:.*]]: f64, %[[L0T:.*]]: tensor +// CHECK-RWT: %[[L0T2:.*]] = sparse_tensor.insert %[[L0V]] into %[[L0T]]{{\[}}%[[L0I0]], %[[L0I1]]] +// CHECK-RWT: sparse_tensor.yield %[[L0T2]] // CHECK-RWT: } -// CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[VAL_3]] hasInserts +// CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[T2]] hasInserts // CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index} // CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index} -// CHECK-RWT: %[[VAL_13:.*]] = memref.load %[[I0]]{{\[}}%[[VAL_1]]] : memref -// CHECK-RWT: %[[VAL_14:.*]] = sparse_tensor.values %[[COO]] -// CHECK-RWT: sparse_tensor.sort %[[VAL_13]], %[[I0]], %[[I1]] jointly %[[VAL_14]] : memref, memref jointly memref -// CHECK-RWT: %[[VAL_15:.*]] = bufferization.alloc_tensor() -// CHECK-RWT: %[[VAL_16:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[VAL_15]]) -// CHECK-RWT: ^bb0(%[[VAL_17:.*]]: index, %[[VAL_18:.*]]: index, %[[VAL_19:.*]]: f64, %[[VAL_20:.*]]: tensor -// CHECK-RWT: %[[VAL_21:.*]] = sparse_tensor.insert %[[VAL_19]] into %[[VAL_20]]{{\[}}%[[VAL_17]], %[[VAL_18]]] -// CHECK-RWT: sparse_tensor.yield %[[VAL_21]] +// CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]] +// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]] +// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]] : memref, memref jointly memref +// CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor() +// CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]]) +// CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f64, %[[L1T:.*]]: tensor +// CHECK-RWT: %[[L1T2:.*]] = sparse_tensor.insert %[[L1V]] into %[[L1T]]{{\[}}%[[L1I0]], %[[L1I1]]] +// CHECK-RWT: sparse_tensor.yield %[[L1T2]] // CHECK-RWT: } -// CHECK-RWT: %[[VAL_22:.*]] = sparse_tensor.load %[[VAL_16]] hasInserts -// CHECK-RWT: %[[VAL_24:.*]] = sparse_tensor.convert %[[VAL_22]] +// CHECK-RWT: %[[T5:.*]] = sparse_tensor.load %[[T4]] hasInserts +// CHECK-RWT: %[[T6:.*]] = sparse_tensor.convert %[[T5]] // CHECK-RWT: bufferization.dealloc_tensor %[[COO]] -// CHECK-RWT: return %[[VAL_24]] +// CHECK-RWT: return %[[T6]] // CHECK-RWT: } func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> { %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #CSR> @@ -163,30 +162,29 @@ // CHECK: return %[[T]] : !llvm.ptr // CHECK-RWT-LABEL: func.func @sparse_constant() -> tensor<8x7xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> { -// CHECK-RWT: %[[VAL_0:.*]] = arith.constant 1 : index -// CHECK-RWT: %[[VAL_1:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32> +// CHECK-RWT: %[[F0:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32> // CHECK-RWT: %[[T0:.*]] = bufferization.alloc_tensor() -// CHECK-RWT: %[[T1:.*]] = sparse_tensor.foreach in %[[VAL_1]] init(%[[T0]]) -// CHECK-RWT: ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: tensor -// CHECK-RWT: %[[T2:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]{{\[}}%[[VAL_4]], %[[VAL_5]]] -// CHECK-RWT: sparse_tensor.yield %[[T2]] +// CHECK-RWT: %[[T1:.*]] = sparse_tensor.foreach in %[[F0]] init(%[[T0]]) +// CHECK-RWT: ^bb0(%[[L0I0:.*]]: index, %[[L0I1:.*]]: index, %[[L0V:.*]]: f32, %[[L0T:.*]]: tensor +// CHECK-RWT: %[[L0T2:.*]] = sparse_tensor.insert %[[L0V]] into %[[L0T]]{{\[}}%[[L0I0]], %[[L0I1]]] +// CHECK-RWT: sparse_tensor.yield %[[L0T2]] // CHECK-RWT: } // CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[T1]] hasInserts // CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index} // CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index} -// CHECK-RWT: %[[VAL_13:.*]] = memref.load %[[I0]]{{\[}}%[[VAL_0]]] : memref +// CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]] // CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]] -// CHECK-RWT: sparse_tensor.sort %[[VAL_13]], %[[I0]], %[[I1]] jointly %[[V]] : memref, memref jointly memref -// CHECK-RWT: %[[VAL_15:.*]] = bufferization.alloc_tensor() -// CHECK-RWT: %[[VAL_16:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[VAL_15]]) -// CHECK-RWT: ^bb0(%[[VAL_17:.*]]: index, %[[VAL_18:.*]]: index, %[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: tensor -// CHECK-RWT: %[[VAL_21:.*]] = sparse_tensor.insert %[[VAL_19]] into %[[VAL_20]]{{\[}}%[[VAL_17]], %[[VAL_18]]] -// CHECK-RWT: sparse_tensor.yield %[[VAL_21]] +// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]] : memref, memref jointly memref +// CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor() +// CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]]) +// CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f32, %[[L1T:.*]]: tensor +// CHECK-RWT: %[[L1T2:.*]] = sparse_tensor.insert %[[L1V]] into %[[L1T]]{{\[}}%[[L1I0]], %[[L1I1]]] +// CHECK-RWT: sparse_tensor.yield %[[L1T2]] // CHECK-RWT: } -// CHECK-RWT: %[[VAL_22:.*]] = sparse_tensor.load %[[VAL_16]] hasInserts -// CHECK-RWT: %[[VAL_24:.*]] = sparse_tensor.convert %[[VAL_22]] +// CHECK-RWT: %[[T5:.*]] = sparse_tensor.load %[[T4]] hasInserts +// CHECK-RWT: %[[T6:.*]] = sparse_tensor.convert %[[T5]] // CHECK-RWT: bufferization.dealloc_tensor %[[COO]] -// CHECK-RWT: return %[[VAL_24]] +// CHECK-RWT: return %[[T6]] // CHECK-RWT: } func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{ // Initialize a tensor. diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir @@ -86,11 +86,9 @@ // CHECK-RWT-LABEL: func.func @sparse_convert( // CHECK-RWT-SAME: %[[A:.*]]: tensor>) // CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-RWT: %[[D:.*]] = tensor.dim %[[A]], %[[C0]] // CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[A]] {dimension = 0 : index} -// CHECK-RWT: %[[NNZr:.*]] = memref.load %[[I0]]{{\[}}%[[C1]]] : memref -// CHECK-RWT: %[[NNZ:.*]] = arith.index_cast %[[NNZr]] : i64 to index +// CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[A]] // CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[A]] // CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]] jointly %[[V]] // CHECK-RWT: %[[DST:.*]] = bufferization.alloc_tensor(%[[D]])