diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -443,6 +443,93 @@ return std::make_pair(whileOp.getResult(0), compareEq); } +/// Creates a code block to swap the values so that data[mi] is the median among +/// data[lo], data[hi], and data[mi]. +// The generated code corresponds to this C-like algorithm: +// median = mi +// if (data[mi] < data[lo]). (if1) +// if (data[hi] < data[lo]) (if2) +// median = data[hi] < data[mi] ? mi : hi +// else +// median = lo +// else +// if data[hi] < data[mi] (if3) +// median = data[hi] < data[lo] ? lo : hi +// if median != mi swap data[median] with data[mi] +static void createChoosePivot(OpBuilder &builder, ModuleOp module, + func::FuncOp func, uint64_t nx, uint64_t ny, + bool isCoo, Value lo, Value hi, Value mi, + ValueRange args) { + SmallVector compareOperands{mi, lo}; + uint64_t numXBuffers = isCoo ? 1 : nx; + compareOperands.append(args.begin() + xStartIdx, + args.begin() + xStartIdx + numXBuffers); + Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); + SmallVector cmpTypes{i1Type}; + FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( + builder, func, cmpTypes, kLessThanFuncNamePrefix, nx, ny, isCoo, + compareOperands, createLessThanFunc); + Location loc = func.getLoc(); + // Compare data[mi] < data[lo]. + Value cond1 = + builder.create(loc, lessThanFunc, cmpTypes, compareOperands) + .getResult(0); + SmallVector ifTypes{lo.getType()}; + scf::IfOp ifOp1 = + builder.create(loc, ifTypes, cond1, /*else=*/true); + + // Generate an if-stmt to find the median value, assuming we already know that + // data[b] < data[a] and we haven't compare data[c] yet. + auto createFindMedian = [&](Value a, Value b, Value c) -> scf::IfOp { + compareOperands[0] = c; + compareOperands[1] = a; + // Compare data[c]] < data[a]. + Value cond2 = + builder + .create(loc, lessThanFunc, cmpTypes, compareOperands) + .getResult(0); + scf::IfOp ifOp2 = + builder.create(loc, ifTypes, cond2, /*else=*/true); + builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); + compareOperands[0] = c; + compareOperands[1] = b; + // Compare data[c] < data[b]. + Value cond3 = + builder + .create(loc, lessThanFunc, cmpTypes, compareOperands) + .getResult(0); + builder.create( + loc, ValueRange{builder.create(loc, cond3, b, c)}); + builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); + builder.create(loc, ValueRange{a}); + return ifOp2; + }; + + builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); + scf::IfOp ifOp2 = createFindMedian(lo, mi, hi); + builder.setInsertionPointAfter(ifOp2); + builder.create(loc, ValueRange{ifOp2.getResult(0)}); + + builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); + scf::IfOp ifOp3 = createFindMedian(mi, lo, hi); + + builder.setInsertionPointAfter(ifOp3); + builder.create(loc, ValueRange{ifOp3.getResult(0)}); + + builder.setInsertionPointAfter(ifOp1); + Value median = ifOp1.getResult(0); + Value cond = + builder.create(loc, arith::CmpIPredicate::ne, mi, median); + scf::IfOp ifOp = + builder.create(loc, TypeRange(), cond, /*else=*/false); + + SmallVector swapOperands{median, mi}; + swapOperands.append(args.begin() + xStartIdx, args.end()); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + createSwap(builder, loc, swapOperands, nx, ny, isCoo); + builder.setInsertionPointAfter(ifOp); +} + /// Creates a function to perform quick sort partition on the values in the /// range of index [lo, hi), assuming lo < hi. // @@ -489,7 +576,8 @@ Value i = lo; Value j = builder.create(loc, hi, c1); - SmallVector operands{i, j, p}; // exactly three + createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args); + SmallVector operands{i, j, p}; // Exactly three values. SmallVector types{i.getType(), j.getType(), p.getType()}; scf::WhileOp whileOp = builder.create(loc, types, operands); diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir @@ -80,7 +80,7 @@ // CHECK: [1, 1, 2, 5, 10] // CHECK: [3, 3, 1, 10, 1 // CHECK: [9, 9, 4, 7, 2 - // CHECK: [8, 7, 10, 9, 6 + // CHECK: [7, 8, 10, 9, 6 call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1) : (memref, i32, i32, i32, i32, i32) -> () call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3) diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir @@ -80,8 +80,8 @@ // CHECK: ( 1, 1, 2, 5, 10 ) // CHECK: ( 3, 3, 1, 10, 1 ) // CHECK: ( 9, 9, 4, 7, 2 ) - // CHECK: ( 8, 7, 10, 9, 6 ) - // CHECK: ( 4, 7, 7, 9, 5 ) + // CHECK: ( 7, 8, 10, 9, 6 ) + // CHECK: ( 7, 4, 7, 9, 5 ) call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1) : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3)