diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -918,9 +918,13 @@ builder.create(loc); } -static void createQuickSort(OpBuilder &builder, ModuleOp module, - func::FuncOp func, ValueRange args, uint64_t nx, - uint64_t ny, bool isCoo, uint32_t nTrailingP) { +/// A helper for generating code to perform quick sort. It partitions [lo, hi), +/// recursively calls quick sort to process the smaller partition and returns +/// the bigger partition to be processed by the enclosed while-loop. +static std::pair +createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, + ValueRange args, uint64_t nx, uint64_t ny, bool isCoo, + uint32_t nTrailingP) { MLIRContext *context = module.getContext(); Location loc = func.getLoc(); Value lo = args[loIdx]; @@ -928,20 +932,45 @@ FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx, ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc); - auto p = builder.create(loc, partitionFunc, - TypeRange{IndexType::get(context)}, - args.drop_back(nTrailingP)); - - SmallVector lowOperands{lo, p.getResult(0)}; - lowOperands.append(args.begin() + xStartIdx, args.end()); - builder.create(loc, func, lowOperands); - - SmallVector highOperands{ - builder.create(loc, p.getResult(0), - constantIndex(builder, loc, 1)), - hi}; - highOperands.append(args.begin() + xStartIdx, args.end()); - builder.create(loc, func, highOperands); + Value p = builder + .create(loc, partitionFunc, + TypeRange{IndexType::get(context)}, + args.drop_back(nTrailingP)) + .getResult(0); + Value pP1 = + builder.create(loc, p, constantIndex(builder, loc, 1)); + Value lenLow = builder.create(loc, p, lo); + Value lenHigh = builder.create(loc, hi, p); + Value cond = builder.create(loc, arith::CmpIPredicate::ule, + lenLow, lenHigh); + + SmallVector types(2, lo.getType()); // Only two types. + scf::IfOp ifOp = builder.create(loc, types, cond, /*else=*/true); + + Value c0 = constantIndex(builder, loc, 0); + auto mayRecursion = [&](Value low, Value high, Value len) { + Value cond = + builder.create(loc, arith::CmpIPredicate::ne, len, c0); + scf::IfOp ifOp = builder.create(loc, cond, /*else=*/false); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + SmallVector operands{low, high}; + operands.append(args.begin() + xStartIdx, args.end()); + builder.create(loc, func, operands); + builder.setInsertionPointAfter(ifOp); + }; + + // Recursively call quickSort to process the smaller partition and return + // the bigger partition to be processed by the enclosed while-loop. + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + mayRecursion(lo, p, lenLow); + builder.create(loc, ValueRange{pP1, hi}); + + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + mayRecursion(pP1, hi, lenHigh); + builder.create(loc, ValueRange{lo, p}); + + builder.setInsertionPointAfter(ifOp); + return std::make_pair(ifOp.getResult(0), ifOp.getResult(1)); } /// Creates a function to perform insertion sort on the values in the range of @@ -1036,16 +1065,21 @@ // // When nTrailingP == 0, the generated IR corresponds to this C like algorithm: // void quickSort(lo, hi, data) { -// if (lo + 1 < hi) { +// while (lo + 1 < hi) { // p = partition(low, high, data); -// quickSort(lo, p, data); -// quickSort(p + 1, hi, data); +// if (len(lo, p) < len(p+1, hi)) { +// quickSort(lo, p, data); +// lo = p+1; +// } else { +// quickSort(p + 1, hi, data); +// hi = p; +// } // } // } // // When nTrailingP == 1, the generated IR corresponds to this C like algorithm: // void hybridQuickSort(lo, hi, data, depthLimit) { -// if (lo + 1 < hi) { +// while (lo + 1 < hi) { // len = hi - lo; // if (len <= limit) { // insertionSort(lo, hi, data); @@ -1055,10 +1089,14 @@ // heapSort(lo, hi, data); // } else { // p = partition(low, high, data); -// quickSort(lo, p, data); -// quickSort(p + 1, hi, data); +// if (len(lo, p) < len(p+1, hi)) { +// quickSort(lo, p, data, depthLimit); +// lo = p+1; +// } else { +// quickSort(p + 1, hi, data, depthLimit); +// hi = p; +// } // } -// depthLimit ++; // } // } // } @@ -1073,70 +1111,98 @@ builder.setInsertionPointToStart(entryBlock); Location loc = func.getLoc(); - ValueRange args = entryBlock->getArguments(); + SmallVector args; + args.append(entryBlock->getArguments().begin(), + entryBlock->getArguments().end()); Value lo = args[loIdx]; Value hi = args[hiIdx]; - Value loCmp = + SmallVector types(2, lo.getType()); // Only two types. + scf::WhileOp whileOp = + builder.create(loc, types, SmallVector{lo, hi}); + + // The before-region of the WhileOp. + Block *before = + builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); + builder.setInsertionPointToEnd(before); + lo = before->getArgument(0); + hi = before->getArgument(1); + Value loP1 = builder.create(loc, lo, constantIndex(builder, loc, 1)); - Value cond = - builder.create(loc, arith::CmpIPredicate::ult, loCmp, hi); - scf::IfOp ifOp = builder.create(loc, cond, /*else=*/false); + Value needSort = + builder.create(loc, arith::CmpIPredicate::ult, loP1, hi); + builder.create(loc, needSort, before->getArguments()); - // The if-stmt true branch. - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - Value pDepthLimit; - Value savedDepthLimit; - scf::IfOp depthIf; + // The after-region of the WhileOp. + Block *after = + builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); + builder.setInsertionPointToEnd(after); + lo = after->getArgument(0); + hi = after->getArgument(1); + args[0] = lo; + args[1] = hi; if (isHybrid) { Value len = builder.create(loc, hi, lo); Value lenLimit = constantIndex(builder, loc, 30); Value lenCond = builder.create( loc, arith::CmpIPredicate::ule, len, lenLimit); - scf::IfOp lenIf = builder.create(loc, lenCond, /*else=*/true); + scf::IfOp lenIf = + builder.create(loc, types, lenCond, /*else=*/true); // When len <= limit. builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc( builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo, - args.drop_back(nTrailingP), createSortStableFunc); + ValueRange(args).drop_back(nTrailingP), createSortStableFunc); builder.create(loc, insertionSortFunc, TypeRange(), - ValueRange(args.drop_back(nTrailingP))); + ValueRange(args).drop_back(nTrailingP)); + builder.create(loc, ValueRange{lo, lo}); // When len > limit. builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); - pDepthLimit = args.back(); - savedDepthLimit = builder.create(loc, pDepthLimit); - Value depthLimit = builder.create( - loc, savedDepthLimit, constantI64(builder, loc, 1)); - builder.create(loc, depthLimit, pDepthLimit); + Value depthLimit = args.back(); + depthLimit = builder.create(loc, depthLimit, + constantI64(builder, loc, 1)); Value depthCond = builder.create(loc, arith::CmpIPredicate::ule, depthLimit, constantI64(builder, loc, 0)); - depthIf = builder.create(loc, depthCond, /*else=*/true); + scf::IfOp depthIf = + builder.create(loc, types, depthCond, /*else=*/true); // When depth exceeds limit. builder.setInsertionPointToStart(&depthIf.getThenRegion().front()); FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc( builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo, - args.drop_back(nTrailingP), createHeapSortFunc); + ValueRange(args).drop_back(nTrailingP), createHeapSortFunc); builder.create(loc, heapSortFunc, TypeRange(), - ValueRange(args.drop_back(nTrailingP))); + ValueRange(args).drop_back(nTrailingP)); + builder.create(loc, ValueRange{lo, lo}); // When depth doesn't exceed limit. builder.setInsertionPointToStart(&depthIf.getElseRegion().front()); - } - - createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP); + args.back() = depthLimit; + std::tie(lo, hi) = + createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP); + builder.create(loc, ValueRange{lo, hi}); - if (isHybrid) { - // Restore depthLimit. builder.setInsertionPointAfter(depthIf); - builder.create(loc, savedDepthLimit, pDepthLimit); + lo = depthIf.getResult(0); + hi = depthIf.getResult(1); + builder.create(loc, ValueRange{lo, hi}); + + builder.setInsertionPointAfter(lenIf); + lo = lenIf.getResult(0); + hi = lenIf.getResult(1); + } else { + std::tie(lo, hi) = + createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP); } - // After the if-stmt. - builder.setInsertionPointAfter(ifOp); + // New [lo, hi) for the next while-loop iteration. + builder.create(loc, ValueRange{lo, hi}); + + // After the while-loop. + builder.setInsertionPointAfter(whileOp); builder.create(loc); } @@ -1171,9 +1237,6 @@ funcName = kHybridQuickSortFuncNamePrefix; funcGenerator = createQuickSortFunc; nTrailingP = 1; - Value pDepthLimit = rewriter.create( - loc, MemRefType::get({}, rewriter.getI64Type())); - operands.push_back(pDepthLimit); // As a heuristics, set depthLimit = 2 * log2(n). Value lo = operands[loIdx]; Value hi = operands[hiIdx]; @@ -1183,9 +1246,7 @@ Value depthLimit = rewriter.create( loc, constantI64(rewriter, loc, 64), rewriter.create(loc, len)); - depthLimit = rewriter.create(loc, depthLimit, - constantI64(rewriter, loc, 1)); - rewriter.create(loc, depthLimit, pDepthLimit); + operands.push_back(depthLimit); break; } case SparseTensorSortKind::QuickSort: diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -132,13 +132,24 @@ // CHECK-SAME: %[[Y0:.*]]: memref, // CHECK-SAME: %[[Y1:.*]]: memref) { // CHECK: %[[C1:.*]] = arith.constant 1 -// CHECK: %[[Lb:.*]] = arith.addi %[[L]], %[[C1]] -// CHECK: %[[COND:.*]] = arith.cmpi ult, %[[Lb]], %[[H]] -// CHECK: scf.if %[[COND]] { -// CHECK: %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]]) -// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]]) -// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]] : index -// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]]) +// CHECK: scf.while (%[[L2:.*]] = %[[L]], %[[H2:.*]] = %[[H]]) +// CHECK: %[[Lb:.*]] = arith.addi %[[L2]], %[[C1]] +// CHECK: %[[COND:.*]] = arith.cmpi ult, %[[Lb]], %[[H2]] +// CHECK: scf.condition(%[[COND]]) %[[L2]], %[[H2]] +// CHECK: } do { +// CHECK: ^bb0(%[[L3:.*]]: index, %[[H3:.*]]: index) +// CHECK: %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L3]], %[[H3]], %[[X0]], %[[Y0]], %[[Y1]]) +// CHECK: %[[PP1:.*]] = arith.addi %[[P]], %[[C1]] : index +// CHECK: %[[LenL:.*]] = arith.subi %[[P]], %[[L3]] +// CHECK: %[[LenH:.*]] = arith.subi %[[H3]], %[[P]] +// CHECK: %[[Cmp:.*]] = arith.cmpi ule, %[[LenL]], %[[LenH]] +// CHECK: %[[L4:.*]] = arith.select %[[Cmp]], %[[PP1]], %[[L3]] +// CHECK: %[[H4:.*]] = arith.select %[[Cmp]], %[[H3]], %[[P]] +// CHECK: scf.if %[[Cmp]] +// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[L3]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]]) +// CHECK: else +// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[PP1]], %[[H3]], %[[X0]], %[[Y0]], %[[Y1]]) +// CHECK: scf.yield %[[L4]], %[[H4]] // CHECK: } // CHECK: return // CHECK: } @@ -187,7 +198,7 @@ // CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-DAG: func.func private @_sparse_compare_eq_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref) { +// CHECK-DAG: func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: i64) { // CHECK-LABEL: func.func @sparse_sort_3d_hybrid func.func @sparse_sort_3d_hybrid(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> @@ -249,7 +260,7 @@ // CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-DAG: func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref) { +// CHECK-DAG: func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: i64) { // CHECK-LABEL: func.func @sparse_sort_coo_hybrid func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32>