diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -490,7 +490,10 @@ // %t = concatenate %s1, %s2, %s3 {dim = 1} // ==> // if (isSparseDst) - // %tmp = bufferization.alloc_tensor : unordered COO + // if (allDense) + // %tmp = bufferization.alloc_tensor dstTp + // else + // %tmp = bufferization.alloc_tensor : unordered COO // else // %tmp = memref.alloc : dense tensor // foreach in %s1 : insert d0, d1, %tmp @@ -499,11 +502,18 @@ // %t = convert_to_dest_tensor(%tmp) SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); Value dst; // Destination tensor for inserting source tensor values. + bool allDense = false; if (encDst) { + allDense = llvm::all_of(encDst.getDimLevelType(), + [](DimLevelType dlt) { return isDenseDLT(dlt); }); SmallVector dynSizes; getDynamicSizes(dstTp, sizes, dynSizes); - RankedTensorType cooTp = getUnorderedCOOFromType(dstTp); - dst = rewriter.create(loc, cooTp, dynSizes).getResult(); + RankedTensorType tp = dstTp; + if (!allDense) { + tp = getUnorderedCOOFromType(dstTp); + encDst = getSparseTensorEncoding(tp); + } + dst = rewriter.create(loc, tp, dynSizes).getResult(); } else { // TODO: Dense buffers should be allocated/deallocated via the callback // in BufferizationOptions. @@ -523,13 +533,13 @@ loc, input, initArgs, [&](OpBuilder &builder, Location loc, ValueRange args, Value v, ValueRange reduc) { - SmallVector indices; + SmallVector indices(rank, Value()); for (int64_t i = 0; i < rank; i++) { Value idx = args[i]; if (i == static_cast(conDim)) // Transform coordinates for the concatenating dim. idx = builder.create(loc, idx, offset); - indices.push_back(idx); + indices[toStoredDim(encDst, i)] = idx; } if (encDst) { Value cond = genIsNonzero(rewriter, loc, v); @@ -563,9 +573,12 @@ if (encDst) { dst = rewriter.create(loc, dst, true); - Value converted = rewriter.create(loc, dstTp, dst).getResult(); - rewriter.create(loc, dst); - rewriter.replaceOp(op, converted); + if (!allDense) { + Value tmpCoo = dst; + dst = rewriter.create(loc, dstTp, tmpCoo).getResult(); + rewriter.create(loc, tmpCoo); + } + rewriter.replaceOp(op, dst); } else { rewriter.replaceOpWithNewOp(op, dstTp, dst); } diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir @@ -2,7 +2,11 @@ // RUN: | FileCheck %s #DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> - +#DENSE = #sparse_tensor.encoding<{dimLevelType = ["dense", "dense"]}> +#DENSE_P = #sparse_tensor.encoding<{ + dimLevelType = ["dense", "dense"], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> // CHECK-LABEL: @concat_sparse_sparse( // CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor // CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor @@ -256,3 +260,175 @@ tensor<4x4xf64, #DCSR> to tensor return %0 : tensor } + +// CHECK-LABEL: @concat_sparse_sparse_annotated_dense( +// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor +// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor +// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor +// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[TMP_c9:.*]] = arith.constant 9 : index +// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index +// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor(%[[TMP_c9]], %[[TMP_c4]]) : tensor +// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref +// CHECK: %[[RET_1:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] iter_args(%[[A0:.*]] = %[[TMP_0]]) +// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref +// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref +// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index +// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref +// CHECK: %[[RET_4:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A1:.*]] = %[[A0]]) +// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref +// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref +// CHECK: %[[NEW_1:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor +// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref +// CHECK: %[[RET_2:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]] iter_args(%[[A2:.*]] = %[[RET_1]]) +// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref +// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref +// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index +// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref +// CHECK: %[[RET_5:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A3:.*]] = %[[A2]]) +// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref +// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref +// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index +// CHECK: %[[NEW_2:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A3]][%[[TMP_29]], %[[TMP_27]]] : tensor +// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref +// CHECK: %[[RET_3:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]] iter_args(%[[A4:.*]] = %[[RET_2]]) +// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref +// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref +// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index +// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref +// CHECK: %[[RET_6:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A5:.*]] = %[[A4]]) +// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref +// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref +// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index +// CHECK: %[[NEW_3:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A5]][%[[TMP_29]], %[[TMP_27]]] : tensor> +// CHECK: return %[[R]] : tensor> +func.func @concat_sparse_sparse_annotated_dense(%arg0: tensor<2x4xf64, #DCSR>, + %arg1: tensor<3x4xf64, #DCSR>, + %arg2: tensor<4x4xf64, #DCSR>) + -> tensor { + %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} + : tensor<2x4xf64, #DCSR>, + tensor<3x4xf64, #DCSR>, + tensor<4x4xf64, #DCSR> to tensor + return %0 : tensor +} + +// CHECK-LABEL: @concat_sparse_sparse_annotated_dense_permute( +// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor +// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor +// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor +// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[TMP_c9:.*]] = arith.constant 9 : index +// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index +// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor(%[[TMP_c9]], %[[TMP_c4]]) : tensor +// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref +// CHECK: %[[RET_1:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] iter_args(%[[A0:.*]] = %[[TMP_0]]) +// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref +// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref +// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index +// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref +// CHECK: %[[RET_4:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A1:.*]] = %[[A0]]) +// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref +// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref +// CHECK: %[[NEW_1:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_27]], %[[TMP_23]]] : tensor +// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref +// CHECK: %[[RET_2:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]] iter_args(%[[A2:.*]] = %[[RET_1]]) +// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref +// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref +// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index +// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref +// CHECK: %[[RET_5:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A3:.*]] = %[[A2]]) +// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref +// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref +// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index +// CHECK: %[[NEW_2:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A3]][%[[TMP_27]], %[[TMP_29]]] : tensor +// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref +// CHECK: %[[RET_3:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]] iter_args(%[[A4:.*]] = %[[RET_2]]) +// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref +// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref +// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index +// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref +// CHECK: %[[RET_6:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A5:.*]] = %[[A4]]) +// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref +// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref +// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index +// CHECK: %[[NEW_3:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A5]][%[[TMP_27]], %[[TMP_29]]] : tensor (d1, d0)> }>> +// CHECK: return %[[R]] : tensor (d1, d0)> }>> +func.func @concat_sparse_sparse_annotated_dense_permute(%arg0: tensor<2x4xf64, #DCSR>, + %arg1: tensor<3x4xf64, #DCSR>, + %arg2: tensor<4x4xf64, #DCSR>) + -> tensor { + %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} + : tensor<2x4xf64, #DCSR>, + tensor<3x4xf64, #DCSR>, + tensor<4x4xf64, #DCSR> to tensor + return %0 : tensor +} \ No newline at end of file diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir @@ -14,6 +14,10 @@ #MAT_C_C = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> #MAT_D_C = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}> #MAT_C_D = #sparse_tensor.encoding<{dimLevelType = ["compressed", "dense"]}> +#MAT_D_D = #sparse_tensor.encoding<{ + dimLevelType = ["dense", "dense"], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> #MAT_C_C_P = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], @@ -49,6 +53,13 @@ return %0 : tensor<9x4xf64> } + // Concats all sparse matrices (with different encodings) to a annotated all dense matrix. + func.func @concat_sparse_annotated_dense(%arg0: tensor<2x4xf64, #MAT_C_C>, %arg1: tensor<3x4xf64, #MAT_C_D>, %arg2: tensor<4x4xf64, #MAT_D_C>) -> tensor<9x4xf64, #MAT_D_D> { + %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} + : tensor<2x4xf64, #MAT_C_C>, tensor<3x4xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C> to tensor<9x4xf64, #MAT_D_D> + return %0 : tensor<9x4xf64, #MAT_D_D> + } + // Concats mix sparse and dense matrices to a sparse matrix func.func @concat_mix_sparse(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #MAT_C_D>, %arg2: tensor<4x4xf64, #MAT_D_C>) -> tensor<9x4xf64, #MAT_C_C> { %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} @@ -214,6 +225,20 @@ return } + func.func @dump_mat_annotated_dense_9x4(%A: tensor<9x4xf64, #MAT_D_D>) { + %c0 = arith.constant 0 : index + %du = arith.constant -1.0 : f64 + + %n = sparse_tensor.number_of_entries %A : tensor<9x4xf64, #MAT_D_D> + vector.print %n : index + + %1 = sparse_tensor.values %A : tensor<9x4xf64, #MAT_D_D> to memref + %2 = vector.transfer_read %1[%c0], %du: memref, vector<36xf64> + vector.print %2 : vector<36xf64> + + return + } + func.func @dump_mat_4x9(%A: tensor<4x9xf64, #MAT_C_C>) { %c0 = arith.constant 0 : index %du = arith.constant -1.0 : f64 @@ -421,6 +446,13 @@ : (tensor<4x2xf64>, tensor<4x3xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C>) -> tensor call @dump_mat_dyn(%16) : (tensor) -> () + // CHECK-NEXT: 36 + // CHECK-NEXT: ( 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 2, 0, 0.5, 5, 0, 3.5, 5, 0.5, 3, 0, 1, 0, 2, 1.5, 0, 2, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0 ) + %17 = call @concat_sparse_annotated_dense(%sm24cc, %sm34cd, %sm44dc) + : (tensor<2x4xf64, #MAT_C_C>, tensor<3x4xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C>) -> tensor<9x4xf64, #MAT_D_D> + call @dump_mat_annotated_dense_9x4(%17) : (tensor<9x4xf64, #MAT_D_D>) -> () + + // Release resources. bufferization.dealloc_tensor %sm24cc : tensor<2x4xf64, #MAT_C_C> bufferization.dealloc_tensor %sm34cd : tensor<3x4xf64, #MAT_C_D> @@ -449,6 +481,7 @@ bufferization.dealloc_tensor %14 : tensor<4x9xf64, #MAT_C_C> bufferization.dealloc_tensor %15 : tensor<4x9xf64> bufferization.dealloc_tensor %16 : tensor + bufferization.dealloc_tensor %17 : tensor<9x4xf64, #MAT_D_D> return } }