diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -154,6 +154,34 @@ let hasVerifier = 1; } +def SparseTensor_ToIndicesBufferOp : SparseTensor_Op<"indices_buffer", [Pure]>, + Arguments<(ins AnySparseTensor:$tensor)>, + Results<(outs AnyStridedMemRefOfRank<1>:$result)> { + let summary = "Extracts the linear indices array from a tensor"; + let description = [{ + Returns the linear indices array for a sparse tensor with a trailing COO + region with at least two dimensions. It is an error if the tensor doesn't + contain such a COO region. This is similar to the `bufferization.to_memref` + operation in the sense that it provides a bridge between a tensor world view + and a bufferized world view. Unlike the `bufferization.to_memref` operation, + however, this sparse operation actually lowers into code that extracts the + linear indices array from the sparse storage scheme that stores the indices + for the COO region as an array of structures. For example, a 2D COO sparse + tensor with two non-zero elements at coordinates (1, 3) and (4, 6) are + stored in a linear buffer as (1, 4, 3, 6) instead of two buffer as (1, 4) + and (3, 6). + + Example: + + ```mlir + %1 = sparse_tensor.indices_buffer %0 + : tensor<64x64xf64, #COO> to memref + ``` + }]; + let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)"; + let hasVerifier = 1; +} + def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [Pure]>, Arguments<(ins AnySparseTensor:$tensor)>, Results<(outs AnyStridedMemRefOfRank<1>:$result)> { diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -496,6 +496,13 @@ return success(); } +LogicalResult ToIndicesBufferOp::verify() { + auto e = getSparseTensorEncoding(getTensor().getType()); + if (getCOOStart(e) >= e.getDimLevelType().size()) + return emitError("expected sparse tensor with a COO region"); + return success(); +} + LogicalResult ToValuesOp::verify() { RankedTensorType ttp = getTensor().getType().cast(); MemRefType mtp = getResult().getType().cast(); diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -90,6 +90,24 @@ #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +func.func @indices_buffer_noncoo(%arg0: tensor<128xf64, #SparseVector>) -> memref { + // expected-error@+1 {{expected sparse tensor with a COO region}} + %0 = sparse_tensor.indices_buffer %arg0 : tensor<128xf64, #SparseVector> to memref + return %0 : memref +} + +// ----- + +func.func @indices_buffer_dense(%arg0: tensor<1024xf32>) -> memref { + // expected-error@+1 {{must be sparse tensor of any type values}} + %0 = sparse_tensor.indices_buffer %arg0 : tensor<1024xf32> to memref + return %0 : memref +} + +// ----- + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> + func.func @mismatch_values_types(%arg0: tensor) -> memref { // expected-error@+1 {{unexpected mismatch in element types}} %0 = sparse_tensor.values %arg0 : tensor to memref diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -78,6 +78,19 @@ // ----- +#COO = #sparse_tensor.encoding<{dimLevelType = ["compressed-nu", "singleton"]}> + +// CHECK-LABEL: func @sparse_indices_buffer( +// CHECK-SAME: %[[A:.*]]: tensor) +// CHECK: %[[T:.*]] = sparse_tensor.indices_buffer %[[A]] : tensor to memref +// CHECK: return %[[T]] : memref +func.func @sparse_indices_buffer(%arg0: tensor) -> memref { + %0 = sparse_tensor.indices_buffer %arg0 : tensor to memref + return %0 : memref +} + +// ----- + #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> // CHECK-LABEL: func @sparse_indices(