diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -388,6 +388,28 @@ let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)"; } +def SparseTensor_SortOp : SparseTensor_Op<"sort", []>, + Arguments<(ins Index:$n, + StridedMemRefRankOf<[AnyType], [1]>:$values, + Variadic>:$coordinates)> { + string summary = "Sorts COO data with values and coordinates"; + string description = [{ + Sort the first `n` values in `values` by non-decreasing coordinate values + provided by `coordinates`. This operator modifies `values` and `coordinates` + with the result of the sort. + + Example: + + ```mlir + sparse_tensor.sort %n, %values, %coordinate + : memref, memref + ``` + }]; + let assemblyFormat = "$n `,` $values (`,` $coordinates^)? attr-dict " + "`:` type($values) (`,` type($coordinates)^)?"; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // Sparse Tensor Custom Linalg.Generic Operations. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -452,6 +452,30 @@ return success(); } +LogicalResult SortOp::verify() { + if (getCoordinates().empty()) + return emitError("Need at least one coordinate buffer."); + + int64_t dim = getValues().getType().cast().getShape()[0]; + Type etp = + getCoordinates().front().getType().cast().getElementType(); + for (size_t i = 0; i < getCoordinates().size(); i++) { + Value coord = getCoordinates()[i]; + MemRefType mtp = coord.getType().cast(); + uint64_t cur_dim = mtp.getShape()[0]; + if (dim == ShapedType::kDynamicSize || cur_dim == ShapedType::kDynamicSize) + dim = ShapedType::kDynamicSize; + else if (dim != cur_dim) + return emitError(llvm::formatv("Buffers for values and coordinates need " + "to have matching dimension: {0} != {1}", + dim, cur_dim)); + if (etp != mtp.getElementType()) + return emitError("Mismatch in index array element types"); + } + + return success(); +} + LogicalResult ReduceOp::verify() { Type inputType = getX().getType(); LogicalResult regionResult = success(); diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -468,3 +468,35 @@ tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC> return %0 : tensor<9x4xf64, #DC> } + +// ----- + +func.func @sparse_sort_coordinate_type(%arg0: index, %arg1: memref,%arg2: memref) { + // expected-error@+1 {{op operand #2 must be 1D memref of integer or index values}} + sparse_tensor.sort %arg0, %arg1, %arg2 : memref, memref +} + +// ----- + +func.func @sparse_sort_no_coordinate(%arg0: index, %arg1: memref<20xf32>) { + // expected-error@+1 {{Need at least one coordinate buffer}} + sparse_tensor.sort %arg0, %arg1 : memref<20xf32> + return +} + +// ----- + +func.func @sparse_sort_inconsistent(%arg0: index, %arg1: memref<20xf32>, %arg2: memref<10xindex>) { + // expected-error@+1 {{Buffers for values and coordinates need to have matching dimension: 20 != 10}} + sparse_tensor.sort %arg0, %arg1, %arg2 : memref<20xf32>, memref<10xindex> + return +} + +// ----- + +func.func @sparse_sort_indices(%arg0: index, %arg1: memref<10xf64>, %arg2: memref<10xindex>, %arg3: memref<10xi8>) + -> (memref<10xf64>, memref<10xindex>, memref<10xi8>) { + // expected-error@+1 {{Mismatch in index array element types}} + sparse_tensor.sort %arg0, %arg1, %arg2, %arg3 : memref<10xf64>, memref<10xindex>, memref<10xi8> + return %arg1, %arg2, %arg3 : memref<10xf64>, memref<10xindex>, memref<10xi8> +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -347,3 +347,30 @@ tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix> return %0 : tensor<9x4xf64, #SparseMatrix> } + +// ----- + +// CHECK-LABEL: func @sparse_sort_1d( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref, +// CHECK-SAME: %[[C:.*]]: memref) +// CHECK: sparse_tensor.sort %[[A]], %[[B]], %[[C]] : memref, memref +// CHECK: return %[[B]], %[[C]] +func.func @sparse_sort_1d(%arg0: index, %arg1: memref, %arg2: memref) -> (memref, memref) { + sparse_tensor.sort %arg0, %arg1, %arg2 : memref, memref + return %arg1, %arg2 : memref, memref +} + +// ----- + +// CHECK-LABEL: func @sparse_sort_2d( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref<10xf64>, +// CHECK-SAME: %[[C:.*]]: memref, +// CHECK-SAME: %[[D:.*]]: memref<10xi8>) +// CHECK: sparse_tensor.sort %[[A]], %[[B]], %[[C]], %[[D]] : memref<10xf64>, memref, memref<10xi8> +// CHECK: return %[[B]], %[[C]], %[[D]] +func.func @sparse_sort_2d(%arg0: index, %arg1: memref<10xf64>, %arg2: memref, %arg3: memref<10xi8>) -> (memref<10xf64>, memref, memref<10xi8>) { + sparse_tensor.sort %arg0, %arg1, %arg2, %arg3 : memref<10xf64>, memref, memref<10xi8> + return %arg1, %arg2, %arg3 : memref<10xf64>, memref, memref<10xi8> +}