diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -463,19 +463,99 @@ return std::make_pair(whileOp.getResult(0), compareEq); } -/// Creates a code block to swap the values so that data[mi] is the median among -/// data[lo], data[hi], and data[mi]. -// The generated code corresponds to this C-like algorithm: -// median = mi -// if (data[mi] < data[lo]). (if1) -// if (data[hi] < data[lo]) (if2) -// median = data[hi] < data[mi] ? mi : hi -// else -// median = lo -// else -// if data[hi] < data[mi] (if3) -// median = data[hi] < data[lo] ? lo : hi -// if median != mi swap data[median] with data[mi] +/// Creates and returns an IfOp to compare two elements and swap the elements +/// if compareFunc(data[b], data[a]) returns true. The insert point is right +/// after the swap instructions. +static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, + uint64_t nx, uint64_t ny, bool isCoo, + SmallVectorImpl &swapOperands, + SmallVectorImpl &compareOperands, + FlatSymbolRefAttr compareFunc, Value a, + Value b) { + SmallVector resultTypes{builder.getIntegerType(1)}; + // Compare(data[b], data[a]). + compareOperands[0] = b; + compareOperands[1] = a; + Value cond = + builder + .create(loc, compareFunc, resultTypes, compareOperands) + .getResult(0); + scf::IfOp ifOp = builder.create(loc, cond, /*else=*/false); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + swapOperands[0] = b; + swapOperands[1] = a; + createSwap(builder, loc, swapOperands, nx, ny, isCoo); + return ifOp; +} + +/// Creates code to insert the 3rd element to a list of two sorted elements. +static void createInsert3rd(OpBuilder &builder, Location loc, uint64_t nx, + uint64_t ny, bool isCoo, + SmallVectorImpl &swapOperands, + SmallVectorImpl &compareOperands, + FlatSymbolRefAttr lessThanFunc, Value v0, Value v1, + Value v2) { + scf::IfOp ifOp = + createCompareThenSwap(builder, loc, nx, ny, isCoo, swapOperands, + compareOperands, lessThanFunc, v1, v2); + createCompareThenSwap(builder, loc, nx, ny, isCoo, swapOperands, + compareOperands, lessThanFunc, v0, v1); + builder.setInsertionPointAfter(ifOp); +} + +/// Creates code to sort 3 elements. +static void createSort3(OpBuilder &builder, Location loc, uint64_t nx, + uint64_t ny, bool isCoo, + SmallVectorImpl &swapOperands, + SmallVectorImpl &compareOperands, + FlatSymbolRefAttr lessThanFunc, Value v0, Value v1, + Value v2) { + // Sort the first 2 elements. + scf::IfOp ifOp1 = + createCompareThenSwap(builder, loc, nx, ny, isCoo, swapOperands, + compareOperands, lessThanFunc, v0, v1); + builder.setInsertionPointAfter(ifOp1); + + // Insert the 3th element. + createInsert3rd(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, + lessThanFunc, v0, v1, v2); +} + +/// Creates code to sort 5 elements. +static void createSort5(OpBuilder &builder, Location loc, uint64_t nx, + uint64_t ny, bool isCoo, + SmallVectorImpl &swapOperands, + SmallVectorImpl &compareOperands, + FlatSymbolRefAttr lessThanFunc, Value v0, Value v1, + Value v2, Value v3, Value v4) { + // Sort the first 3 elements. + createSort3(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, + lessThanFunc, v0, v1, v2); + + auto insert4th = [&]() { + scf::IfOp ifOp = + createCompareThenSwap(builder, loc, nx, ny, isCoo, swapOperands, + compareOperands, lessThanFunc, v2, v3); + createInsert3rd(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, + lessThanFunc, v0, v1, v2); + builder.setInsertionPointAfter(ifOp); + }; + + // Insert the 4th element. + insert4th(); + + // Insert the 5th element. + scf::IfOp ifOp = + createCompareThenSwap(builder, loc, nx, ny, isCoo, swapOperands, + compareOperands, lessThanFunc, v3, v4); + insert4th(); + builder.setInsertionPointAfter(ifOp); +} + +/// Creates a code block to swap the values in indices lo, mi, and hi so that +/// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When +/// the number of values in range [lo, hi) is more than a threshold, we also +/// include the middle of [lo, mi) and [mi, hi) and sort a total of five values. static void createChoosePivot(OpBuilder &builder, ModuleOp module, func::FuncOp func, uint64_t nx, uint64_t ny, bool isCoo, Value lo, Value hi, Value mi, @@ -489,65 +569,37 @@ FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( builder, func, cmpTypes, kLessThanFuncNamePrefix, nx, ny, isCoo, compareOperands, createLessThanFunc); - Location loc = func.getLoc(); - // Compare data[mi] < data[lo]. - Value cond1 = - builder.create(loc, lessThanFunc, cmpTypes, compareOperands) - .getResult(0); - SmallVector ifTypes{lo.getType()}; - scf::IfOp ifOp1 = - builder.create(loc, ifTypes, cond1, /*else=*/true); - - // Generate an if-stmt to find the median value, assuming we already know that - // data[b] < data[a] and we haven't compare data[c] yet. - auto createFindMedian = [&](Value a, Value b, Value c) -> scf::IfOp { - compareOperands[0] = c; - compareOperands[1] = a; - // Compare data[c]] < data[a]. - Value cond2 = - builder - .create(loc, lessThanFunc, cmpTypes, compareOperands) - .getResult(0); - scf::IfOp ifOp2 = - builder.create(loc, ifTypes, cond2, /*else=*/true); - builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); - compareOperands[0] = c; - compareOperands[1] = b; - // Compare data[c] < data[b]. - Value cond3 = - builder - .create(loc, lessThanFunc, cmpTypes, compareOperands) - .getResult(0); - builder.create( - loc, ValueRange{builder.create(loc, cond3, b, c)}); - builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); - builder.create(loc, ValueRange{a}); - return ifOp2; - }; - - builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); - scf::IfOp ifOp2 = createFindMedian(lo, mi, hi); - builder.setInsertionPointAfter(ifOp2); - builder.create(loc, ValueRange{ifOp2.getResult(0)}); - - builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); - scf::IfOp ifOp3 = createFindMedian(mi, lo, hi); - - builder.setInsertionPointAfter(ifOp3); - builder.create(loc, ValueRange{ifOp3.getResult(0)}); + SmallVector swapOperands{mi, lo}; + swapOperands.append(args.begin() + xStartIdx, args.end()); - builder.setInsertionPointAfter(ifOp1); - Value median = ifOp1.getResult(0); - Value cond = - builder.create(loc, arith::CmpIPredicate::ne, mi, median); - scf::IfOp ifOp = - builder.create(loc, TypeRange(), cond, /*else=*/false); + Location loc = func.getLoc(); - SmallVector swapOperands{median, mi}; - swapOperands.append(args.begin() + xStartIdx, args.end()); - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - createSwap(builder, loc, swapOperands, nx, ny, isCoo); - builder.setInsertionPointAfter(ifOp); + Value c1 = constantIndex(builder, loc, 1); + Value hiP1 = builder.create(loc, hi, c1); + Value len = builder.create(loc, hiP1, lo); + Value lenThreshold = constantIndex(builder, loc, 1000); + Value lenCond = builder.create(loc, arith::CmpIPredicate::ult, + len, lenThreshold); + scf::IfOp lenIf = builder.create(loc, lenCond, /*else=*/true); + + // When len < 1000, choose pivot from median of 3 values. + builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); + createSort3(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, + lessThanFunc, lo, mi, hi); + + // When len >= 1000, choose pivot from median of 5 values. + builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); + Value miP1 = builder.create(loc, hi, c1); + Value a = builder.create(loc, lo, miP1); + // Value a is the middle between [loc, mi]. + a = builder.create(loc, a, c1); + Value b = builder.create(loc, mi, hiP1); + // Value b is the middle between [mi, hi]. + b = builder.create(loc, b, c1); + createSort5(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, + lessThanFunc, lo, a, mi, b, hi); + + builder.setInsertionPointAfter(lenIf); } /// Creates a function to perform quick sort partition on the values in the