diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -573,8 +573,22 @@ rewriter.create(loc, cond, before->getArguments()); Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); rewriter.setInsertionPointToStart(after); + + bool hasDenseDim = llvm::any_of( + enc.getDimLevelType(), [](DimLevelType dlt) { return isDenseDLT(dlt); }); + if (hasDenseDim) { + Value elemV = rewriter.create(loc, elemPtr); + Value isZero = genIsNonzero(rewriter, loc, elemV); + scf::IfOp ifOp = rewriter.create(loc, isZero, /*else*/ false); + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + } // Callback here to build loop body. bodyBuilder(rewriter, loc, srcIdx, elemPtr); + + // Exit the scope from the IfOp. + if (hasDenseDim) + rewriter.setInsertionPointToEnd(after); + rewriter.create(loc); // Finish generating loop. rewriter.setInsertionPointAfter(whileOp); diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir @@ -169,9 +169,12 @@ %v = vector.transfer_read %c[%c0, %c0], %du: tensor<9x4xf64>, vector<9x4xf64> vector.print %v : vector<9x4xf64> + %n = sparse_tensor.number_of_entries %A : tensor<9x4xf64, #MAT_C_C> + vector.print %n : index + %1 = sparse_tensor.values %A : tensor<9x4xf64, #MAT_C_C> to memref - %2 = vector.transfer_read %1[%c0], %du: memref, vector<36xf64> - vector.print %2 : vector<36xf64> + %2 = vector.transfer_read %1[%c0], %du: memref, vector<18xf64> + vector.print %2 : vector<18xf64> return } @@ -184,9 +187,12 @@ %v = vector.transfer_read %c[%c0, %c0], %du: tensor<9x4xf64>, vector<9x4xf64> vector.print %v : vector<9x4xf64> + %n = sparse_tensor.number_of_entries %A : tensor<9x4xf64, #MAT_C_C_P> + vector.print %n : index + %1 = sparse_tensor.values %A : tensor<9x4xf64, #MAT_C_C_P> to memref - %2 = vector.transfer_read %1[%c0], %du: memref, vector<36xf64> - vector.print %2 : vector<36xf64> + %2 = vector.transfer_read %1[%c0], %du: memref, vector<18xf64> + vector.print %2 : vector<18xf64> return } @@ -209,9 +215,12 @@ %v = vector.transfer_read %c[%c0, %c0], %du: tensor<4x9xf64>, vector<4x9xf64> vector.print %v : vector<4x9xf64> + %n = sparse_tensor.number_of_entries %A : tensor<4x9xf64, #MAT_C_C> + vector.print %n : index + %1 = sparse_tensor.values %A : tensor<4x9xf64, #MAT_C_C> to memref - %2 = vector.transfer_read %1[%c0], %du: memref, vector<36xf64> - vector.print %2 : vector<36xf64> + %2 = vector.transfer_read %1[%c0], %du: memref, vector<18xf64> + vector.print %2 : vector<18xf64> return } @@ -224,9 +233,12 @@ %v = vector.transfer_read %c[%c0, %c0], %du: tensor, vector<4x9xf64> vector.print %v : vector<4x9xf64> + %n = sparse_tensor.number_of_entries %A : tensor + vector.print %n : index + %1 = sparse_tensor.values %A : tensor to memref - %2 = vector.transfer_read %1[%c0], %du: memref, vector<36xf64> - vector.print %2 : vector<36xf64> + %2 = vector.transfer_read %1[%c0], %du: memref, vector<18xf64> + vector.print %2 : vector<18xf64> return } @@ -239,9 +251,12 @@ %v = vector.transfer_read %c[%c0, %c0], %du: tensor<4x9xf64>, vector<4x9xf64> vector.print %v : vector<4x9xf64> + %n = sparse_tensor.number_of_entries %A : tensor<4x9xf64, #MAT_C_C_P> + vector.print %n : index + %1 = sparse_tensor.values %A : tensor<4x9xf64, #MAT_C_C_P> to memref - %2 = vector.transfer_read %1[%c0], %du: memref, vector<36xf64> - vector.print %2 : vector<36xf64> + %2 = vector.transfer_read %1[%c0], %du: memref, vector<18xf64> + vector.print %2 : vector<18xf64> return } @@ -297,7 +312,8 @@ %sm44dc_dyn = sparse_tensor.convert %m44 : tensor<4x4xf64> to tensor // CHECK: ( ( 1, 0, 3, 0 ), ( 0, 2, 0, 0 ), ( 1, 0, 1, 1 ), ( 0, 0.5, 0, 0 ), ( 1, 5, 2, 0 ), ( 0, 0, 1.5, 1 ), ( 0, 3.5, 0, 0 ), ( 1, 5, 2, 0 ), ( 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 3, 2, 1, 0, 1, 1, 0, 0.5, 0, 0, 1, 5, 2, 0, 1.5, 1, 3.5, 1, 5, 2, 1, 0.5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 3, 2, 1, 1, 1, 0.5, 1, 5, 2, 1.5, 1, 3.5, 1, 5, 2, 1, 0.5 ) %0 = call @concat_sparse_sparse(%sm24cc, %sm34cd, %sm44dc) : (tensor<2x4xf64, #MAT_C_C>, tensor<3x4xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C>) -> tensor<9x4xf64, #MAT_C_C> call @dump_mat_9x4(%0) : (tensor<9x4xf64, #MAT_C_C>) -> () @@ -308,7 +324,8 @@ call @dump_mat_dense_9x4(%1) : (tensor<9x4xf64>) -> () // CHECK-NEXT: ( ( 1, 0, 3, 0 ), ( 0, 2, 0, 0 ), ( 1, 0, 1, 1 ), ( 0, 0.5, 0, 0 ), ( 1, 5, 2, 0 ), ( 0, 0, 1.5, 1 ), ( 0, 3.5, 0, 0 ), ( 1, 5, 2, 0 ), ( 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 3, 2, 1, 0, 1, 1, 0, 0.5, 0, 0, 1, 5, 2, 0, 1.5, 1, 3.5, 1, 5, 2, 1, 0.5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 3, 2, 1, 1, 1, 0.5, 1, 5, 2, 1.5, 1, 3.5, 1, 5, 2, 1, 0.5 ) %2 = call @concat_mix_sparse(%m24, %sm34cd, %sm44dc) : (tensor<2x4xf64>, tensor<3x4xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C>) -> tensor<9x4xf64, #MAT_C_C> call @dump_mat_9x4(%2) : (tensor<9x4xf64, #MAT_C_C>) -> () @@ -319,7 +336,8 @@ call @dump_mat_dense_9x4(%3) : (tensor<9x4xf64>) -> () // CHECK-NEXT: ( ( 1, 0, 3, 0 ), ( 0, 2, 0, 0 ), ( 1, 0, 1, 1 ), ( 0, 0.5, 0, 0 ), ( 1, 5, 2, 0 ), ( 0, 0, 1.5, 1 ), ( 0, 3.5, 0, 0 ), ( 1, 5, 2, 0 ), ( 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 1, 0, 1, 1, 1, 2, 0, 0.5, 5, 3.5, 5, 0.5, 3, 1, 0, 2, 1.5, 2, 1, 0, 0, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 1, 1, 1, 1, 2, 0.5, 5, 3.5, 5, 0.5, 3, 1, 2, 1.5, 2, 1, 1 ) %4 = call @concat_sparse_sparse_perm(%sm24ccp, %sm34cd, %sm44dc) : (tensor<2x4xf64, #MAT_C_C_P>, tensor<3x4xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C>) -> tensor<9x4xf64, #MAT_C_C_P> call @dump_mat_perm_9x4(%4) : (tensor<9x4xf64, #MAT_C_C_P>) -> () @@ -330,7 +348,8 @@ call @dump_mat_dense_9x4(%5) : (tensor<9x4xf64>) -> () // CHECK-NEXT: ( ( 1, 0, 3, 0 ), ( 0, 2, 0, 0 ), ( 1, 0, 1, 1 ), ( 0, 0.5, 0, 0 ), ( 1, 5, 2, 0 ), ( 0, 0, 1.5, 1 ), ( 0, 3.5, 0, 0 ), ( 1, 5, 2, 0 ), ( 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 3, 2, 1, 0, 1, 1, 0, 0.5, 0, 0, 1, 5, 2, 0, 1.5, 1, 3.5, 1, 5, 2, 1, 0.5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 3, 2, 1, 1, 1, 0.5, 1, 5, 2, 1.5, 1, 3.5, 1, 5, 2, 1, 0.5 ) %6 = call @concat_mix_sparse_perm(%m24, %sm34cdp, %sm44dc) : (tensor<2x4xf64>, tensor<3x4xf64, #MAT_C_D_P>, tensor<4x4xf64, #MAT_D_C>) -> tensor<9x4xf64, #MAT_C_C> call @dump_mat_9x4(%6) : (tensor<9x4xf64, #MAT_C_C>) -> () @@ -341,7 +360,8 @@ call @dump_mat_dense_9x4(%7) : (tensor<9x4xf64>) -> () // CHECK-NEXT: ( ( 1, 0, 1, 0, 1, 0, 0, 1.5, 1 ), ( 3.1, 0, 1, 0, 0.5, 0, 3.5, 0, 0 ), ( 0, 2, 0, 0, 1, 1, 5, 2, 0 ), ( 0, 0, 5, 2, 0, 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 1, 0, 1, 1.5, 1, 3.1, 1, 0, 0.5, 3.5, 2, 0, 0, 1, 1, 5, 2, 5, 2, 0, 1, 0.5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 1, 1, 1.5, 1, 3.1, 1, 0.5, 3.5, 2, 1, 1, 5, 2, 5, 2, 1, 0.5 ) %8 = call @concat_sparse_sparse_dim1(%sm42cc, %sm43cd, %sm44dc) : (tensor<4x2xf64, #MAT_C_C>, tensor<4x3xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C>) -> tensor<4x9xf64, #MAT_C_C> call @dump_mat_4x9(%8) : (tensor<4x9xf64, #MAT_C_C>) -> () @@ -352,7 +372,8 @@ call @dump_mat_dense_4x9(%9) : (tensor<4x9xf64>) -> () // CHECK-NEXT: ( ( 1, 0, 1, 0, 1, 0, 0, 1.5, 1 ), ( 3.1, 0, 1, 0, 0.5, 0, 3.5, 0, 0 ), ( 0, 2, 0, 0, 1, 1, 5, 2, 0 ), ( 0, 0, 5, 2, 0, 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 1, 0, 1, 1.5, 1, 3.1, 1, 0, 0.5, 3.5, 2, 0, 0, 1, 1, 5, 2, 5, 2, 0, 1, 0.5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 1, 1, 1.5, 1, 3.1, 1, 0.5, 3.5, 2, 1, 1, 5, 2, 5, 2, 1, 0.5 ) %10 = call @concat_mix_sparse_dim1(%m42, %sm43cd, %sm44dc) : (tensor<4x2xf64>, tensor<4x3xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C>) -> tensor<4x9xf64, #MAT_C_C> call @dump_mat_4x9(%10) : (tensor<4x9xf64, #MAT_C_C>) -> () @@ -363,7 +384,8 @@ call @dump_mat_dense_4x9(%11) : (tensor<4x9xf64>) -> () // CHECK-NEXT: ( ( 1, 0, 1, 0, 1, 0, 0, 1.5, 1 ), ( 3.1, 0, 1, 0, 0.5, 0, 3.5, 0, 0 ), ( 0, 2, 0, 0, 1, 1, 5, 2, 0 ), ( 0, 0, 5, 2, 0, 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 3.1, 2, 1, 1, 0, 5, 0, 0, 0, 2, 1, 0.5, 1, 0, 1, 1, 3.5, 5, 0.5, 1.5, 2, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 3.1, 2, 1, 1, 5, 2, 1, 0.5, 1, 1, 1, 3.5, 5, 0.5, 1.5, 2, 1 ) %12 = call @concat_sparse_sparse_perm_dim1(%sm42ccp, %sm43cd, %sm44dc) : (tensor<4x2xf64, #MAT_C_C_P>, tensor<4x3xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C>) -> tensor<4x9xf64, #MAT_C_C_P> call @dump_mat_perm_4x9(%12) : (tensor<4x9xf64, #MAT_C_C_P>) -> () @@ -374,7 +396,8 @@ call @dump_mat_dense_4x9(%13) : (tensor<4x9xf64>) -> () // CHECK-NEXT: ( ( 1, 0, 1, 0, 1, 0, 0, 1.5, 1 ), ( 3.1, 0, 1, 0, 0.5, 0, 3.5, 0, 0 ), ( 0, 2, 0, 0, 1, 1, 5, 2, 0 ), ( 0, 0, 5, 2, 0, 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 1, 0, 1, 1.5, 1, 3.1, 1, 0, 0.5, 3.5, 2, 0, 0, 1, 1, 5, 2, 5, 2, 0, 1, 0.5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 1, 1, 1.5, 1, 3.1, 1, 0.5, 3.5, 2, 1, 1, 5, 2, 5, 2, 1, 0.5 ) %14 = call @concat_mix_sparse_perm_dim1(%m42, %sm43cdp, %sm44dc) : (tensor<4x2xf64>, tensor<4x3xf64, #MAT_C_D_P>, tensor<4x4xf64, #MAT_D_C>) -> tensor<4x9xf64, #MAT_C_C> call @dump_mat_4x9(%14) : (tensor<4x9xf64, #MAT_C_C>) -> () @@ -385,7 +408,8 @@ call @dump_mat_dense_4x9(%15) : (tensor<4x9xf64>) -> () // CHECK-NEXT: ( ( 1, 0, 1, 0, 1, 0, 0, 1.5, 1 ), ( 3.1, 0, 1, 0, 0.5, 0, 3.5, 0, 0 ), ( 0, 2, 0, 0, 1, 1, 5, 2, 0 ), ( 0, 0, 5, 2, 0, 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 1, 0, 1, 1.5, 1, 3.1, 1, 0, 0.5, 3.5, 2, 0, 0, 1, 1, 5, 2, 5, 2, 0, 1, 0.5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 1, 1, 1.5, 1, 3.1, 1, 0.5, 3.5, 2, 1, 1, 5, 2, 5, 2, 1, 0.5 ) %16 = call @concat_mix_sparse_dyn(%m42, %sm43cd, %sm44dc) : (tensor<4x2xf64>, tensor<4x3xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C>) -> tensor call @dump_mat_dyn(%16) : (tensor) -> ()