diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -392,6 +392,50 @@ let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)"; } +def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>, + // TODO: May want to extend tablegen with + // class NonemptyVariadic : Variadic { let minSize = 1; } + // and then use NonemptyVariadic<...>:$xs here. + Arguments<(ins Index:$n, + Variadic>:$xs, + Variadic>:$ys)> { + string summary = "Sorts the arrays in xs and ys lexicographically on the " + "integral values found in the xs list"; + string description = [{ + Lexicographically sort the first `n` values in `xs` along with the values in + `ys`. Conceptually, the values being sorted are tuples produced by + zip(zip(xs), zip(ys)). In particular, values in `ys` needed to be sorted + along with values in `xs`, but values in `ys` don't affect the + lexicographical order. The order in which arrays appear in `xs` affects the + sorting result. The operator updates `xs` and `ys` in place with the result + of the sorting. + + For example, assume x1=[4, 3], x2=[1, 2], y1=[10, 5], then the output of + "sort 2, x1, x2 jointly y1" are x1=[3, 4], x2=[2, 1], y1=[5, 10] while the + output of "sort 2, x2, x1, jointly y1" are x2=[1, 2], x1=[4, 3], y1=[10, 5]. + + Buffers in `xs` needs to have the same integral element type while buffers + in `ys` can have different numeric element types. All buffers in `xs` and + `ys` should have a dimension not less than `n`. The operator requires at + least one buffer in `xs` while `ys` can be empty. + + Note that this operation is "impure" in the sense that its behavior is + solely defined by side-effects and not SSA values. The semantics may be + refined over time as our sparse abstractions evolve. + + Example: + + ```mlir + sparse_tensor.sort %n, %x1, %x2 jointly y1, %y2 + : memref, memref jointly memref, memref + ``` + }]; + let assemblyFormat = "$n `,` $xs (`jointly` $ys^)? attr-dict" + "`:` type($xs) (`jointly` type($ys)^)?"; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // Sparse Tensor Syntax Operations. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt @@ -9,6 +9,7 @@ MLIRSparseTensorOpsIncGen LINK_LIBS PUBLIC + MLIRArithmeticDialect MLIRDialect MLIRIR MLIRInferTypeOpInterface diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" @@ -505,6 +506,41 @@ return success(); } +LogicalResult SortOp::verify() { + if (getXs().empty()) + return emitError("need at least one xs buffer."); + + auto n = getN().getDefiningOp(); + + Type xtp = getXs().front().getType().cast().getElementType(); + auto checkTypes = [&](ValueRange operands, + bool checkEleType = true) -> LogicalResult { + for (Value opnd : operands) { + MemRefType mtp = opnd.getType().cast(); + uint64_t dim = mtp.getShape()[0]; + // We can't check the size of dynamic dimension at compile-time, but all + // xs and ys should have a dimension not less than n at runtime. + if (n && dim != ShapedType::kDynamicSize && dim < n.value()) + return emitError(llvm::formatv("xs and ys need to have a dimension >= n" + ": {0} < {1}", + dim, n.value())); + + if (checkEleType && xtp != mtp.getElementType()) + return emitError("mismatch xs element types"); + } + return success(); + }; + + LogicalResult result = checkTypes(getXs()); + if (failed(result)) + return result; + + if (n) + return checkTypes(getYs(), false); + + return success(); +} + LogicalResult YieldOp::verify() { // Check for compatible parent. auto *parentOp = (*this)->getParentOp(); diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -501,3 +501,29 @@ } return } + +// ----- + +// TODO: a test case with empty xs doesn't work due to some parser issues. + +func.func @sparse_sort_x_type( %arg0: index, %arg1: memref) { + // expected-error@+1 {{operand #1 must be 1D memref of integer or index values}} + sparse_tensor.sort %arg0, %arg1: memref +} + +// ----- + +func.func @sparse_sort_dim_too_small(%arg0: memref<10xindex>) { + %i20 = arith.constant 20 : index + // expected-error@+1 {{xs and ys need to have a dimension >= n: 10 < 20}} + sparse_tensor.sort %i20, %arg0 : memref<10xindex> + return +} + +// ----- + +func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<10xi8>) { + // expected-error@+1 {{mismatch xs element types}} + sparse_tensor.sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8> + return +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -362,3 +362,43 @@ } return } + +// ---- + +// CHECK-LABEL: func @sparse_sort_1d0v( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref) +// CHECK: sparse_tensor.sort %[[A]], %[[B]] : memref +// CHECK: return %[[B]] +func.func @sparse_sort_1d0v(%arg0: index, %arg1: memref) -> (memref) { + sparse_tensor.sort %arg0, %arg1 : memref + return %arg1 : memref +} + +// ----- + +// CHECK-LABEL: func @sparse_sort_1d2v( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref<20xindex>, +// CHECK-SAME: %[[C:.*]]: memref<10xindex>, +// CHECK-SAME: %[[D:.*]]: memref) +// CHECK: sparse_tensor.sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref<20xindex> jointly memref<10xindex>, memref +// CHECK: return %[[B]], %[[C]], %[[D]] +func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<20xindex>, %arg2: memref<10xindex>, %arg3: memref) -> (memref<20xindex>, memref<10xindex>, memref) { + sparse_tensor.sort %arg0, %arg1 jointly %arg2, %arg3 : memref<20xindex> jointly memref<10xindex>, memref + return %arg1, %arg2, %arg3 : memref<20xindex>, memref<10xindex>, memref +} + +// ----- + +// CHECK-LABEL: func @sparse_sort_2d1v( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref<10xi8>, +// CHECK-SAME: %[[C:.*]]: memref<20xi8>, +// CHECK-SAME: %[[D:.*]]: memref<10xf64>) +// CHECK: sparse_tensor.sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64> +// CHECK: return %[[B]], %[[C]], %[[D]] +func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) { + sparse_tensor.sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64> + return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64> +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2030,6 +2030,7 @@ hdrs = ["include/mlir/Dialect/SparseTensor/IR/SparseTensor.h"], includes = ["include"], deps = [ + ":ArithmeticDialect", ":IR", ":InferTypeOpInterface", ":SparseTensorAttrDefsIncGen",