diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -15,6 +15,8 @@ // //===----------------------------------------------------------------------===// +#include + #include "CodegenUtils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -312,6 +314,302 @@ return forOp; } +constexpr uint64_t loIdx = 0; +constexpr uint64_t hiIdx = 1; +constexpr uint64_t xStartIdx = 2; + +/// Constructs a slice that drops the first `n` values. +inline ValueRange skipFront(ValueRange operands, size_t n) { + return operands.slice(n, operands.size() - n); +} + +/// Constructs a function name with this format to facilitate quick sort: +/// __..._ +static void getSortingFuncName(llvm::raw_svector_ostream &nameOstream, + size_t dim, ValueRange operands, + StringRef namePrefix) { + nameOstream + << namePrefix << dim << "_" + << operands[xStartIdx].getType().cast().getElementType(); + + for (Value v : skipFront(operands, xStartIdx + dim)) + nameOstream << "_" << v.getType().cast().getElementType(); +} + +/// Looks up a function that is appropriate for the given operands being +/// sorted, and creates such a function if it doesn't exist yet. +static FlatSymbolRefAttr getSortingFunc( + OpBuilder &builder, func::FuncOp insertPoint, size_t dim, + TypeRange resultTypes, ValueRange operands, + function_ref createFunc, + StringRef namePrefix) { + SmallString<32> nameBuffer; + llvm::raw_svector_ostream nameOstream(nameBuffer); + getSortingFuncName(nameOstream, dim, operands, namePrefix); + + ModuleOp module = insertPoint->getParentOfType(); + MLIRContext *context = module.getContext(); + auto result = SymbolRefAttr::get(context, nameOstream.str()); + auto func = module.lookupSymbol(result.getAttr()); + + if (!func) { + // Create the function. + OpBuilder::InsertionGuard insertionGuard(builder); + builder.setInsertionPoint(insertPoint); + Location loc = insertPoint.getLoc(); + func = builder.create( + loc, nameOstream.str(), + FunctionType::get(context, operands.getTypes(), resultTypes)); + func.setPrivate(); + createFunc(builder, module, func, dim); + } + + return result; +} + +/// Creates a function for swapping the values in index i and j for all the +/// buffers. +// +// The generate IR corresponds to this C like algorithm: +// if (i != j) { +// swap(x0[i], x0[j]); +// swap(x1[i], x1[j]); +// ... +// swap(xn[i], xn[j]); +// swap(y0[i], y0[j]); +// ... +// swap(yn[i], yn[j]); +// } +static void createSwapFunc(OpBuilder &builder, ModuleOp unused, + func::FuncOp func, size_t dim) { + OpBuilder::InsertionGuard insertionGuard(builder); + + Block *entryBlock = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + + Location loc = func.getLoc(); + ValueRange args = entryBlock->getArguments(); + Value i = args[0]; + Value j = args[1]; + Value cond = + builder.create(loc, arith::CmpIPredicate::ne, i, j); + scf::IfOp ifOp = builder.create(loc, cond, /*else=*/false); + + // If i!=j swap values in the buffers. + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + for (auto iter = args.begin() + xStartIdx; iter != args.end(); iter++) { + Value vi = builder.create(loc, *iter, i); + Value vj = builder.create(loc, *iter, j); + builder.create(loc, vj, *iter, i); + builder.create(loc, vi, *iter, j); + } + + builder.setInsertionPointAfter(ifOp); + builder.create(loc); +} + +/// Creates a block of code to compare the xs values in index i and j for all +/// the dimensions in a recursive manner. +// +// The generate IR corresponds to this C like algorithm: +// if (x0(i) < x0(j)) +// return true; +// else if (x0(j) < x0(i)) +// return false; +// else +// return createLessThenCompare(x1..n) +static scf::IfOp createLessThenCompare(OpBuilder &builder, Location loc, + Value i, Value j, ValueRange xs) { + OpBuilder::InsertionGuard insertionGuard(builder); + + Value f = constantI1(builder, loc, false); + Value t = constantI1(builder, loc, true); + Value vi = builder.create(loc, xs[0], i); + Value vj = builder.create(loc, xs[0], j); + + Value cond = + builder.create(loc, arith::CmpIPredicate::ult, vi, vj); + scf::IfOp ifOp = + builder.create(loc, f.getType(), cond, /*else=*/true); + // If (xs[0](i) < xs[0](j)). + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + builder.create(loc, t); + + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + if (xs.size() == 1) { + // Finish checking all dimensions. + builder.create(loc, f); + } else { + cond = + builder.create(loc, arith::CmpIPredicate::ult, vj, vi); + scf::IfOp ifOp2 = + builder.create(loc, f.getType(), cond, /*else=*/true); + // Otherwise if (xs[0](j) < xs[0](i)). + builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); + builder.create(loc, f); + + // Otherwise check the remaining dimensions recursively. + builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); + scf::IfOp ifOp3 = + createLessThenCompare(builder, loc, i, j, xs.drop_front()); + builder.create(loc, ifOp3.getResult(0)); + + builder.setInsertionPointAfter(ifOp2); + builder.create(loc, ifOp2.getResult(0)); + } + + return ifOp; +} + +/// Creates a function to compare the xs values in index i and j for all the +/// dimensions. The function returns true iff xs(i) < xs(j) +static void createLessThenFunc(OpBuilder &builder, ModuleOp unused, + func::FuncOp func, size_t dim) { + OpBuilder::InsertionGuard insertionGuard(builder); + + Block *entryBlock = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + Location loc = func.getLoc(); + ValueRange args = entryBlock->getArguments(); + scf::IfOp ifOp = createLessThenCompare(builder, loc, args[0], args[1], + skipFront(args, xStartIdx)); + + builder.setInsertionPointAfter(ifOp); + builder.create(loc, ifOp.getResult(0)); +} + +/// Creates a function to perform quick sort partition on the values in the +/// range of index [lo, hi), assuming lo < hi. +// +// The generated IR corresponds to this C like algorithm: +// int partition(lo, hi, data) { +// pivot = data[hi - 1]; +// i = (lo – 1) // RHS of the pivot found so far. +// for (j = lo; j < hi - 1; j++){ +// if (data[j] < pivot){ +// i++; +// swap data[i] and data[j] +// } +// } +// i++ +// swap data[i] and data[hi-1]) +// return i +// } +static void createPartitionFunc(OpBuilder &builder, ModuleOp module, + func::FuncOp func, size_t dim) { + OpBuilder::InsertionGuard insertionGuard(builder); + + Block *entryBlock = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + + MLIRContext *context = module.getContext(); + Location loc = func.getLoc(); + ValueRange args = entryBlock->getArguments(); + Value lo = args[loIdx]; + Value c1 = constantIndex(builder, loc, 1); + Value i = builder.create(loc, lo, c1); + Value him1 = builder.create(loc, args[hiIdx], c1); + scf::ForOp forOp = + builder.create(loc, lo, him1, c1, ValueRange{i}); + + // Start the for-stmt body. + builder.setInsertionPointToStart(forOp.getBody()); + Value j = forOp.getInductionVar(); + SmallVector compareOperands{j, him1}; + ValueRange xs = args.slice(xStartIdx, dim); + compareOperands.append(xs.begin(), xs.end()); + Type compareResultType = IntegerType::get(context, 1, IntegerType::Signless); + FlatSymbolRefAttr lessThenFunc = + getSortingFunc(builder, func, dim, {compareResultType}, compareOperands, + createLessThenFunc, "_sparse_less_then_"); + Value cond = + builder + .create(loc, lessThenFunc, TypeRange{compareResultType}, + compareOperands) + .getResult(0); + scf::IfOp ifOp = + builder.create(loc, i.getType(), cond, /*else=*/true); + + // The if-stmt true branch: i++; swap(data[i], data[j]); yield i. + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value i1 = + builder.create(loc, forOp.getRegionIterArgs().front(), c1); + SmallVector swapOperands{i1, j}; + swapOperands.append(args.begin() + xStartIdx, args.end()); + FlatSymbolRefAttr swapFunc = + getSortingFunc(builder, func, dim, TypeRange(), swapOperands, + createSwapFunc, "_sparse_swap_"); + builder.create(loc, swapFunc, TypeRange(), swapOperands); + builder.create(loc, i1); + + // The if-stmt false branch: yield i. + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + builder.create(loc, forOp.getRegionIterArgs().front()); + + // After the if-stmt, yield the updated i value to end the for-stmt body. + builder.setInsertionPointAfter(ifOp); + builder.create(loc, ifOp.getResult(0)); + + // After the for-stmt: i++; swap(data[i], data[him1]); return i. + builder.setInsertionPointAfter(forOp); + i1 = builder.create(loc, forOp.getResult(0), c1); + swapOperands[0] = i1; + swapOperands[1] = him1; + builder.create(loc, swapFunc, TypeRange(), swapOperands); + builder.create(loc, i1); +} + +/// Creates a function to perform quick sort on the value in the range of +/// index [lo, hi). +// +// The generate IR corresponds to this C like algorithm: +// void quickSort(lo, hi, data) { +// if (lo < hi) { +// p = partition(low, high, data); +// quickSort(lo, p, data); +// quickSort(p + 1, hi, data); +// } +// } +static void createSortFunc(OpBuilder &builder, ModuleOp module, + func::FuncOp func, size_t dim) { + OpBuilder::InsertionGuard insertionGuard(builder); + Block *entryBlock = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + + MLIRContext *context = module.getContext(); + Location loc = func.getLoc(); + ValueRange args = entryBlock->getArguments(); + Value lo = args[loIdx]; + Value hi = args[hiIdx]; + Value cond = + builder.create(loc, arith::CmpIPredicate::ult, lo, hi); + scf::IfOp ifOp = builder.create(loc, cond, /*else=*/false); + + // The if-stmt true branch. + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + FlatSymbolRefAttr partition_func = + getSortingFunc(builder, func, dim, {IndexType::get(context)}, args, + createPartitionFunc, "_sparse_partition_"); + auto p = builder.create(loc, partition_func, + TypeRange{IndexType::get(context)}, + ValueRange(args)); + + SmallVector low_operands{lo, p.getResult(0)}; + low_operands.append(args.begin() + xStartIdx, args.end()); + builder.create(loc, func, low_operands); + + SmallVector high_operands{ + builder.create(loc, p.getResult(0), + constantIndex(builder, loc, 1)), + hi}; + high_operands.append(args.begin() + xStartIdx, args.end()); + builder.create(loc, func, high_operands); + + // After the if-stmt. + builder.setInsertionPointAfter(ifOp); + builder.create(loc); +} + //===----------------------------------------------------------------------===// // Codegen rules. //===----------------------------------------------------------------------===// @@ -729,6 +1027,40 @@ } }; +/// Sparse codegen rule for the sort operator. +class SparseSortConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(SortOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + SmallVector operands{constantIndex(rewriter, loc, 0), + adaptor.getN()}; + + // Convert `values` to have dynamic shape and append them to `operands`. + auto addValues = [&](ValueRange values) { + for (Value v : values) { + auto mtp = v.getType().cast(); + if (!mtp.isDynamicDim(0)) { + auto new_mtp = + MemRefType::get({ShapedType::kDynamicSize}, mtp.getElementType()); + v = rewriter.create(loc, new_mtp, v); + } + operands.push_back(v); + } + }; + ValueRange xs = adaptor.getXs(); + addValues(xs); + addValues(adaptor.getYs()); + auto insertPoint = op->getParentOfType(); + FlatSymbolRefAttr func = + getSortingFunc(rewriter, insertPoint, xs.size(), TypeRange(), operands, + createSortFunc, "_sparse_sort_"); + rewriter.replaceOpWithNewOp(op, func, TypeRange(), operands); + return success(); + } +}; } // namespace //===----------------------------------------------------------------------===// @@ -752,7 +1084,7 @@ SparseCastConverter, SparseTensorAllocConverter, SparseTensorDeallocConverter, SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter, - SparsePushBackConverter, SparseToPointersConverter, - SparseToIndicesConverter, SparseToValuesConverter>( - typeConverter, patterns.getContext()); + SparsePushBackConverter, SparseSortConverter, + SparseToPointersConverter, SparseToIndicesConverter, + SparseToValuesConverter>(typeConverter, patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -410,3 +410,15 @@ %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 to memref return %0 : memref } + +// For the sort operator, we only check the generated supporting function +// prototypes. We have integration test for the correction. +// CHECK-DAG: func.func private @_sparse_sort_1_i8_f32_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_partition_1_i8_f32_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_less_then_1_i8(%arg0: index, %arg1: index, %arg2: memref) -> i1 { +// CHECK-DAG: func.func private @_sparse_swap_1_i8_f32_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-LABEL: func.func @sparse_sort_1d2v +func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xi8>, memref, memref<10xindex>) { + sparse_tensor.sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref, memref<10xindex> + return %arg1, %arg2, %arg3 : memref<10xi8>, memref, memref<10xindex> +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_sort.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_sort.mlir @@ -0,0 +1,93 @@ +// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +module { + // Stores 5 values to the memref buffer. + func.func @storeValuesTo(%b: memref<5xi32>, %v0: i32, %v1: i32, %v2: i32, + %v3: i32, %v4: i32) -> () { + %i0 = arith.constant 0 : index + %i1 = arith.constant 1 : index + %i2 = arith.constant 2 : index + %i3 = arith.constant 3 : index + %i4 = arith.constant 4 : index + memref.store %v0, %b[%i0] : memref<5xi32> + memref.store %v1, %b[%i1] : memref<5xi32> + memref.store %v2, %b[%i2] : memref<5xi32> + memref.store %v3, %b[%i3] : memref<5xi32> + memref.store %v4, %b[%i4] : memref<5xi32> + return + } + + // The main driver. + func.func @entry() { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %c3 = arith.constant 3 : i32 + %c4 = arith.constant 4 : i32 + %c5 = arith.constant 5 : i32 + %c6 = arith.constant 6 : i32 + %c7 = arith.constant 7 : i32 + %c8 = arith.constant 8 : i32 + %c9 = arith.constant 9 : i32 + %c10 = arith.constant 10 : i32 + %c100 = arith.constant 100 : i32 + + %i0 = arith.constant 0 : index + %i4 = arith.constant 4 : index + %i5 = arith.constant 5 : index + + %x0 = memref.alloc() : memref<5xi32> + call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1) + : (memref<5xi32>, i32, i32, i32, i32, i32) -> () + + // Sort 0 elements. + // CHECK: ( 10, 2, 0, 5, 1 ) + sparse_tensor.sort %i0, %x0 : memref<5xi32> + %x0v0 = vector.transfer_read %x0[%i0], %c100: memref<5xi32>, vector<5xi32> + vector.print %x0v0 : vector<5xi32> + + // Sort the first 4 elements, with the last valid value untouched. + // CHECK: ( 0, 2, 5, 10, 1 ) + sparse_tensor.sort %i4, %x0 : memref<5xi32> + %x0v1 = vector.transfer_read %x0[%i0], %c100: memref<5xi32>, vector<5xi32> + vector.print %x0v1 : vector<5xi32> + + // Sort "parallel arrays". + %x1 = memref.alloc() : memref<5xi32> + %x2 = memref.alloc() : memref<5xi32> + %y0 = memref.alloc() : memref<5xi32> + call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1) + : (memref<5xi32>, i32, i32, i32, i32, i32) -> () + call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3) + : (memref<5xi32>, i32, i32, i32, i32, i32) -> () + call @storeValuesTo(%x2, %c2, %c4, %c4, %c7, %c9) + : (memref<5xi32>, i32, i32, i32, i32, i32) -> () + call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7) + : (memref<5xi32>, i32, i32, i32, i32, i32) -> () + // CHECK: ( 0, 1, 2, 5, 10 ) + // CHECK: ( 3, 3, 1, 10, 1 ) + // CHECK: ( 4, 9, 4, 7, 2 ) + // CHECK: ( 8, 7, 10, 9, 6 ) + sparse_tensor.sort %i5, %x0, %x1, %x2 jointly %y0 + : memref<5xi32>, memref<5xi32>, memref<5xi32> jointly memref<5xi32> + %x0v2 = vector.transfer_read %x0[%i0], %c100: memref<5xi32>, vector<5xi32> + vector.print %x0v2 : vector<5xi32> + %x1v = vector.transfer_read %x1[%i0], %c100: memref<5xi32>, vector<5xi32> + vector.print %x1v : vector<5xi32> + %x2v = vector.transfer_read %x2[%i0], %c100: memref<5xi32>, vector<5xi32> + vector.print %x2v : vector<5xi32> + %y0v = vector.transfer_read %y0[%i0], %c100: memref<5xi32>, vector<5xi32> + vector.print %y0v : vector<5xi32> + + // Release the buffers. + memref.dealloc %x0 : memref<5xi32> + memref.dealloc %x1 : memref<5xi32> + memref.dealloc %x2 : memref<5xi32> + memref.dealloc %y0 : memref<5xi32> + return + } +}