diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -388,6 +388,43 @@ let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)"; } +def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>, + Arguments<(ins Index:$n, + Variadic>:$xs, + Variadic>:$ys)> { + string summary = "Sorts values based on coordinates"; + string description = [{ + Lexicographically sort the first `n` values in `xs` along with the values in + `ys`. Values in `ys` needed to be sorted along with values in `xs` but don't + affect the lexicographical order. This operator updates `xs` and `ys` in + place with the result of the sort. + + For example, assume x1=[4, 3], x2=[1, 2], y1=[10, 5], then the output of + "sort 2, x1, x2 jointly y1" are x1=[3, 4], x2=[2, 1], y1=[5, 10] while the + output of "sort 2, x2, x1, jointly y1" are x2=[1, 2], x1=[4, 3], y1=[10, 5]. + + Buffers in `xs` needs to have the same element type while buffers in `ys` + can have different element types. All buffers in `xs` and `ys` should have + the same dimension at runtime. The operator requires at least one buffer in + `xs` while `ys` can be empty. + + Note that this operation is "impure" in the sense that its behavior is + solely defined by side-effects and not SSA values. The semantics may be + refined over time as our sparse abstractions evolve. + + Example: + + ```mlir + sparse_tensor.sort %n, %x1, %x2 jointly y1, %y2 + : memref, memref jointly memref, memref + ``` + }]; + let assemblyFormat = "$n `,` $xs (`jointly` $ys^)? attr-dict" + "`:` type($xs) (`jointly` type($ys)^)?"; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // Sparse Tensor Syntax Operations. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -505,6 +505,42 @@ return success(); } +LogicalResult SortOp::verify() { + if (getXs().empty()) + return emitError("need at least one xs buffer."); + + auto mtp = getXs().front().getType().cast(); + Type etp = mtp.getElementType(); + int64_t dim = mtp.getShape()[0]; + auto checkTypes = [&](ValueRange operands, + bool checkEleType = true) -> LogicalResult { + for (Value opnd : operands) { + MemRefType cur_mtp = opnd.getType().cast(); + uint64_t cur_dim = cur_mtp.getShape()[0]; + // We can't check the size of dynamic dimension at compile-time, but the + // definition of the operator requires all xs and ys have the same + // dimension at runtime. + if (dim == ShapedType::kDynamicSize && + cur_dim != ShapedType::kDynamicSize) + dim = cur_dim; + if (dim != cur_dim && cur_dim != ShapedType::kDynamicSize) + return emitError( + llvm::formatv("xs and ys need to have matching dimension" + ": {0} != {1}", + dim, cur_dim)); + if (checkEleType && etp != cur_mtp.getElementType()) + return emitError("mismatch xs element types"); + } + return success(); + }; + + LogicalResult result = checkTypes(getXs().drop_front()); + if (failed(result)) + return result; + + return checkTypes(getYs(), false); +} + LogicalResult YieldOp::verify() { // Check for compatible parent. auto *parentOp = (*this)->getParentOp(); diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -501,3 +501,26 @@ } return } + +// ----- + +func.func @sparse_sort_x_type( %arg0: index, %arg1: memref) { + // expected-error@+1 {{operand #1 must be 1D memref of integer or index values}} + sparse_tensor.sort %arg0, %arg1: memref +} + +// ----- + +func.func @sparse_sort_mismatched_dims(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<20xf32>) { + // expected-error@+1 {{xs and ys need to have matching dimension: 10 != 20}} + sparse_tensor.sort %arg0, %arg1 jointly %arg2 : memref<10xindex> jointly memref<20xf32> + return +} + +// ----- + +func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<10xi8>) { + // expected-error@+1 {{mismatch xs element types}} + sparse_tensor.sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8> + return +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -362,3 +362,43 @@ } return } + +// ---- + +// CHECK-LABEL: func @sparse_sort_1d0v( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref) +// CHECK: sparse_tensor.sort %[[A]], %[[B]] : memref +// CHECK: return %[[B]] +func.func @sparse_sort_1d0v(%arg0: index, %arg1: memref) -> (memref) { + sparse_tensor.sort %arg0, %arg1 : memref + return %arg1 : memref +} + +// ----- + +// CHECK-LABEL: func @sparse_sort_1d2v( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref, +// CHECK-SAME: %[[C:.*]]: memref<10xindex>, +// CHECK-SAME: %[[D:.*]]: memref) +// CHECK: sparse_tensor.sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref jointly memref<10xindex>, memref +// CHECK: return %[[B]], %[[C]], %[[D]] +func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref, %arg2: memref<10xindex>, %arg3: memref) -> (memref, memref<10xindex>, memref) { + sparse_tensor.sort %arg0, %arg1 jointly %arg2, %arg3 : memref jointly memref<10xindex>, memref + return %arg1, %arg2, %arg3 : memref, memref<10xindex>, memref +} + +// ----- + +// CHECK-LABEL: func @sparse_sort_2d1v( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref<10xi8>, +// CHECK-SAME: %[[C:.*]]: memref, +// CHECK-SAME: %[[D:.*]]: memref<10xf64>) +// CHECK: sparse_tensor.sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref jointly memref<10xf64> +// CHECK: return %[[B]], %[[C]], %[[D]] +func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref, %arg3: memref<10xf64>) -> (memref<10xi8>, memref, memref<10xf64>) { + sparse_tensor.sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref jointly memref<10xf64> + return %arg1, %arg2, %arg3 : memref<10xi8>, memref, memref<10xf64> +}