diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -536,17 +536,17 @@ RankedTensorType cooTp = getUnorderedCOOFromType(dstTp); auto cooBuffer = rewriter.create(loc, cooTp, dynSizes).getResult(); - unsigned rank = dstTp.cast().getRank(); - - genDenseTensorOrSparseConstantIterLoop( - rewriter, loc, src, rank, - [&](OpBuilder &builder, Location loc, Value val, ValueRange indices) { - builder.create(loc, val, cooBuffer, indices); + auto foreachOp = rewriter.create( + loc, src, cooBuffer, + [&](OpBuilder &builder, Location loc, ValueRange indices, Value v, + ValueRange reduc) { + builder.create( + loc, builder.create(loc, v, reduc.front(), indices)); }); - rewriter.setInsertionPointAfter(op); - rewriter.replaceOpWithNewOp(op, dstTp, cooBuffer); - rewriter.create(loc, cooBuffer); + src = rewriter.create(loc, foreachOp.getResult(0), true); + rewriter.replaceOpWithNewOp(op, dstTp, src); + rewriter.create(loc, src); return success(); } diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir @@ -104,38 +104,31 @@ // CHECK: return %[[T]] : !llvm.ptr // CHECK-RWT-LABEL: func.func @sparse_convert_2d( -// CHECK-RWT-SAME: %[[A:.*]]: tensor<2x4xf64>) -> tensor<2x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> { -// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-RWT-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-RWT-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-RWT-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-RWT: %[[COO:.*]] = bufferization.alloc_tensor() -// CHECK-RWT: scf.for %[[FI:.*]] = %[[C0]] to %[[C2]] step %[[C1]] { -// CHECK-RWT: scf.for %[[FJ:.*]] = %[[C0]] to %[[C4]] step %[[C1]] { -// CHECK-RWT: %[[V:.*]] = tensor.extract %[[A]]{{\[}}%[[FI]], %[[FJ]]] : tensor<2x4xf64> -// CHECK-RWT: %[[NZ:.*]] = arith.cmpf une, %[[V]], %[[F0]] : f64 -// CHECK-RWT: scf.if %[[NZ]] { -// // FIXME: the SSA chain is broken here! -// CHECK-RWT: %{{.*}} = sparse_tensor.insert %[[V]] into %[[COO]]{{\[}}%[[FI]], %[[FJ]]] -// CHECK-RWT: } -// CHECK-RWT: } -// CHECK-RWT: } -// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index} -// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index} -// CHECK-RWT: %[[NNZ:.*]] = memref.load %[[I0]]{{\[}}%[[C1]]] : memref -// CHECK-RWT: %[[V2:.*]] = sparse_tensor.values %[[COO]] -// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V2]] -// CHECK-RWT: %[[DST:.*]] = bufferization.alloc_tensor() -// CHECK-RWT: %[[NEW_T:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[DST]]) -// CHECK-RWT: ^bb0(%[[FI0:.*]]: index, %[[FI1:.*]]: index, %[[FV:.*]]: f64, %[[R0:.*]]: tensor -// CHECK-RWT: %[[RET:.*]] = sparse_tensor.insert %[[FV]] into %[[R0]]{{\[}}%[[FI0]], %[[FI1]]] -// CHECK-RWT: sparse_tensor.yield %[[RET]] -// CHECK-RWT: } -// CHECK-RWT: %[[NT:.*]] = sparse_tensor.load %[[NEW_T]] hasInserts -// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[NT]] -// CHECK-RWT: bufferization.dealloc_tensor %[[COO]] -// CHECK-RWT: return %[[R]] : tensor<2x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> +// CHECK-RWT-SAME: %[[T0:.*]]: tensor<2x4xf64>) -> tensor<2x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> { +// CHECK-RWT: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK-RWT: %[[VAL_2:.*]] = bufferization.alloc_tensor() +// CHECK-RWT: %[[VAL_3:.*]] = sparse_tensor.foreach in %[[T0]] init(%[[VAL_2]]) +// CHECK-RWT: ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: f64, %[[VAL_7:.*]]: tensor +// CHECK-RWT: %[[VAL_8:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]{{\[}}%[[VAL_4]], %[[VAL_5]]] +// CHECK-RWT: sparse_tensor.yield %[[VAL_8]] +// CHECK-RWT: } +// CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[VAL_3]] hasInserts +// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index} +// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index} +// CHECK-RWT: %[[VAL_13:.*]] = memref.load %[[I0]]{{\[}}%[[VAL_1]]] : memref +// CHECK-RWT: %[[VAL_14:.*]] = sparse_tensor.values %[[COO]] +// CHECK-RWT: sparse_tensor.sort %[[VAL_13]], %[[I0]], %[[I1]] jointly %[[VAL_14]] : memref, memref jointly memref +// CHECK-RWT: %[[VAL_15:.*]] = bufferization.alloc_tensor() +// CHECK-RWT: %[[VAL_16:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[VAL_15]]) +// CHECK-RWT: ^bb0(%[[VAL_17:.*]]: index, %[[VAL_18:.*]]: index, %[[VAL_19:.*]]: f64, %[[VAL_20:.*]]: tensor +// CHECK-RWT: %[[VAL_21:.*]] = sparse_tensor.insert %[[VAL_19]] into %[[VAL_20]]{{\[}}%[[VAL_17]], %[[VAL_18]]] +// CHECK-RWT: sparse_tensor.yield %[[VAL_21]] +// CHECK-RWT: } +// CHECK-RWT: %[[VAL_22:.*]] = sparse_tensor.load %[[VAL_16]] hasInserts +// CHECK-RWT: %[[VAL_24:.*]] = sparse_tensor.convert %[[VAL_22]] +// CHECK-RWT: bufferization.dealloc_tensor %[[COO]] +// CHECK-RWT: return %[[VAL_24]] +// CHECK-RWT: } func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> { %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #CSR> return %0 : tensor<2x4xf64, #CSR> @@ -159,9 +152,9 @@ // CHECK: %[[N:.*]] = memref.cast %[[M]] : memref<2xindex> to memref // CHECK: %[[BUF:.*]] = memref.alloca() : memref // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] { -// CHECK-DAG: memref.store %{{.*}}, %[[M]][%[[C0]]] : memref<2xindex> -// CHECK-DAG: memref.store %{{.*}}, %[[M]][%[[C1]]] : memref<2xindex> -// CHECK-DAG: %[[V:.*]] = tensor.extract %{{.*}}[%[[I]]] : tensor<2xf32> +// CHECK-DAG: memref.store %{{.*}}, %[[M]][%[[C0]]] : memref<2xindex> +// CHECK-DAG: memref.store %{{.*}}, %[[M]][%[[C1]]] : memref<2xindex> +// CHECK-DAG: %[[V:.*]] = tensor.extract %{{.*}}[%[[I]]] : tensor<2xf32> // CHECK: memref.store %[[V]], %[[BUF]][] : memref // CHECK: call @addEltF32(%{{.*}}, %[[BUF]], %[[N]], %{{.*}}) // CHECK: } @@ -169,37 +162,32 @@ // CHECK: call @delSparseTensorCOOF32(%[[C]]) // CHECK: return %[[T]] : !llvm.ptr -// CHECK-RWT-LABEL: func.func @sparse_constant() -// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-RWT-DAG: %[[SI:.*]] = arith.constant dense<{{\[\[}}0, 0], [1, 6]]> : tensor<2x2xi64> -// CHECK-RWT-DAG: %[[SV:.*]] = arith.constant dense<[1.000000e+00, 5.000000e+00]> : tensor<2xf32> -// CHECK-RWT-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-RWT: %[[COO:.*]] = bufferization.alloc_tensor() -// CHECK-RWT: scf.for %[[FI:.*]] = %[[C0]] to %[[C2]] step %[[C1]] { -// CHECK-RWT: %[[I0r:.*]] = tensor.extract %[[SI]]{{\[}}%[[FI]], %[[C0]]] : tensor<2x2xi64> -// CHECK-RWT: %[[I0:.*]] = arith.index_cast %[[I0r]] : i64 to index -// CHECK-RWT: %[[I1r:.*]] = tensor.extract %[[SI]]{{\[}}%[[FI]], %[[C1]]] : tensor<2x2xi64> -// CHECK-RWT: %[[I1:.*]] = arith.index_cast %[[I1r]] : i64 to index -// CHECK-RWT: %[[V:.*]] = tensor.extract %[[SV]]{{\[}}%[[FI]]] : tensor<2xf32> -// // FIXME: the SSA chain is broken here! -// CHECK-RWT: sparse_tensor.insert %[[V]] into %[[COO]]{{\[}}%[[I0]], %[[I1]]] -// CHECK-RWT: } -// CHECK-RWT: %[[TI0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index} -// CHECK-RWT: %[[TI1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index} -// CHECK-RWT: %[[NNZ:.*]] = memref.load %[[TI0]]{{\[}}%[[C1]]] : memref -// CHECK-RWT: %[[TV:.*]] = sparse_tensor.values %[[COO]] -// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[TI0]], %[[TI1]] jointly %[[TV]] -// CHECK-RWT: %[[DST:.*]] = bufferization.alloc_tensor() -// CHECK-RWT: %[[RET:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[DST]]) -// CHECK-RWT: ^bb0(%[[F2I0:.*]]: index, %[[F2I1:.*]]: index, %[[F2V:.*]]: f32, %[[R0:.*]]: tensor -// CHECK-RWT: %[[NEW_T:.*]] = sparse_tensor.insert %[[F2V]] into %[[R0]]{{\[}}%[[F2I0]], %[[F2I1]]] -// CHECK-RWT: sparse_tensor.yield %[[NEW_T]] -// CHECK-RWT: } -// CHECK-RWT: %[[T:.*]] = sparse_tensor.load %[[RET]] hasInserts -// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[T]] -// CHECK-RWT: bufferization.dealloc_tensor %[[COO]] -// CHECK-RWT: return %[[R]] : tensor<8x7xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> +// CHECK-RWT-LABEL: func.func @sparse_constant() -> tensor<8x7xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> { +// CHECK-RWT: %[[VAL_0:.*]] = arith.constant 1 : index +// CHECK-RWT: %[[VAL_1:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32> +// CHECK-RWT: %[[T0:.*]] = bufferization.alloc_tensor() +// CHECK-RWT: %[[T1:.*]] = sparse_tensor.foreach in %[[VAL_1]] init(%[[T0]]) +// CHECK-RWT: ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: tensor +// CHECK-RWT: %[[T2:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]{{\[}}%[[VAL_4]], %[[VAL_5]]] +// CHECK-RWT: sparse_tensor.yield %[[T2]] +// CHECK-RWT: } +// CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[T1]] hasInserts +// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index} +// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index} +// CHECK-RWT: %[[VAL_13:.*]] = memref.load %[[I0]]{{\[}}%[[VAL_0]]] : memref +// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]] +// CHECK-RWT: sparse_tensor.sort %[[VAL_13]], %[[I0]], %[[I1]] jointly %[[V]] : memref, memref jointly memref +// CHECK-RWT: %[[VAL_15:.*]] = bufferization.alloc_tensor() +// CHECK-RWT: %[[VAL_16:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[VAL_15]]) +// CHECK-RWT: ^bb0(%[[VAL_17:.*]]: index, %[[VAL_18:.*]]: index, %[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: tensor +// CHECK-RWT: %[[VAL_21:.*]] = sparse_tensor.insert %[[VAL_19]] into %[[VAL_20]]{{\[}}%[[VAL_17]], %[[VAL_18]]] +// CHECK-RWT: sparse_tensor.yield %[[VAL_21]] +// CHECK-RWT: } +// CHECK-RWT: %[[VAL_22:.*]] = sparse_tensor.load %[[VAL_16]] hasInserts +// CHECK-RWT: %[[VAL_24:.*]] = sparse_tensor.convert %[[VAL_22]] +// CHECK-RWT: bufferization.dealloc_tensor %[[COO]] +// CHECK-RWT: return %[[VAL_24]] +// CHECK-RWT: } func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{ // Initialize a tensor. %0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32>