diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -167,6 +167,25 @@ let hasVerifier = 1; } +def SparseTensor_NumberOfEntriesOp : SparseTensor_Op<"number_of_entries", [Pure]>, + Arguments<(ins AnySparseTensor:$tensor)>, + Results<(outs Index:$result)> { + let summary = "Returns the number of entries that are stored in the tensor."; + let description = [{ + Returns the number of entries that are stored in the given sparse tensor. + Note that this is typically the number of nonzero elements in the tensor, + but since explicit zeros may appear in the storage formats, the more + accurate nomenclature is used. + + Example: + + ```mlir + %noe = sparse_tensor.number_of_entries %tensor : tensor<64x64xf64, #CSR> + ``` + }]; + let assemblyFormat = "$tensor attr-dict `:` type($tensor)"; +} + def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>, Arguments<(ins Variadic:$inputs, IndexAttr:$dimension)>, Results<(outs AnyRankedTensor:$result)> { diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -91,6 +91,19 @@ // ----- +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> + +// CHECK-LABEL: func @sparse_noe( +// CHECK-SAME: %[[A:.*]]: tensor<128xf64, #{{.*}}>) +// CHECK: %[[T:.*]] = sparse_tensor.number_of_entries %[[A]] : tensor<128xf64, #{{.*}}> +// CHECK: return %[[T]] : index +func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index { + %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector> + return %0 : index +} + +// ----- + #DenseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","dense"]}> // CHECK-LABEL: func @sparse_load(