diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -415,7 +415,8 @@ // and then use NonemptyVariadic<...>:$xs here. Arguments<(ins Index:$n, Variadic>:$xs, - Variadic>:$ys)> { + Variadic>:$ys, + UnitAttr:$stable)> { string summary = "Sorts the arrays in xs and ys lexicographically on the " "integral values found in the xs list"; string description = [{ @@ -437,6 +438,9 @@ is undefined if this condition is not met. The operator requires at least one buffer in `xs` while `ys` can be empty. + The `stable` attribute indicates whether a stable sorting algorithm should + be used to implement the operator. + Note that this operation is "impure" in the sense that its behavior is solely defined by side-effects and not SSA values. The semantics may be refined over time as our sparse abstractions evolve. @@ -447,10 +451,18 @@ sparse_tensor.sort %n, %x1, %x2 jointly y1, %y2 : memref, memref jointly memref, memref ``` + + ```mlir + sparse_tensor.sort stable %n, %x1, %x2 jointly y1, %y2 + : memref, memref jointly memref, memref + ``` }]; - let assemblyFormat = "$n `,` $xs (`jointly` $ys^)? attr-dict" + let assemblyFormat = "(`stable` $stable^)? $n" + "`,`$xs (`jointly` $ys^)? attr-dict" "`:` type($xs) (`jointly` type($ys)^)?"; - + let builders = [ + OpBuilder<(ins "Value":$n, "ValueRange":$xs, "ValueRange":$ys)> + ]; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -675,6 +675,11 @@ return success(); } +void SortOp::build(OpBuilder &odsBuilder, OperationState &odsState, Value n, + ValueRange xs, ValueRange ys) { + build(odsBuilder, odsState, n, xs, ys, /*stable=*/false); +} + LogicalResult SortOp::verify() { if (getXs().empty()) return emitError("need at least one xs buffer."); diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -423,3 +423,17 @@ sparse_tensor.sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64> return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64> } + +// ----- + +// CHECK-LABEL: func @sparse_sort_stable( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref<10xi8>, +// CHECK-SAME: %[[C:.*]]: memref<20xi8>, +// CHECK-SAME: %[[D:.*]]: memref<10xf64>) +// CHECK: sparse_tensor.sort stable %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64> +// CHECK: return %[[B]], %[[C]], %[[D]] +func.func @sparse_sort_stable(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) { + sparse_tensor.sort stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64> + return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64> +}