diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -388,6 +388,37 @@ let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)"; } +def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>, + Arguments<(ins Index:$n, + Variadic>:$coordinates, + Variadic>:$values)> { + string summary = "Sorts values based on coordinates"; + string description = [{ + Sort the first `n` values in buffers provided by `coordinates` and `values` + so that the resulting `coordinates` have non-decreasing values. The + `coordinates` needs to have the same element type while the `values` can + have different element types. All buffers in `coordinates` and `values` + should have the same dimensions. The operator requires at least one buffer + for `coordinates` while `values` can be empty. This operator modifies + `coordinates` and `values` to reflect the result of the sorting. + + Note that this operation is "impure" in the sense that its behavior is + solely defined by side-effects and not SSA values. The semantics may be + refined over time as our sparse abstractions evolve. + + Example: + + ```mlir + sparse_tensor.sort $n, %coords1, %coords2 (%values1, %values2) + : memref, memref (memref, memref) + ``` + }]; + let assemblyFormat = "$n `,` $coordinates (`(`$values^`)`)? attr-dict " + "`:` type($coordinates) (`(`type($values)^`)`)?"; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // Sparse Tensor Syntax Operations. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -505,6 +505,40 @@ return success(); } +LogicalResult SortOp::verify() { + if (getCoordinates().empty()) + return emitError("need at least one coordinate buffer."); + + auto mtp = getCoordinates().front().getType().cast(); + Type etp = mtp.getElementType(); + int64_t dim = mtp.getShape()[0]; + auto checkTypes = [&](ValueRange operands, + bool checkEleType = true) -> LogicalResult { + for (size_t i = 0, e = operands.size(); i < e; i++) { + Value opnd = operands[i]; + MemRefType cur_mtp = opnd.getType().cast(); + uint64_t cur_dim = cur_mtp.getShape()[0]; + if (dim == ShapedType::kDynamicSize && + cur_dim != ShapedType::kDynamicSize) + dim = cur_dim; + if (dim != cur_dim && cur_dim != ShapedType::kDynamicSize) + return emitError( + llvm::formatv("buffers for values and coordinates need " + "to have matching dimension: {0} != {1}", + dim, cur_dim)); + if (checkEleType && etp != cur_mtp.getElementType()) + return emitError("mismatch in index array element types"); + } + return success(); + }; + + LogicalResult result = checkTypes(getCoordinates().drop_front()); + if (failed(result)) + return result; + + return checkTypes(getValues(), false); +} + LogicalResult YieldOp::verify() { // Check for compatible parent. auto *parentOp = (*this)->getParentOp(); diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -501,3 +501,26 @@ } return } + +// ----- + +func.func @sparse_sort_coordinate_type( %arg0: index, %arg1: memref) { + // expected-error@+1 {{operand #1 must be 1D memref of integer or index values}} + sparse_tensor.sort %arg0, %arg1: memref +} + +// ----- + +func.func @sparse_sort_inconsistent(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<20xf32>) { + // expected-error@+1 {{buffers for values and coordinates need to have matching dimension: 10 != 20}} + sparse_tensor.sort %arg0, %arg1 (%arg2) : memref<10xindex> (memref<20xf32>) + return +} + +// ----- + +func.func @sparse_sort_indices(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<10xi8>) { + // expected-error@+1 {{mismatch in index array element types}} + sparse_tensor.sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8> + return +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -362,3 +362,43 @@ } return } + +// ---- + +// CHECK-LABEL: func @sparse_sort_1dv0( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref) +// CHECK: sparse_tensor.sort %[[A]], %[[B]] : memref +// CHECK: return %[[B]] +func.func @sparse_sort_1dv0(%arg0: index, %arg1: memref) -> (memref) { + sparse_tensor.sort %arg0, %arg1 : memref + return %arg1 : memref +} + +// ----- + +// CHECK-LABEL: func @sparse_sort_1dv2( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref, +// CHECK-SAME: %[[C:.*]]: memref<10xindex>, +// CHECK-SAME: %[[D:.*]]: memref) +// CHECK: sparse_tensor.sort %[[A]], %[[B]](%[[C]], %[[D]]) : memref(memref<10xindex>, memref) +// CHECK: return %[[B]], %[[C]], %[[D]] +func.func @sparse_sort_1dv2(%arg0: index, %arg1: memref, %arg2: memref<10xindex>, %arg3: memref) -> (memref, memref<10xindex>, memref) { + sparse_tensor.sort %arg0, %arg1 (%arg2, %arg3) : memref (memref<10xindex>, memref) + return %arg1, %arg2, %arg3 : memref, memref<10xindex>, memref +} + +// ----- + +// CHECK-LABEL: func @sparse_sort_2d( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref<10xi8>, +// CHECK-SAME: %[[C:.*]]: memref, +// CHECK-SAME: %[[D:.*]]: memref<10xf64>) +// CHECK: sparse_tensor.sort %[[A]], %[[B]], %[[C]](%[[D]]) : memref<10xi8>, memref(memref<10xf64>) +// CHECK: return %[[B]], %[[C]], %[[D]] +func.func @sparse_sort_2d(%arg0: index, %arg1: memref<10xi8>, %arg2: memref, %arg3: memref<10xf64>) -> (memref<10xi8>, memref, memref<10xf64>) { + sparse_tensor.sort %arg0, %arg1, %arg2 (%arg3) : memref<10xi8>, memref (memref<10xf64>) + return %arg1, %arg2, %arg3 : memref<10xi8>, memref, memref<10xf64> +}