diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -163,6 +163,30 @@ let hasVerifier = 1; } +def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", []>, + Arguments<(ins Variadic:$inputs, + IndexAttr:$dimension)>, + Results<(outs AnyRankedTensor:$result)> { + + let summary = "Concatenates a list of concatenate op"; + let description = [{ + The concatenation happens on the specified `dimension`. The resulting `dimension` + size is the sum of all the input dimension sizes, while all the other dimensions + should have the same size in the input and output tensors. + + + Example: + + ```mlir + %0 = sparse_tensor.concatenate %1, %2 { dimension = 0 : index } + : tensor<64x64xf64, #CSR>, tensor<64x64xf64, #CSR> to tensor<128x64xf64, #CSR> + ``` + }]; + + let assemblyFormat = "$inputs attr-dict `:` type($inputs) `to` type($result)"; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // Sparse Tensor Management Operations. These operations are "impure" in the // sense that they do not properly operate on SSA values. Instead, the behavior diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -352,6 +352,61 @@ return success(); } +LogicalResult ConcatenateOp::verify() { + auto dstTp = getType().cast(); + uint64_t concatDim = getDimension().getZExtValue(); + unsigned rank = dstTp.getRank(); + + for (auto input : getInputs()) { + auto inputRank = input.getType().cast().getRank(); + if (inputRank != rank) { + return emitError( + "All input tensors and output tensor should have the same rank."); + } + } + + for (unsigned i = 0; i < rank; i++) { + auto dstDim = dstTp.getShape()[i]; + if (i == concatDim) { + if (dstDim != ShapedType::kDynamicSize) { + unsigned sumDim = 0; + for (auto src : getInputs()) { + auto d = src.getType().cast().getShape()[i]; + if (d == ShapedType::kDynamicSize) + // If the output tensor has static dimension, yet there is an input + // tensor with dynamic dimension. + return emitError( + "Failed to verify the shaping rules with dynamically " + "shaped inputs"); + sumDim += d; + } + // If all dimension are statically known, the sum of all the input + // dimensions should be equal to the output dimension. + if (sumDim != dstDim) { + return emitError( + "The concatenation dimension of the output tensor should be the " + "sum of all the concatenation dimensions of the input tensors."); + } + } + } else { + int prev = dstDim; + for (auto src : getInputs()) { + auto d = src.getType().cast().getShape()[i]; + if (d != ShapedType::kDynamicSize) { + if (prev != ShapedType::kDynamicSize && d != prev) { + return emitError( + "All dimensions (expect for the concatenating one) " + "should be equal."); + } + prev = d; + } + } + } + } + + return success(); +} + LogicalResult ReduceOp::verify() { Type inputType = getX().getType(); LogicalResult regionResult = success(); diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -360,3 +360,62 @@ } return %r : f64 } + +// ----- + +#C = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}> +#DCC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed", "compressed"]}> +func.func @invalid_concat_rank_mismatch(%arg0: tensor<2xf64, #C>, + %arg1: tensor<3x4xf64, #DC>, + %arg2: tensor<4x4x4xf64, #DCC>) -> tensor<9x4xf64, #DC> { + // expected-error@+1 {{All input tensors and output tensor should have the same rank}} + %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} + : tensor<2xf64, #C>, + tensor<3x4xf64, #DC>, + tensor<4x4x4xf64, #DCC> to tensor<9x4xf64, #DC> + return %0 : tensor<9x4xf64, #DC> +} + +// ----- + +#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}> +func.func @invalid_concat_size_mismatch_dyn(%arg0: tensor, + %arg1: tensor<5x4xf64, #DC>, + %arg2: tensor<4x4xf64, #DC>) -> tensor<9x4xf64, #DC> { + // expected-error@+1 {{Failed to verify the shaping rules with dynamically shaped inputs}} + %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} + : tensor, + tensor<5x4xf64, #DC>, + tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC> + return %0 : tensor<9x4xf64, #DC> +} + +// ----- + +#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}> +func.func @invalid_concat_size_mismatch(%arg0: tensor<3x4xf64, #DC>, + %arg1: tensor<5x4xf64, #DC>, + %arg2: tensor<4x4xf64, #DC>) -> tensor<9x4xf64, #DC> { + // expected-error@+1 {{The concatenation dimension of the output tensor should be the sum of}} + %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} + : tensor<3x4xf64, #DC>, + tensor<5x4xf64, #DC>, + tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC> + return %0 : tensor<9x4xf64, #DC> +} + +// ----- + +#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}> +func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>, + %arg1: tensor<3x3xf64, #DC>, + %arg2: tensor<4x4xf64, #DC>) -> tensor<9x4xf64, #DC> { + // expected-error@+1 {{All dimensions (expect for the concatenating one) should be equal}} + %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} + : tensor<2x4xf64, #DC>, + tensor<3x3xf64, #DC>, + tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC> + return %0 : tensor<9x4xf64, #DC> +} + diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -289,4 +289,28 @@ sparse_tensor.yield %x : f64 } return %r : f64 -} \ No newline at end of file +} + +// ----- + +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// CHECK-LABEL: func @concat_sparse_sparse( +// CHECK-SAME: %[[A0:.*]]: tensor<2x4xf64 +// CHECK-SAME: %[[A1:.*]]: tensor<3x4xf64 +// CHECK-SAME: %[[A2:.*]]: tensor<4x4xf64 +// CHECK: %[[TMP0:.*]] = sparse_tensor.concatenate %[[A0]], %[[A1]], %[[A2]] {dimension = 0 : index} : +// CHECK-SAME: tensor<2x4xf64 +// CHECK-SAME: tensor<3x4xf64 +// CHECK-SAME: tensor<4x4xf64 +// CHECK-SAME: tensor<9x4xf64 +// CHECK: return %[[TMP0]] : tensor<9x4xf64 +func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>, + %arg1: tensor<3x4xf64, #SparseMatrix>, + %arg2: tensor<4x4xf64, #SparseMatrix>) -> tensor<9x4xf64, #SparseMatrix> { + %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} + : tensor<2x4xf64, #SparseMatrix>, + tensor<3x4xf64, #SparseMatrix>, + tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix> + return %0 : tensor<9x4xf64, #SparseMatrix> +}