diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -28,9 +28,19 @@ // Helper methods for the actual rewriting rules. //===---------------------------------------------------------------------===// -constexpr uint64_t loIdx = 0; -constexpr uint64_t hiIdx = 1; -constexpr uint64_t xStartIdx = 2; +static constexpr uint64_t loIdx = 0; +static constexpr uint64_t hiIdx = 1; +static constexpr uint64_t xStartIdx = 2; + +static constexpr const char kMaySwapFuncNamePrefix[] = "_sparse_may_swap_"; +static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_"; +static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_"; +static constexpr const char kBinarySearchFuncNamePrefix[] = + "_sparse_binary_search_"; +static constexpr const char kSortNonstableFuncNamePrefix[] = + "_sparse_sort_nonstable_"; +static constexpr const char kSortStableFuncNamePrefix[] = + "_sparse_sort_stable_"; typedef function_ref FuncGeneratorType; @@ -201,6 +211,79 @@ builder.create(loc, topIfOp.getResult(0)); } +/// Creates a function to use a binary search to find the insertion point for +/// inserting xs[hi] to the sorted values xs[lo..hi). +// +// The generate IR corresponds to this C like algorithm: +// p = hi +// while (lo < hi) +// mid = (lo + hi) >> 1 +// if (xs[p] < xs[mid]) +// hi = mid +// else +// lo = mid - 1 +// return lo; +// +static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, + func::FuncOp func, size_t dim) { + OpBuilder::InsertionGuard insertionGuard(builder); + Block *entryBlock = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + + Location loc = func.getLoc(); + ValueRange args = entryBlock->getArguments(); + Value p = args[hiIdx]; + SmallVector types(2, p.getType()); + scf::WhileOp whileOp = builder.create( + loc, types, SmallVector{args[loIdx], args[hiIdx]}); + + // The before-region of the WhileOp. + Block *before = + builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); + builder.setInsertionPointToEnd(before); + Value cond1 = builder.create(loc, arith::CmpIPredicate::ult, + before->getArgument(0), + before->getArgument(1)); + builder.create(loc, cond1, before->getArguments()); + + // The after-region of the WhileOp. + Block *after = + builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); + builder.setInsertionPointToEnd(after); + Value lo = after->getArgument(0); + Value hi = after->getArgument(1); + // Compute mid = (lo + hi) >> 1. + Value c1 = constantIndex(builder, loc, 1); + Value mid = builder.create( + loc, builder.create(loc, lo, hi), c1); + Value midp1 = builder.create(loc, mid, c1); + + // Compare xs[p] < xs[mid]. + SmallVector compareOperands{p, mid}; + compareOperands.append(args.begin() + xStartIdx, + args.begin() + xStartIdx + dim); + Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); + FlatSymbolRefAttr lessThanFunc = + getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix, + dim, compareOperands, createLessThanFunc); + Value cond2 = builder + .create(loc, lessThanFunc, TypeRange{i1Type}, + compareOperands) + .getResult(0); + + // Update lo and hi for the WhileOp as follows: + // if (xs[p] < xs[mid])) + // hi = mid; + // else + // lo = mid + 1; + Value newLo = builder.create(loc, cond2, lo, midp1); + Value newHi = builder.create(loc, cond2, mid, hi); + builder.create(loc, ValueRange{newLo, newHi}); + + builder.setInsertionPointAfter(whileOp); + builder.create(loc, whileOp.getResult(0)); +} + /// Creates a function to perform quick sort partition on the values in the /// range of index [lo, hi), assuming lo < hi. // @@ -243,7 +326,7 @@ compareOperands.append(xs.begin(), xs.end()); Type i1Type = IntegerType::get(context, 1, IntegerType::Signless); FlatSymbolRefAttr lessThanFunc = - getMangledSortHelperFunc(builder, func, {i1Type}, "_sparse_less_than_", + getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix, dim, compareOperands, createLessThanFunc); Value cond = builder .create(loc, lessThanFunc, TypeRange{i1Type}, @@ -258,9 +341,9 @@ builder.create(loc, forOp.getRegionIterArgs().front(), c1); SmallVector swapOperands{i1, j}; swapOperands.append(args.begin() + xStartIdx, args.end()); - FlatSymbolRefAttr swapFunc = - getMangledSortHelperFunc(builder, func, TypeRange(), "_sparse_may_swap_", - dim, swapOperands, createMaySwapFunc); + FlatSymbolRefAttr swapFunc = getMangledSortHelperFunc( + builder, func, TypeRange(), kMaySwapFuncNamePrefix, dim, swapOperands, + createMaySwapFunc); builder.create(loc, swapFunc, TypeRange(), swapOperands); builder.create(loc, i1); @@ -292,8 +375,8 @@ // quickSort(p + 1, hi, data); // } // } -static void createSortFunc(OpBuilder &builder, ModuleOp module, - func::FuncOp func, size_t dim) { +static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module, + func::FuncOp func, size_t dim) { OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); @@ -310,8 +393,8 @@ // The if-stmt true branch. builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( - builder, func, {IndexType::get(context)}, "_sparse_partition_", dim, args, - createPartitionFunc); + builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, dim, + args, createPartitionFunc); auto p = builder.create( loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args)); @@ -331,6 +414,78 @@ builder.create(loc); } +/// Creates a function to perform insertion sort on the values in the range of +/// index [lo, hi). +// +// The generate IR corresponds to this C like algorithm: +// void insertionSort(lo, hi, data) { +// for (i = lo+1; i < hi; i++) { +// d = data[i]; +// p = binarySearch(lo, i-1, data) +// for (j = 0; j > i - p; j++) +// data[i-j] = data[i-j-1] +// data[p] = d +// } +// } +static void createSortStableFunc(OpBuilder &builder, ModuleOp module, + func::FuncOp func, size_t dim) { + OpBuilder::InsertionGuard insertionGuard(builder); + Block *entryBlock = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + + MLIRContext *context = module.getContext(); + Location loc = func.getLoc(); + ValueRange args = entryBlock->getArguments(); + Value c1 = constantIndex(builder, loc, 1); + Value lo = args[loIdx]; + Value hi = args[hiIdx]; + Value lop1 = builder.create(loc, lo, c1); + + // Start the outer for-stmt with induction variable i. + scf::ForOp forOpI = builder.create(loc, lop1, hi, c1); + builder.setInsertionPointToStart(forOpI.getBody()); + Value i = forOpI.getInductionVar(); + + // Binary search to find the insertion point p. + SmallVector operands{lo, i}; + operands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + dim); + FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( + builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, + dim, operands, createBinarySearchFunc); + Value p = builder + .create(loc, searchFunc, TypeRange{c1.getType()}, + operands) + .getResult(0); + + // Move the value at data[i] to a temporary location. + ValueRange data = args.drop_front(xStartIdx); + SmallVector d; + for (Value v : data) + d.push_back(builder.create(loc, v, i)); + + // Start the inner for-stmt with induction variable j, for moving data[p..i) + // to data[p+1..i+1). + Value imp = builder.create(loc, i, p); + Value c0 = constantIndex(builder, loc, 0); + scf::ForOp forOpJ = builder.create(loc, c0, imp, c1); + builder.setInsertionPointToStart(forOpJ.getBody()); + Value j = forOpJ.getInductionVar(); + Value imj = builder.create(loc, i, j); + Value imjm1 = builder.create(loc, imj, c1); + for (Value v : data) { + Value t = builder.create(loc, v, imjm1); + builder.create(loc, t, v, imj); + } + + // Store the value at data[i] to data[p]. + builder.setInsertionPointAfter(forOpJ); + for (auto it : llvm::zip(d, data)) + builder.create(loc, std::get<0>(it), std::get<1>(it), p); + + builder.setInsertionPointAfter(forOpI); + builder.create(loc); +} + //===---------------------------------------------------------------------===// // The actual sparse buffer rewriting rules. //===---------------------------------------------------------------------===// @@ -425,9 +580,13 @@ addValues(xs); addValues(op.getYs()); auto insertPoint = op->getParentOfType(); - FlatSymbolRefAttr func = getMangledSortHelperFunc( - rewriter, insertPoint, TypeRange(), "_sparse_sort_", xs.size(), - operands, createSortFunc); + SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix + : kSortNonstableFuncNamePrefix); + FuncGeneratorType funcGenerator = + op.getStable() ? createSortStableFunc : createSortNonstableFunc; + FlatSymbolRefAttr func = + getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, + xs.size(), operands, funcGenerator); rewriter.replaceOpWithNewOp(op, func, TypeRange(), operands); return success(); } diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --sparse-buffer-rewrite --canonicalize --cse | FileCheck %s +// RUN: mlir-opt %s -split-input-file --sparse-buffer-rewrite --canonicalize --cse | FileCheck %s // CHECK-LABEL: func @sparse_push_back( // CHECK-SAME: %[[A:.*]]: memref, @@ -26,6 +26,8 @@ return %0 : memref } +// ----- + // CHECK-LABEL: func @sparse_push_back_inbound( // CHECK-SAME: %[[A:.*]]: memref, // CHECK-SAME: %[[B:.*]]: memref, @@ -42,6 +44,8 @@ return %0 : memref } +// ----- + // CHECK-LABEL: func.func private @_sparse_less_than_1_i8( // CHECK-SAME: %[[I:arg0]]: index, // CHECK-SAME: %[[J:.*]]: index, @@ -101,7 +105,7 @@ // CHECK: return %[[I3p1]] // CHECK: } -// CHECK-LABEL: func.func private @_sparse_sort_1_i8_f32_index( +// CHECK-LABEL: func.func private @_sparse_sort_nonstable_1_i8_f32_index( // CHECK-SAME: %[[L:arg0]]: index, // CHECK-SAME: %[[H:.*]]: index, // CHECK-SAME: %[[X0:.*]]: memref, @@ -111,9 +115,9 @@ // CHECK: %[[COND:.*]] = arith.cmpi ult, %[[L]], %[[H]] // CHECK: scf.if %[[COND]] { // CHECK: %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]]) -// CHECK: func.call @_sparse_sort_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]]) +// CHECK: func.call @_sparse_sort_nonstable_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]]) // CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]] : index -// CHECK: func.call @_sparse_sort_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]]) +// CHECK: func.call @_sparse_sort_nonstable_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]]) // CHECK: } // CHECK: return // CHECK: } @@ -126,7 +130,7 @@ // CHECK: %[[C0:.*]] = arith.constant 0 // CHECK: %[[DX0:.*]] = memref.cast %[[X0]] : memref<10xi8> to memref // CHECK: %[[DY1:.*]] = memref.cast %[[Y1]] : memref<10xindex> to memref -// CHECK: call @_sparse_sort_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]]) +// CHECK: call @_sparse_sort_nonstable_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]]) // CHECK: return %[[X0]], %[[Y0]], %[[Y1]] // CHECK: } func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref, %arg3: memref<10xindex>) @@ -135,15 +139,31 @@ return %arg1, %arg2, %arg3 : memref<10xi8>, memref, memref<10xindex> } +// ----- + // Only check the generated supporting function now. We have integration test // to verify correctness of the generated code. // // CHECK-DAG: func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_may_swap_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_sort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_sort_nonstable_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_3d func.func @sparse_sort_3d(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { sparse_tensor.sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> } + +// ----- + +// Only check the generated supporting functions. We have integration test to +// verify correctness of the generated code. +// +// CHECK-DAG: func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> i1 { +// CHECK-DAG: func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-LABEL: func.func @sparse_sort_3d_stable +func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { + sparse_tensor.sort stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> + return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir @@ -51,12 +51,24 @@ sparse_tensor.sort %i0, %x0 : memref %x0v0 = vector.transfer_read %x0[%i0], %c100: memref, vector<5xi32> vector.print %x0v0 : vector<5xi32> + // Stable sort. + // CHECK: ( 10, 2, 0, 5, 1 ) + sparse_tensor.sort stable %i0, %x0 : memref + %x0v0s = vector.transfer_read %x0[%i0], %c100: memref, vector<5xi32> + vector.print %x0v0s : vector<5xi32> // Sort the first 4 elements, with the last valid value untouched. // CHECK: ( 0, 2, 5, 10, 1 ) sparse_tensor.sort %i4, %x0 : memref %x0v1 = vector.transfer_read %x0[%i0], %c100: memref, vector<5xi32> vector.print %x0v1 : vector<5xi32> + // Stable sort. + // CHECK: ( 0, 2, 5, 10, 1 ) + call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1) + : (memref, i32, i32, i32, i32, i32) -> () + sparse_tensor.sort stable %i4, %x0 : memref + %x0v1s = vector.transfer_read %x0[%i0], %c100: memref, vector<5xi32> + vector.print %x0v1s : vector<5xi32> // Prepare more buffers of different dimensions. %x1s = memref.alloc() : memref<10xi32> @@ -65,20 +77,20 @@ %x2 = memref.cast %x2s : memref<6xi32> to memref %y0s = memref.alloc() : memref<7xi32> %y0 = memref.cast %y0s : memref<7xi32> to memref - call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1) + + // Sort "parallel arrays". + // CHECK: ( 1, 1, 2, 5, 10 ) + // CHECK: ( 3, 3, 1, 10, 1 ) + // CHECK: ( 9, 9, 4, 7, 2 ) + // CHECK: ( 7, 8, 10, 9, 6 ) + call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1) : (memref, i32, i32, i32, i32, i32) -> () call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3) : (memref, i32, i32, i32, i32, i32) -> () - call @storeValuesTo(%x2, %c2, %c4, %c4, %c7, %c9) + call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9) : (memref, i32, i32, i32, i32, i32) -> () call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7) : (memref, i32, i32, i32, i32, i32) -> () - - // Sort "parallel arrays". - // CHECK: ( 0, 1, 2, 5, 10 ) - // CHECK: ( 3, 3, 1, 10, 1 ) - // CHECK: ( 4, 9, 4, 7, 2 ) - // CHECK: ( 8, 7, 10, 9, 6 ) sparse_tensor.sort %i5, %x0, %x1, %x2 jointly %y0 : memref, memref, memref jointly memref %x0v2 = vector.transfer_read %x0[%i0], %c100: memref, vector<5xi32> @@ -89,6 +101,29 @@ vector.print %x2v : vector<5xi32> %y0v = vector.transfer_read %y0[%i0], %c100: memref, vector<5xi32> vector.print %y0v : vector<5xi32> + // Stable sort. + // CHECK: ( 1, 1, 2, 5, 10 ) + // CHECK: ( 3, 3, 1, 10, 1 ) + // CHECK: ( 9, 9, 4, 7, 2 ) + // CHECK: ( 8, 7, 10, 9, 6 ) + call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1) + : (memref, i32, i32, i32, i32, i32) -> () + call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3) + : (memref, i32, i32, i32, i32, i32) -> () + call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9) + : (memref, i32, i32, i32, i32, i32) -> () + call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7) + : (memref, i32, i32, i32, i32, i32) -> () + sparse_tensor.sort stable %i5, %x0, %x1, %x2 jointly %y0 + : memref, memref, memref jointly memref + %x0v2s = vector.transfer_read %x0[%i0], %c100: memref, vector<5xi32> + vector.print %x0v2s : vector<5xi32> + %x1vs = vector.transfer_read %x1[%i0], %c100: memref, vector<5xi32> + vector.print %x1vs : vector<5xi32> + %x2vs = vector.transfer_read %x2[%i0], %c100: memref, vector<5xi32> + vector.print %x2vs : vector<5xi32> + %y0vs = vector.transfer_read %y0[%i0], %c100: memref, vector<5xi32> + vector.print %y0vs : vector<5xi32> // Release the buffers. memref.dealloc %x0 : memref