diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -325,4 +325,39 @@ def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>; +//===----------------------------------------------------------------------===// +// Sparse Tensor Sorting Algorithm Attribute. +//===----------------------------------------------------------------------===// + +// TODO: Currently, we only provide four implementations, and expose the +// implementations via attribute algorithm. In the future, if we will need +// to support both stable and non-stable quick sort, we may add +// quick_sort_nonstable enum to the attribute. Alternative, we may use two +// attributes, (stable|nonstable, algorithm), to specify a sorting +// implementation. +// +// -------------------------------------------------------------------------- +// | | hybrid_qsort| insertion_sort | qsort | heap_sort. | +// |non-stable | Impl | X | Impl | Impl | +// |stable | X | Impl | Not Impl | X | +// -------------------------------------------------------------------------- + +// The C++ enum for sparse tensor sort kind. +def SparseTensorSortKindEnum + : I32EnumAttr<"SparseTensorSortKind", "sparse tensor sort algorithm", [ + I32EnumAttrCase<"HybridQuickSort", 0, "hybrid_quick_sort">, + I32EnumAttrCase<"InsertionSortStable", 1, "insertion_sort_stable">, + I32EnumAttrCase<"QuickSort", 2, "quick_sort">, + I32EnumAttrCase<"HeapSort", 3, "heap_sort">, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = SparseTensor_Dialect.cppNamespace; +} + +// Define the enum sparse tensor sort kind attribute. +def SparseTensorSortKindAttr + : EnumAttr { +} + #endif // SPARSETENSOR_ATTRDEFS diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -573,10 +573,15 @@ // TODO: May want to extend tablegen with // class NonemptyVariadic : Variadic { let minSize = 1; } // and then use NonemptyVariadic<...>:$xs here. + // + // TODO: Currently tablegen doesn't support the assembly syntax when + // `algorithm` is an optional enum attribute. We may want to use an optional + // enum attribute when this is fixed in tablegen. + // Arguments<(ins Index:$n, Variadic>:$xs, Variadic>:$ys, - UnitAttr:$stable)> { + SparseTensorSortKindAttr:$algorithm)> { string summary = "Sorts the arrays in xs and ys lexicographically on the " "integral values found in the xs list"; string description = [{ @@ -598,8 +603,9 @@ is undefined if this condition is not met. The operator requires at least one buffer in `xs` while `ys` can be empty. - The `stable` attribute indicates whether a stable sorting algorithm should - be used to implement the operator. + The enum attribute `algorithm` indicates the sorting algorithm used to + implement the operator: hybrid_quick_sort, insertion_sort_stable, + quick_sort, or heap_sort. Note that this operation is "impure" in the sense that its behavior is solely defined by side-effects and not SSA values. @@ -607,17 +613,17 @@ Example: ```mlir - sparse_tensor.sort %n, %x1, %x2 jointly y1, %y2 + sparse_tensor.sort insertion_sort_stable %n, %x1, %x2 jointly y1, %y2 : memref, memref jointly memref, memref ``` ```mlir - sparse_tensor.sort stable %n, %x1, %x2 jointly y1, %y2 + sparse_tensor.sort hybrid_quick_sort %n, %x1, %x2 jointly y1, %y2 + { alg=1 : index} : memref, memref jointly memref, memref ``` }]; - let assemblyFormat = "(`stable` $stable^)? $n" - "`,`$xs (`jointly` $ys^)? attr-dict" + let assemblyFormat = "$algorithm $n `,` $xs (`jointly` $ys^)? attr-dict" "`:` type($xs) (`jointly` type($ys)^)?"; let hasVerifier = 1; } @@ -626,7 +632,7 @@ Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy, Variadic>:$ys, OptionalAttr:$nx, OptionalAttr:$ny, - UnitAttr:$stable)> { + SparseTensorSortKindAttr:$algorithm)> { let summary = "Sorts the arrays in xs and ys lexicographically on the " "integral values found in the xs list"; let description = [{ @@ -645,17 +651,18 @@ Example: ```mlir - sparse_tensor.sort_coo %n, %x { nx = 2 : index} + sparse_tensor.sort_coo insertion_sort_stable %n, %x { nx = 2 : index} : memref ``` ```mlir - sparse_tensor.sort %n, %xy jointly %y1 { nx = 2 : index, ny = 2 : index} + sparse_tensor.sort hybrid_quick_sort %n, %xy jointly %y1 + { nx = 2 : index, ny = 2 : index} : memref jointly memref ``` }]; - let assemblyFormat = "(`stable` $stable^)? $n" + let assemblyFormat = "$algorithm $n" "`,`$xy (`jointly` $ys^)? attr-dict" "`:` type($xy) (`jointly` type($ys)^)?"; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -727,11 +727,13 @@ } operands.push_back(v); } + bool isStable = + (op.getAlgorithm() == SparseTensorSortKind::InsertionSortStable); auto insertPoint = op->template getParentOfType(); - SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix - : kSortNonstableFuncNamePrefix); + SmallString<32> funcName(isStable ? kSortStableFuncNamePrefix + : kSortNonstableFuncNamePrefix); FuncGeneratorType funcGenerator = - op.getStable() ? createSortStableFunc : createSortNonstableFunc; + isStable ? createSortStableFunc : createSortNonstableFunc; FlatSymbolRefAttr func = getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx, ny, isCoo, operands, funcGenerator); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -819,7 +819,8 @@ // in the "added" array prior to applying the compression. unsigned rank = dstType.getShape().size(); if (isOrderedDim(dstType, rank - 1)) - rewriter.create(loc, count, ValueRange{added}, ValueRange{}); + rewriter.create(loc, count, ValueRange{added}, ValueRange{}, + SparseTensorSortKind::HybridQuickSort); // While performing the insertions, we also need to reset the elements // of the values/filled-switch by only iterating over the set elements, // to ensure that the runtime complexity remains proportional to the diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -866,9 +866,9 @@ get1DMemRefType(getIndexOverheadType(rewriter, encSrc), /*withLayout=*/false); Value xs = rewriter.create(loc, indTp, src); - rewriter.create(loc, nnz, xs, ValueRange{y}, - rewriter.getIndexAttr(rank), - rewriter.getIndexAttr(0)); + rewriter.create( + loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(rank), + rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort); } else { // Gather the indices-arrays in the dst tensor storage order. SmallVector xs(rank, Value()); @@ -877,7 +877,8 @@ xs[toStoredDim(encDst, orgDim)] = genToIndices(rewriter, loc, src, i, /*cooStart=*/0); } - rewriter.create(loc, nnz, xs, ValueRange{y}); + rewriter.create(loc, nnz, xs, ValueRange{y}, + SparseTensorSortKind::HybridQuickSort); } } diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -55,7 +55,7 @@ // CHECK: return %[[M]], %[[S2]] : memref, index func.func @sparse_push_back_n(%arg0: index, %arg1: memref, %arg2: f64, %arg3: index) -> (memref, index) { %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 : index, memref, f64, index - return %0#0, %0#1 : memref, index + return %0#0, %0#1 : memref, index } // ----- @@ -155,7 +155,7 @@ // CHECK: } func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xi8>, memref, memref<10xindex>) { - sparse_tensor.sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref, memref<10xindex> + sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref, memref<10xindex> return %arg1, %arg2, %arg3 : memref<10xi8>, memref, memref<10xindex> } @@ -170,7 +170,7 @@ // CHECK-DAG: func.func private @_sparse_sort_nonstable_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_3d func.func @sparse_sort_3d(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { - sparse_tensor.sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> + sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> } @@ -184,7 +184,7 @@ // CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_3d_stable func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { - sparse_tensor.sort stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> + sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> } @@ -199,7 +199,7 @@ // CHECK-DAG: func.func private @_sparse_sort_nonstable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_coo func.func @sparse_sort_coo(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { - sparse_tensor.sort_coo %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> + sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> } @@ -213,7 +213,7 @@ // CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_coo_stable func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { - sparse_tensor.sort_coo stable %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> + sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> } diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -430,7 +430,7 @@ // CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64 // CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index // CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index -// CHECK: sparse_tensor.sort %[[A7]], %[[A6]] : memref +// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref // CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]]) // CHECK: %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref // CHECK: %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref @@ -478,7 +478,7 @@ // CHECK: %[[A11:.*]] = arith.constant 0.000000e+00 : f64 // CHECK: %[[A12:.*]] = arith.constant 1 : index // CHECK: %[[A13:.*]] = arith.constant 0 : index -// CHECK: sparse_tensor.sort %[[A7]], %[[A6]] : memref +// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref // CHECK: %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) -> (memref, memref, memref, !sparse_tensor.storage_specifier // CHECK: %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref // CHECK: %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir @@ -195,7 +195,7 @@ // CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]] // CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]] // CHECK-RWT: %[[I:.*]] = sparse_tensor.indices_buffer %[[COO]] -// CHECK-RWT: sparse_tensor.sort_coo %[[NNZ]], %[[I]] jointly %[[V]] {nx = 2 : index, ny = 0 : index} +// CHECK-RWT: sparse_tensor.sort_coo hybrid_quick_sort %[[NNZ]], %[[I]] jointly %[[V]] {nx = 2 : index, ny = 0 : index} // CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor() // CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]]) // CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f32, %[[L1T:.*]]: tensor diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -696,7 +696,7 @@ func.func @sparse_sort_x_type( %arg0: index, %arg1: memref) { // expected-error@+1 {{operand #1 must be 1D memref of integer or index values}} - sparse_tensor.sort %arg0, %arg1: memref + sparse_tensor.sort hybrid_quick_sort %arg0, %arg1: memref } // ----- @@ -704,7 +704,7 @@ func.func @sparse_sort_dim_too_small(%arg0: memref<10xindex>) { %i20 = arith.constant 20 : index // expected-error@+1 {{xs and ys need to have a dimension >= n: 10 < 20}} - sparse_tensor.sort %i20, %arg0 : memref<10xindex> + sparse_tensor.sort insertion_sort_stable %i20, %arg0 : memref<10xindex> return } @@ -712,7 +712,7 @@ func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<10xi8>) { // expected-error@+1 {{mismatch xs element types}} - sparse_tensor.sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8> + sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8> return } @@ -720,7 +720,7 @@ func.func @sparse_sort_coo_x_type( %arg0: index, %arg1: memref) { // expected-error@+1 {{operand #1 must be 1D memref of integer or index values}} - sparse_tensor.sort_coo %arg0, %arg1: memref + sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1: memref return } @@ -729,7 +729,7 @@ func.func @sparse_sort_coo_x_too_small(%arg0: memref<50xindex>) { %i20 = arith.constant 20 : index // expected-error@+1 {{Expected dimension(xy) >= n * (nx + ny) got 50 < 60}} - sparse_tensor.sort_coo %i20, %arg0 {nx = 2 : index, ny = 1 : index} : memref<50xindex> + sparse_tensor.sort_coo hybrid_quick_sort %i20, %arg0 {nx = 2 : index, ny = 1 : index} : memref<50xindex> return } @@ -738,7 +738,7 @@ func.func @sparse_sort_coo_y_too_small(%arg0: memref<60xindex>, %arg1: memref<10xf32>) { %i20 = arith.constant 20 : index // expected-error@+1 {{Expected dimension(y) >= n got 10 < 20}} - sparse_tensor.sort_coo %i20, %arg0 jointly %arg1 {nx = 2 : index, ny = 1 : index} : memref<60xindex> jointly memref<10xf32> + sparse_tensor.sort_coo insertion_sort_stable %i20, %arg0 jointly %arg1 {nx = 2 : index, ny = 1 : index} : memref<60xindex> jointly memref<10xf32> return } diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -504,10 +504,10 @@ // CHECK-LABEL: func @sparse_sort_1d0v( // CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref) -// CHECK: sparse_tensor.sort %[[A]], %[[B]] : memref +// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] : memref // CHECK: return %[[B]] func.func @sparse_sort_1d0v(%arg0: index, %arg1: memref) -> (memref) { - sparse_tensor.sort %arg0, %arg1 : memref + sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 : memref return %arg1 : memref } @@ -518,10 +518,10 @@ // CHECK-SAME: %[[B:.*]]: memref<20xindex>, // CHECK-SAME: %[[C:.*]]: memref<10xindex>, // CHECK-SAME: %[[D:.*]]: memref) -// CHECK: sparse_tensor.sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref<20xindex> jointly memref<10xindex>, memref +// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref<20xindex> jointly memref<10xindex>, memref // CHECK: return %[[B]], %[[C]], %[[D]] func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<20xindex>, %arg2: memref<10xindex>, %arg3: memref) -> (memref<20xindex>, memref<10xindex>, memref) { - sparse_tensor.sort %arg0, %arg1 jointly %arg2, %arg3 : memref<20xindex> jointly memref<10xindex>, memref + sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<20xindex> jointly memref<10xindex>, memref return %arg1, %arg2, %arg3 : memref<20xindex>, memref<10xindex>, memref } @@ -532,10 +532,10 @@ // CHECK-SAME: %[[B:.*]]: memref<10xi8>, // CHECK-SAME: %[[C:.*]]: memref<20xi8>, // CHECK-SAME: %[[D:.*]]: memref<10xf64>) -// CHECK: sparse_tensor.sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64> +// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64> // CHECK: return %[[B]], %[[C]], %[[D]] func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) { - sparse_tensor.sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64> + sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64> return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64> } @@ -546,23 +546,34 @@ // CHECK-SAME: %[[B:.*]]: memref<10xi8>, // CHECK-SAME: %[[C:.*]]: memref<20xi8>, // CHECK-SAME: %[[D:.*]]: memref<10xf64>) -// CHECK: sparse_tensor.sort stable %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64> +// CHECK: sparse_tensor.sort insertion_sort_stable %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64> // CHECK: return %[[B]], %[[C]], %[[D]] func.func @sparse_sort_stable(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) { - sparse_tensor.sort stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64> + sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64> return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64> } // ----- +// CHECK-LABEL: func @sparse_sort_coo( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref) +// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A]], %[[B]] {nx = 2 : index, ny = 1 : index} : memref +// CHECK: return %[[B]] func.func @sparse_sort_coo(%arg0: index, %arg1: memref) -> (memref) { - sparse_tensor.sort_coo %arg0, %arg1 { nx=2 : index, ny=1 : index}: memref + sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {nx = 2 : index, ny = 1 : index}: memref return %arg1 : memref } // ----- +// CHECK-LABEL: func @sparse_sort_coo_stable( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref, +// CHECK-SAME: %[[C:.*]]: memref) +// CHECK: sparse_tensor.sort_coo insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {nx = 2 : index, ny = 1 : index} +// CHECK: return %[[B]], %[[C]] func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref, %arg2: memref) -> (memref, memref) { - sparse_tensor.sort_coo stable %arg0, %arg1 jointly %arg2 { nx=2 : index, ny=1 : index}: memref jointly memref + sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2 {nx = 2 : index, ny = 1 : index}: memref jointly memref return %arg1, %arg2 : memref, memref } diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir @@ -125,7 +125,7 @@ // CHECK: } {"Emitted from" = "linalg.generic"} // CHECK: scf.yield %[[VAL_70:.*]] : index // CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: sparse_tensor.sort %[[VAL_71:.*]], %[[VAL_39]] : memref +// CHECK: sparse_tensor.sort hybrid_quick_sort %[[VAL_71:.*]], %[[VAL_39]] : memref // CHECK: %[[VAL_72:.*]]:4 = scf.for %[[VAL_73:.*]] = %[[VAL_11]] to %[[VAL_71]] step %[[VAL_12]] iter_args(%[[VAL_74:.*]] = %[[VAL_42]], %[[VAL_75:.*]] = %[[VAL_43]], %[[VAL_76:.*]] = %[[VAL_44]], %[[VAL_77:.*]] = %[[VAL_45]]) -> (memref, memref, memref, !sparse_tensor.storage_specifier // CHECK: %[[VAL_78:.*]] = memref.load %[[VAL_38]]{{\[}}%[[VAL_73]]] : memref<4xindex> // CHECK: %[[VAL_79:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64> diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir @@ -50,22 +50,22 @@ // Sort 0 elements. // CHECK: [10, 2, 0, 5, 1] - sparse_tensor.sort %i0, %x0 : memref + sparse_tensor.sort hybrid_quick_sort %i0, %x0 : memref call @printMemref1dI32(%x0) : (memref) -> () // Stable sort. // CHECK: [10, 2, 0, 5, 1] - sparse_tensor.sort stable %i0, %x0 : memref + sparse_tensor.sort insertion_sort_stable %i0, %x0 : memref call @printMemref1dI32(%x0) : (memref) -> () // Sort the first 4 elements, with the last valid value untouched. // CHECK: [0, 2, 5, 10, 1] - sparse_tensor.sort %i4, %x0 : memref + sparse_tensor.sort hybrid_quick_sort %i4, %x0 : memref call @printMemref1dI32(%x0) : (memref) -> () // Stable sort. // CHECK: [0, 2, 5, 10, 1] call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1) : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort stable %i4, %x0 : memref + sparse_tensor.sort insertion_sort_stable %i4, %x0 : memref call @printMemref1dI32(%x0) : (memref) -> () // Prepare more buffers of different dimensions. @@ -89,7 +89,7 @@ : (memref, i32, i32, i32, i32, i32) -> () call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7) : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort %i5, %x0, %x1, %x2 jointly %y0 + sparse_tensor.sort hybrid_quick_sort %i5, %x0, %x1, %x2 jointly %y0 : memref, memref, memref jointly memref call @printMemref1dI32(%x0) : (memref) -> () call @printMemref1dI32(%x1) : (memref) -> () @@ -108,7 +108,7 @@ : (memref, i32, i32, i32, i32, i32) -> () call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7) : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort stable %i5, %x0, %x1, %x2 jointly %y0 + sparse_tensor.sort insertion_sort_stable %i5, %x0, %x1, %x2 jointly %y0 : memref, memref, memref jointly memref call @printMemref1dI32(%x0) : (memref) -> () call @printMemref1dI32(%x1) : (memref) -> () diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir @@ -92,7 +92,7 @@ : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7) : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort_coo %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index} + sparse_tensor.sort_coo hybrid_quick_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index} : memref jointly memref %x0v = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> vector.print %x0v : vector<5xi32> @@ -120,7 +120,7 @@ : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7) : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort_coo stable %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index} + sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index} : memref jointly memref %x0v2 = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> vector.print %x0v2 : vector<5xi32>