diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -161,7 +161,7 @@ let hasVerifier = 1; } -def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", []>, +def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [NoSideEffect]>, Arguments<(ins Variadic:$inputs, IndexAttr:$dimension)>, Results<(outs AnyRankedTensor:$result)> { diff --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir --- a/mlir/test/Dialect/SparseTensor/fold.mlir +++ b/mlir/test/Dialect/SparseTensor/fold.mlir @@ -32,3 +32,16 @@ %2 = sparse_tensor.values %arg0 : tensor<64xf32, #SparseVector> to memref return } +// CHECK-LABEL: func @sparse_concat_dce( +// CHECK-NOT: sparse_tensor.concatenate +// CHECK: return +func.func @sparse_concat_dce(%arg0: tensor<2xf64, #SparseVector>, + %arg1: tensor<3xf64, #SparseVector>, + %arg2: tensor<4xf64, #SparseVector>) { + %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} + : tensor<2xf64, #SparseVector>, + tensor<3xf64, #SparseVector>, + tensor<4xf64, #SparseVector> to tensor<9xf64, #SparseVector> + return +} +