diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -42,6 +42,8 @@ "_sparse_sort_nonstable_"; static constexpr const char kSortStableFuncNamePrefix[] = "_sparse_sort_stable_"; +static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_"; +static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_"; using FuncGeneratorType = function_ref; @@ -680,6 +682,240 @@ builder.create(loc, whileOp.getResult(2)); } +/// Computes (n-2)/n, assuming n has index type. +static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, + Value n) { + Value i2 = constantIndex(builder, loc, 2); + Value res = builder.create(loc, n, i2); + Value i1 = constantIndex(builder, loc, 1); + return builder.create(loc, res, i1); +} + +/// Creates a function to heapify the subtree with root `start` within the full +/// binary tree in the range of index [first, first + n). +// +// The generated IR corresponds to this C like algorithm: +// void shiftDown(first, start, n, data) { +// if (n >= 2) { +// child = start - first +// if ((n-2)/2 >= child) { +// // Left child exists. +// child = child * 2 + 1 // Initialize the bigger child to left child. +// childIndex = child + first +// if (child+1 < n && data[childIndex] < data[childIndex+1]) +// // Right child exits and is bigger. +// childIndex++; child++; +// // Shift data[start] down to where it belongs in the subtree. +// while (data[start] < data[childIndex) { +// swap(data[start], data[childIndex]) +// start = childIndex +// if ((n - 2)/2 >= child) { +// // Left child exists. +// child = 2*child + 1 +// childIndex = child + 1 +// if (child + 1) < n && data[childIndex] < data[childIndex+1] +// childIndex++; child++; +// } +// } +// } +// } +// } +// +static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, + func::FuncOp func, uint64_t nx, uint64_t ny, + bool isCoo, uint32_t nTrailingP) { + // The value n is passed in as a trailing parameter. + assert(nTrailingP == 1); + OpBuilder::InsertionGuard insertionGuard(builder); + Block *entryBlock = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + + Location loc = func.getLoc(); + Value n = entryBlock->getArguments().back(); + ValueRange args = entryBlock->getArguments().drop_back(); + Value first = args[loIdx]; + Value start = args[hiIdx]; + + // If (n >= 2). + Value c2 = constantIndex(builder, loc, 2); + Value condN = + builder.create(loc, arith::CmpIPredicate::uge, n, c2); + scf::IfOp ifN = builder.create(loc, condN, /*else=*/false); + builder.setInsertionPointToStart(&ifN.getThenRegion().front()); + Value child = builder.create(loc, start, first); + + // If ((n-2)/2 >= child). + Value t = createSubTwoDividedByTwo(builder, loc, n); + Value condNc = + builder.create(loc, arith::CmpIPredicate::uge, t, child); + scf::IfOp ifNc = builder.create(loc, condNc, /*else=*/false); + + builder.setInsertionPointToStart(&ifNc.getThenRegion().front()); + Value c1 = constantIndex(builder, loc, 1); + SmallVector compareOperands{start, start}; + uint64_t numXBuffers = isCoo ? 1 : nx; + compareOperands.append(args.begin() + xStartIdx, + args.begin() + xStartIdx + numXBuffers); + Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); + FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( + builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, + compareOperands, createLessThanFunc); + + // Generate code to inspect the children of 'r' and return the larger child + // as follows: + // child = r * 2 + 1 // Left child. + // childIndex = child + first + // if (child+1 < n && data[childIndex] < data[childIndex+1]) + // childIndex ++; child ++ // Right child is bigger. + auto getLargerChild = [&](Value r) -> std::pair { + Value lChild = builder.create(loc, r, c1); + lChild = builder.create(loc, lChild, c1); + Value lChildIdx = builder.create(loc, lChild, first); + Value rChild = builder.create(loc, lChild, c1); + Value cond1 = builder.create(loc, arith::CmpIPredicate::ult, + rChild, n); + SmallVector ifTypes(2, r.getType()); + scf::IfOp if1 = + builder.create(loc, ifTypes, cond1, /*else=*/true); + builder.setInsertionPointToStart(&if1.getThenRegion().front()); + Value rChildIdx = builder.create(loc, rChild, first); + // Compare data[left] < data[right]. + compareOperands[0] = lChildIdx; + compareOperands[1] = rChildIdx; + Value cond2 = builder + .create(loc, lessThanFunc, + TypeRange{i1Type}, compareOperands) + .getResult(0); + scf::IfOp if2 = + builder.create(loc, ifTypes, cond2, /*else=*/true); + builder.setInsertionPointToStart(&if2.getThenRegion().front()); + builder.create(loc, ValueRange{rChild, rChildIdx}); + builder.setInsertionPointToStart(&if2.getElseRegion().front()); + builder.create(loc, ValueRange{lChild, lChildIdx}); + builder.setInsertionPointAfter(if2); + builder.create(loc, if2.getResults()); + builder.setInsertionPointToStart(&if1.getElseRegion().front()); + builder.create(loc, ValueRange{lChild, lChildIdx}); + builder.setInsertionPointAfter(if1); + return std::make_pair(if1.getResult(0), if1.getResult(1)); + }; + + Value childIdx; + std::tie(child, childIdx) = getLargerChild(child); + + // While (data[start] < data[childIndex]). + SmallVector types(3, child.getType()); + scf::WhileOp whileOp = builder.create( + loc, types, SmallVector{start, child, childIdx}); + + // The before-region of the WhileOp. + SmallVector locs(3, loc); + Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); + builder.setInsertionPointToEnd(before); + start = before->getArgument(0); + childIdx = before->getArgument(2); + compareOperands[0] = start; + compareOperands[1] = childIdx; + Value cond = builder + .create(loc, lessThanFunc, TypeRange{i1Type}, + compareOperands) + .getResult(0); + builder.create(loc, cond, before->getArguments()); + + // The after-region of the WhileOp. + Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); + start = after->getArgument(0); + child = after->getArgument(1); + childIdx = after->getArgument(2); + SmallVector swapOperands{start, childIdx}; + swapOperands.append(args.begin() + xStartIdx, args.end()); + createSwap(builder, loc, swapOperands, nx, ny, isCoo); + start = childIdx; + Value cond2 = + builder.create(loc, arith::CmpIPredicate::uge, t, child); + scf::IfOp if2 = builder.create( + loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true); + builder.setInsertionPointToStart(&if2.getThenRegion().front()); + auto [newChild, newChildIdx] = getLargerChild(child); + builder.create(loc, ValueRange{newChild, newChildIdx}); + builder.setInsertionPointToStart(&if2.getElseRegion().front()); + builder.create(loc, ValueRange{child, childIdx}); + builder.setInsertionPointAfter(if2); + builder.create( + loc, ValueRange{start, if2.getResult(0), if2.getResult(1)}); + + builder.setInsertionPointAfter(ifN); + builder.create(loc); +} + +/// Creates a function to perform heap sort on the values in the range of index +/// [lo, hi) with the assumption hi - lo >= 2. +// +// The generate IR corresponds to this C like algorithm: +// void heapSort(lo, hi, data) { +// n = hi - lo +// for i = (n-2)/2 downto 0 +// shiftDown(lo, lo+i, n) +// +// for l = n downto 2 +// swap(lo, lo+l-1) +// shiftdown(lo, lo, l-1) +// } +static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, + func::FuncOp func, uint64_t nx, uint64_t ny, + bool isCoo, uint32_t nTrailingP) { + // Heap sort function doesn't have trailing parameters. + (void)nTrailingP; + assert(nTrailingP == 0); + OpBuilder::InsertionGuard insertionGuard(builder); + Block *entryBlock = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + + Location loc = func.getLoc(); + ValueRange args = entryBlock->getArguments(); + Value lo = args[loIdx]; + Value hi = args[hiIdx]; + Value n = builder.create(loc, hi, lo); + + // For i = (n-2)/2 downto 0. + Value c0 = constantIndex(builder, loc, 0); + Value c1 = constantIndex(builder, loc, 1); + Value s = createSubTwoDividedByTwo(builder, loc, n); + Value up = builder.create(loc, s, c1); + scf::ForOp forI = builder.create(loc, c0, up, c1); + builder.setInsertionPointToStart(forI.getBody()); + Value i = builder.create(loc, s, forI.getInductionVar()); + Value lopi = builder.create(loc, lo, i); + SmallVector shiftDownOperands = {lo, lopi}; + shiftDownOperands.append(args.begin() + xStartIdx, args.end()); + shiftDownOperands.push_back(n); + FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc( + builder, func, TypeRange(), kShiftDownFuncNamePrefix, nx, ny, isCoo, + shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1); + builder.create(loc, shiftDownFunc, TypeRange(), + shiftDownOperands); + + builder.setInsertionPointAfter(forI); + // For l = n downto 2. + up = builder.create(loc, n, c1); + scf::ForOp forL = builder.create(loc, c0, up, c1); + builder.setInsertionPointToStart(forL.getBody()); + Value l = builder.create(loc, n, forL.getInductionVar()); + Value loplm1 = builder.create(loc, lo, l); + loplm1 = builder.create(loc, loplm1, c1); + SmallVector swapOperands{lo, loplm1}; + swapOperands.append(args.begin() + xStartIdx, args.end()); + createSwap(builder, loc, swapOperands, nx, ny, isCoo); + shiftDownOperands[1] = lo; + shiftDownOperands[shiftDownOperands.size() - 1] = + builder.create(loc, l, c1); + builder.create(loc, shiftDownFunc, TypeRange(), + shiftDownOperands); + + builder.setInsertionPointAfter(forL); + builder.create(loc); +} + /// Creates a function to perform quick sort on the value in the range of /// index [lo, hi). // @@ -836,14 +1072,27 @@ } operands.push_back(v); } - bool isStable = - (op.getAlgorithm() == SparseTensorSortKind::InsertionSortStable); + auto insertPoint = op->template getParentOfType(); - SmallString<32> funcName(isStable ? kSortStableFuncNamePrefix - : kSortNonstableFuncNamePrefix); - FuncGeneratorType funcGenerator = - isStable ? createSortStableFunc : createSortNonstableFunc; + SmallString<32> funcName; + FuncGeneratorType funcGenerator; uint32_t nTrailingP = 0; + switch (op.getAlgorithm()) { + case SparseTensorSortKind::HybridQuickSort: + case SparseTensorSortKind::QuickSort: + funcName = kSortNonstableFuncNamePrefix; + funcGenerator = createSortNonstableFunc; + break; + case SparseTensorSortKind::InsertionSortStable: + funcName = kSortStableFuncNamePrefix; + funcGenerator = createSortStableFunc; + break; + case SparseTensorSortKind::HeapSort: + funcName = kHeapSortFuncNamePrefix; + funcGenerator = createHeapSortFunc; + break; + } + FlatSymbolRefAttr func = getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx, ny, isCoo, operands, funcGenerator, nTrailingP); diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -190,6 +190,20 @@ // ----- +// Only check the generated supporting functions. We have integration test to +// verify correctness of the generated code. +// +// CHECK-DAG: func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> i1 { +// CHECK-DAG: func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { +// CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-LABEL: func.func @sparse_sort_3d_heap +func.func @sparse_sort_3d_heap(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { + sparse_tensor.sort heap_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> + return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> +} + +// ----- + // Only check the generated supporting functions. We have integration test to // verify correctness of the generated code. // @@ -217,3 +231,16 @@ return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> } +// ----- + +// Only check the generated supporting functions. We have integration test to +// verify correctness of the generated code. +// +// CHECK-DAG: func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref) -> i1 { +// CHECK-DAG: func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { +// CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-LABEL: func.func @sparse_sort_coo_heap +func.func @sparse_sort_coo_heap(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { + sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> + return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> +} \ No newline at end of file diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir @@ -56,6 +56,10 @@ // CHECK: [10, 2, 0, 5, 1] sparse_tensor.sort insertion_sort_stable %i0, %x0 : memref call @printMemref1dI32(%x0) : (memref) -> () + // Heap sort. + // CHECK: [10, 2, 0, 5, 1] + sparse_tensor.sort heap_sort %i0, %x0 : memref + call @printMemref1dI32(%x0) : (memref) -> () // Sort the first 4 elements, with the last valid value untouched. // CHECK: [0, 2, 5, 10, 1] @@ -67,6 +71,12 @@ : (memref, i32, i32, i32, i32, i32) -> () sparse_tensor.sort insertion_sort_stable %i4, %x0 : memref call @printMemref1dI32(%x0) : (memref) -> () + // Heap sort. + // CHECK: [0, 2, 5, 10, 1] + call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1) + : (memref, i32, i32, i32, i32, i32) -> () + sparse_tensor.sort heap_sort %i4, %x0 : memref + call @printMemref1dI32(%x0) : (memref) -> () // Prepare more buffers of different dimensions. %x1s = memref.alloc() : memref<10xi32> @@ -114,6 +124,25 @@ call @printMemref1dI32(%x1) : (memref) -> () call @printMemref1dI32(%x2) : (memref) -> () call @printMemref1dI32(%y0) : (memref) -> () + // Heap sort. + // CHECK: [1, 1, 2, 5, 10] + // CHECK: [3, 3, 1, 10, 1 + // CHECK: [9, 9, 4, 7, 2 + // CHECK: [7, 8, 10, 9, 6 + call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1) + : (memref, i32, i32, i32, i32, i32) -> () + call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3) + : (memref, i32, i32, i32, i32, i32) -> () + call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9) + : (memref, i32, i32, i32, i32, i32) -> () + call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7) + : (memref, i32, i32, i32, i32, i32) -> () + sparse_tensor.sort heap_sort %i5, %x0, %x1, %x2 jointly %y0 + : memref, memref, memref jointly memref + call @printMemref1dI32(%x0) : (memref) -> () + call @printMemref1dI32(%x1) : (memref) -> () + call @printMemref1dI32(%x2) : (memref) -> () + call @printMemref1dI32(%y0) : (memref) -> () // Release the buffers. memref.dealloc %x0 : memref diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir @@ -132,6 +132,34 @@ vector.print %y0v2 : vector<5xi32> %y1v2 = vector.transfer_read %y1[%i0], %c100: memref, vector<5xi32> vector.print %y1v2 : vector<5xi32> + // Heap sort. + // CHECK: ( 1, 1, 2, 5, 10 ) + // CHECK: ( 3, 3, 1, 10, 1 ) + // CHECK: ( 9, 9, 4, 7, 2 ) + // CHECK: ( 7, 8, 10, 9, 6 ) + // CHECK: ( 7, 4, 7, 9, 5 ) + call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1) + : (memref>, i32, i32, i32, i32, i32) -> () + call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3) + : (memref>, i32, i32, i32, i32, i32) -> () + call @storeValuesToStrided(%x2, %c2, %c4, %c9, %c7, %c9) + : (memref>, i32, i32, i32, i32, i32) -> () + call @storeValuesToStrided(%y0, %c6, %c10, %c8, %c9, %c7) + : (memref>, i32, i32, i32, i32, i32) -> () + call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7) + : (memref, i32, i32, i32, i32, i32) -> () + sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index} + : memref jointly memref + %x0v3 = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> + vector.print %x0v3 : vector<5xi32> + %x1v3 = vector.transfer_read %x1[%i0], %c100: memref>, vector<5xi32> + vector.print %x1v3 : vector<5xi32> + %x2v3 = vector.transfer_read %x2[%i0], %c100: memref>, vector<5xi32> + vector.print %x2v3 : vector<5xi32> + %y0v3 = vector.transfer_read %y0[%i0], %c100: memref>, vector<5xi32> + vector.print %y0v3 : vector<5xi32> + %y1v3 = vector.transfer_read %y1[%i0], %c100: memref, vector<5xi32> + vector.print %y1v3 : vector<5xi32> // Release the buffers. memref.dealloc %xy : memref