diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h @@ -704,12 +704,17 @@ /// in the argument differ from those in the current cursor. uint64_t lexDiff(const uint64_t *lvlCoords) const { const uint64_t lvlRank = getLvlRank(); - for (uint64_t l = 0; l < lvlRank; ++l) - if (lvlCoords[l] > lvlCursor[l]) + for (uint64_t l = 0; l < lvlRank; ++l) { + const auto crd = lvlCoords[l]; + const auto cur = lvlCursor[l]; + if (crd > cur || (crd == cur && !isUniqueLvl(l))) return l; - else - assert(lvlCoords[l] == lvlCursor[l] && "non-lexicographic insertion"); - assert(0 && "duplicate insertion"); + if (crd < cur) { + assert(false && "non-lexicographic insertion"); + return -1u; + } + } + assert(false && "duplicate insertion"); return -1u; } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_coo_test.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_coo_test.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_coo_test.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_coo_test.mlir @@ -87,6 +87,22 @@ return %0 : tensor<8x8xf32> } + func.func @add_coo_coo_out_coo(%arga: tensor<8x8xf32, #SortedCOO>, + %argb: tensor<8x8xf32, #SortedCOO>) + -> tensor<8x8xf32, #SortedCOO> { + %init = tensor.empty() : tensor<8x8xf32, #SortedCOO> + %0 = linalg.generic #trait + ins(%arga, %argb: tensor<8x8xf32, #SortedCOO>, + tensor<8x8xf32, #SortedCOO>) + outs(%init: tensor<8x8xf32, #SortedCOO>) { + ^bb(%a: f32, %b: f32, %x: f32): + %0 = arith.addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<8x8xf32, #SortedCOO> + return %0 : tensor<8x8xf32, #SortedCOO> + } + + func.func @add_coo_dense(%arga: tensor<8x8xf32>, %argb: tensor<8x8xf32, #SortedCOO>) -> tensor<8x8xf32> { @@ -149,17 +165,21 @@ %C3 = call @add_coo_coo(%COO_A, %COO_B) : (tensor<8x8xf32, #SortedCOO>, tensor<8x8xf32, #SortedCOO>) -> tensor<8x8xf32> + %COO_RET = call @add_coo_coo_out_coo(%COO_A, %COO_B) : (tensor<8x8xf32, #SortedCOO>, + tensor<8x8xf32, #SortedCOO>) + -> tensor<8x8xf32, #SortedCOO> + %C4 = sparse_tensor.convert %COO_RET : tensor<8x8xf32, #SortedCOO> to tensor<8x8xf32> // // Verify computed matrix C. // - // CHECK-COUNT-3: ( 8.8, 4.8, 6.8, 4.8, 8.8, 6.1, 14.8, 16.8 ) - // CHECK-NEXT-COUNT-3: ( 4.4, 4.4, 4.4, 8.4, 8.4, 12.4, 16.4, 16.4 ) - // CHECK-NEXT-COUNT-3: ( 8.8, 4.8, 6.8, 8.8, 8.8, 12.8, 14.8, 15.8 ) - // CHECK-NEXT-COUNT-3: ( 4.3, 5.3, 6.3, 8.3, 8.3, 12.3, 14.3, 16.3 ) - // CHECK-NEXT-COUNT-3: ( 4.5, 4.5, 6.5, 8.5, 8.5, 12.5, 14.5, 16.5 ) - // CHECK-NEXT-COUNT-3: ( 9.9, 4.9, 6.9, 8.9, 8.9, 12.9, 15.9, 16.9 ) - // CHECK-NEXT-COUNT-3: ( 12.1, 6.1, 5.1, 9.1, 9.1, 13.1, 15.1, 17.1 ) - // CHECK-NEXT-COUNT-3: ( 15.4, 5.4, 7.4, 5.4, 11.4, 10.4, 11.4, 9.4 ) + // CHECK-COUNT-4: ( 8.8, 4.8, 6.8, 4.8, 8.8, 6.1, 14.8, 16.8 ) + // CHECK-NEXT-COUNT-4: ( 4.4, 4.4, 4.4, 8.4, 8.4, 12.4, 16.4, 16.4 ) + // CHECK-NEXT-COUNT-4: ( 8.8, 4.8, 6.8, 8.8, 8.8, 12.8, 14.8, 15.8 ) + // CHECK-NEXT-COUNT-4: ( 4.3, 5.3, 6.3, 8.3, 8.3, 12.3, 14.3, 16.3 ) + // CHECK-NEXT-COUNT-4: ( 4.5, 4.5, 6.5, 8.5, 8.5, 12.5, 14.5, 16.5 ) + // CHECK-NEXT-COUNT-4: ( 9.9, 4.9, 6.9, 8.9, 8.9, 12.9, 15.9, 16.9 ) + // CHECK-NEXT-COUNT-4: ( 12.1, 6.1, 5.1, 9.1, 9.1, 13.1, 15.1, 17.1 ) + // CHECK-NEXT-COUNT-4: ( 15.4, 5.4, 7.4, 5.4, 11.4, 10.4, 11.4, 9.4 ) // %f0 = arith.constant 0.0 : f32 scf.for %i = %c0 to %c8 step %c1 { @@ -169,9 +189,12 @@ : tensor<8x8xf32>, vector<8xf32> %v3 = vector.transfer_read %C3[%i, %c0], %f0 : tensor<8x8xf32>, vector<8xf32> + %v4 = vector.transfer_read %C4[%i, %c0], %f0 + : tensor<8x8xf32>, vector<8xf32> vector.print %v1 : vector<8xf32> vector.print %v2 : vector<8xf32> vector.print %v3 : vector<8xf32> + vector.print %v4 : vector<8xf32> } // Release resources. @@ -181,6 +204,7 @@ bufferization.dealloc_tensor %CSR_A : tensor<8x8xf32, #CSR> bufferization.dealloc_tensor %COO_A : tensor<8x8xf32, #SortedCOO> bufferization.dealloc_tensor %COO_B : tensor<8x8xf32, #SortedCOO> + bufferization.dealloc_tensor %COO_RET : tensor<8x8xf32, #SortedCOO> return