diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -43,19 +43,25 @@ static constexpr const char kSortStableFuncNamePrefix[] = "_sparse_sort_stable_"; -using FuncGeneratorType = - function_ref; +using FuncGeneratorType = function_ref; /// Constructs a function name with this format to facilitate quick sort: -/// __..._ +/// __..._ for sort +/// __coo__..._ for sort_coo static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, - StringRef namePrefix, size_t dim, + StringRef namePrefix, uint64_t nx, + uint64_t ny, bool isCoo, ValueRange operands) { nameOstream - << namePrefix << dim << "_" + << namePrefix << nx << "_" << operands[xStartIdx].getType().cast().getElementType(); - for (Value v : operands.drop_front(xStartIdx + dim)) + if (isCoo) + nameOstream << "_coo_" << ny; + + uint64_t yBufferOffset = isCoo ? 1 : nx; + for (Value v : operands.drop_front(xStartIdx + yBufferOffset)) nameOstream << "_" << v.getType().cast().getElementType(); } @@ -64,11 +70,12 @@ static FlatSymbolRefAttr getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, StringRef namePrefix, - size_t dim, ValueRange operands, - FuncGeneratorType createFunc) { + uint64_t nx, uint64_t ny, bool isCoo, + ValueRange operands, FuncGeneratorType createFunc) { SmallString<32> nameBuffer; llvm::raw_svector_ostream nameOstream(nameBuffer); - getMangledSortHelperFuncName(nameOstream, namePrefix, dim, operands); + getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo, + operands); ModuleOp module = insertPoint->getParentOfType(); MLIRContext *context = module.getContext(); @@ -84,12 +91,61 @@ loc, nameOstream.str(), FunctionType::get(context, operands.getTypes(), resultTypes)); func.setPrivate(); - createFunc(builder, module, func, dim); + createFunc(builder, module, func, nx, ny, isCoo); } return result; } +/// Creates a code block to process each pair of (xs[i], xs[j]) for sorting. +/// The code to process the value pairs is generated by `bodyBuilder`. +static void forEachIJPairInXs( + OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, + bool isCoo, function_ref bodyBuilder) { + Value iOffset, jOffset; + if (isCoo) { + Value cstep = constantIndex(builder, loc, nx + ny); + iOffset = builder.create(loc, args[0], cstep); + jOffset = builder.create(loc, args[1], cstep); + } + for (uint64_t k = 0; k < nx; k++) { + scf::IfOp ifOp; + Value i, j, buffer; + if (isCoo) { + Value ck = constantIndex(builder, loc, k); + i = builder.create(loc, ck, iOffset); + j = builder.create(loc, ck, jOffset); + buffer = args[xStartIdx]; + } else { + i = args[0]; + j = args[1]; + buffer = args[xStartIdx + k]; + } + bodyBuilder(k, i, j, buffer); + } +} + +/// Creates a code block to process each pair of (xys[i], xys[j]) for sorting. +/// The code to process the value pairs is generated by `bodyBuilder`. +static void forEachIJPairInAllBuffers( + OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, + bool isCoo, function_ref bodyBuilder) { + + // Create code for the first (nx + ny) buffers. When isCoo==true, these + // logical buffers are all from the xy buffer of the sort_coo operator. + forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder); + + uint64_t numHandledBuffers = isCoo ? 1 : nx + ny; + + // Create code for the remaining buffers. + Value i = args[0]; + Value j = args[1]; + for (const auto& arg : + llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) { + bodyBuilder(arg.index() + nx + ny, i, j, arg.value()); + } +} + /// Creates a code block for swapping the values in index i and j for all the /// buffers. // @@ -101,21 +157,23 @@ // swap(y0[i], y0[j]); // ... // swap(yn[i], yn[j]); -static void createSwap(OpBuilder &builder, Location loc, ValueRange args) { - Value i = args[0]; - Value j = args[1]; - for (auto arg : args.drop_front(xStartIdx)) { - Value vi = builder.create(loc, arg, i); - Value vj = builder.create(loc, arg, j); - builder.create(loc, vj, arg, i); - builder.create(loc, vi, arg, j); - } +static void createSwap(OpBuilder &builder, Location loc, ValueRange args, + uint64_t nx, uint64_t ny, bool isCoo) { + auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) { + Value vi = builder.create(loc, buffer, i); + Value vj = builder.create(loc, buffer, j); + builder.create(loc, vj, buffer, i); + builder.create(loc, vi, buffer, j); + }; + + forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair); } /// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to /// compare each pair is create via `compareBuilder`. static void createCompareFuncImplementation( - OpBuilder &builder, ModuleOp unused, func::FuncOp func, size_t dim, + OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx, + uint64_t ny, bool isCoo, function_ref compareBuilder) { OpBuilder::InsertionGuard insertionGuard(builder); @@ -126,17 +184,18 @@ ValueRange args = entryBlock->getArguments(); scf::IfOp topIfOp; - for (const auto &item : llvm::enumerate(args.slice(xStartIdx, dim))) { - scf::IfOp ifOp = compareBuilder(builder, loc, args[0], args[1], - item.value(), (item.index() == dim - 1)); - if (item.index() == 0) { + auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { + scf::IfOp ifOp = compareBuilder(builder, loc, i, j, buffer, (k == nx - 1)); + if (k == 0) { topIfOp = ifOp; } else { OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPointAfter(ifOp); builder.create(loc, ifOp.getResult(0)); } - } + }; + + forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder); builder.setInsertionPointAfter(topIfOp); builder.create(loc, topIfOp.getResult(0)); @@ -180,8 +239,10 @@ // else if (x2[2] != x2[j])) // and so on ... static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused, - func::FuncOp func, size_t dim) { - createCompareFuncImplementation(builder, unused, func, dim, createEqCompare); + func::FuncOp func, uint64_t nx, uint64_t ny, + bool isCoo) { + createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, + createEqCompare); } /// Generates an if-statement to compare whether x[i] is less than x[j]. @@ -238,8 +299,9 @@ // else if (x1[j] < x1[i])) // and so on ... static void createLessThanFunc(OpBuilder &builder, ModuleOp unused, - func::FuncOp func, size_t dim) { - createCompareFuncImplementation(builder, unused, func, dim, + func::FuncOp func, uint64_t nx, uint64_t ny, + bool isCoo) { + createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, createLessThanCompare); } @@ -257,7 +319,8 @@ // return lo; // static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, - func::FuncOp func, size_t dim) { + func::FuncOp func, uint64_t nx, uint64_t ny, + bool isCoo) { OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); @@ -292,12 +355,13 @@ // Compare xs[p] < xs[mid]. SmallVector compareOperands{p, mid}; + uint64_t numXBuffers = isCoo ? 1 : nx; compareOperands.append(args.begin() + xStartIdx, - args.begin() + xStartIdx + dim); + args.begin() + xStartIdx + numXBuffers); Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); - FlatSymbolRefAttr lessThanFunc = - getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix, - dim, compareOperands, createLessThanFunc); + FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( + builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, + compareOperands, createLessThanFunc); Value cond2 = builder .create(loc, lessThanFunc, TypeRange{i1Type}, compareOperands) @@ -324,7 +388,8 @@ /// xs[i] == xs[p]. static std::pair createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, - ValueRange xs, Value i, Value p, size_t dim, int step) { + ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny, + bool isCoo, int step) { Location loc = func.getLoc(); scf::WhileOp whileOp = builder.create(loc, TypeRange{i.getType()}, ValueRange{i}); @@ -344,9 +409,9 @@ compareOperands.append(xs.begin(), xs.end()); MLIRContext *context = module.getContext(); Type i1Type = IntegerType::get(context, 1, IntegerType::Signless); - FlatSymbolRefAttr lessThanFunc = - getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix, - dim, compareOperands, createLessThanFunc); + FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( + builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, + compareOperands, createLessThanFunc); Value cond = builder .create(loc, lessThanFunc, TypeRange{i1Type}, compareOperands) @@ -365,8 +430,8 @@ compareOperands[0] = i; compareOperands[1] = p; FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc( - builder, func, {i1Type}, kCompareEqFuncNamePrefix, dim, compareOperands, - createEqCompareFunc); + builder, func, {i1Type}, kCompareEqFuncNamePrefix, nx, ny, isCoo, + compareOperands, createEqCompareFunc); Value compareEq = builder .create(loc, compareEqFunc, TypeRange{i1Type}, @@ -405,7 +470,8 @@ // return p // } static void createPartitionFunc(OpBuilder &builder, ModuleOp module, - func::FuncOp func, size_t dim) { + func::FuncOp func, uint64_t nx, uint64_t ny, + bool isCoo) { OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); @@ -442,11 +508,14 @@ j = after->getArgument(1); p = after->getArgument(2); - auto [iresult, iCompareEq] = createScanLoop( - builder, module, func, args.slice(xStartIdx, dim), i, p, dim, 1); + uint64_t numXBuffers = isCoo ? 1 : nx; + auto [iresult, iCompareEq] = + createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), + i, p, nx, ny, isCoo, 1); i = iresult; - auto [jresult, jCompareEq] = createScanLoop( - builder, module, func, args.slice(xStartIdx, dim), j, p, dim, -1); + auto [jresult, jCompareEq] = + createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), + j, p, nx, ny, isCoo, -1); j = jresult; // If i < j: @@ -455,7 +524,7 @@ builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); SmallVector swapOperands{i, j}; swapOperands.append(args.begin() + xStartIdx, args.end()); - createSwap(builder, loc, swapOperands); + createSwap(builder, loc, swapOperands, nx, ny, isCoo); // If the pivot is moved, update p with the new pivot. Value icond = builder.create(loc, arith::CmpIPredicate::eq, i, p); @@ -515,7 +584,8 @@ // } // } static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module, - func::FuncOp func, size_t dim) { + func::FuncOp func, uint64_t nx, uint64_t ny, + bool isCoo) { OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); @@ -532,8 +602,8 @@ // The if-stmt true branch. builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( - builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, dim, - args, createPartitionFunc); + builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx, + ny, isCoo, args, createPartitionFunc); auto p = builder.create( loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args)); @@ -567,7 +637,8 @@ // } // } static void createSortStableFunc(OpBuilder &builder, ModuleOp module, - func::FuncOp func, size_t dim) { + func::FuncOp func, uint64_t nx, uint64_t ny, + bool isCoo) { OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); @@ -587,20 +658,23 @@ // Binary search to find the insertion point p. SmallVector operands{lo, i}; - operands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + dim); + operands.append(args.begin() + xStartIdx, args.end()); FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( - builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, - dim, operands, createBinarySearchFunc); + builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx, + ny, isCoo, operands, createBinarySearchFunc); Value p = builder .create(loc, searchFunc, TypeRange{c1.getType()}, operands) .getResult(0); // Move the value at data[i] to a temporary location. - ValueRange data = args.drop_front(xStartIdx); + operands[0] = operands[1] = i; SmallVector d; - for (Value v : data) - d.push_back(builder.create(loc, v, i)); + forEachIJPairInAllBuffers( + builder, loc, operands, nx, ny, isCoo, + [&](uint64_t unused, Value i, Value unused2, Value buffer) { + d.push_back(builder.create(loc, buffer, i)); + }); // Start the inner for-stmt with induction variable j, for moving data[p..i) // to data[p+1..i+1). @@ -610,21 +684,58 @@ builder.setInsertionPointToStart(forOpJ.getBody()); Value j = forOpJ.getInductionVar(); Value imj = builder.create(loc, i, j); - Value imjm1 = builder.create(loc, imj, c1); - for (Value v : data) { - Value t = builder.create(loc, v, imjm1); - builder.create(loc, t, v, imj); - } + operands[1] = imj; + operands[0] = builder.create(loc, imj, c1); + forEachIJPairInAllBuffers( + builder, loc, operands, nx, ny, isCoo, + [&](uint64_t unused, Value imjm1, Value imj, Value buffer) { + Value t = builder.create(loc, buffer, imjm1); + builder.create(loc, t, buffer, imj); + }); // Store the value at data[i] to data[p]. builder.setInsertionPointAfter(forOpJ); - for (auto it : llvm::zip(d, data)) - builder.create(loc, std::get<0>(it), std::get<1>(it), p); + operands[0] = operands[1] = p; + forEachIJPairInAllBuffers( + builder, loc, operands, nx, ny, isCoo, + [&](uint64_t k, Value p, Value usused, Value buffer) { + builder.create(loc, d[k], buffer, p); + }); builder.setInsertionPointAfter(forOpI); builder.create(loc); } +/// Implements the rewriting for operator sort and sort_coo. +template +LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx, + uint64_t ny, bool isCoo, + PatternRewriter &rewriter) { + Location loc = op.getLoc(); + SmallVector operands{constantIndex(rewriter, loc, 0), op.getN()}; + + // Convert `values` to have dynamic shape and append them to `operands`. + for (Value v : xys) { + auto mtp = v.getType().cast(); + if (!mtp.isDynamicDim(0)) { + auto newMtp = + MemRefType::get({ShapedType::kDynamicSize}, mtp.getElementType()); + v = rewriter.create(loc, newMtp, v); + } + operands.push_back(v); + } + auto insertPoint = op->template getParentOfType(); + SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix + : kSortNonstableFuncNamePrefix); + FuncGeneratorType funcGenerator = + op.getStable() ? createSortStableFunc : createSortNonstableFunc; + FlatSymbolRefAttr func = + getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx, + ny, isCoo, operands, funcGenerator); + rewriter.replaceOpWithNewOp(op, func, TypeRange(), operands); + return success(); +} + //===---------------------------------------------------------------------===// // The actual sparse buffer rewriting rules. //===---------------------------------------------------------------------===// @@ -740,34 +851,33 @@ LogicalResult matchAndRewrite(SortOp op, PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - SmallVector operands{constantIndex(rewriter, loc, 0), op.getN()}; - - // Convert `values` to have dynamic shape and append them to `operands`. - auto addValues = [&](ValueRange values) { - for (Value v : values) { - auto mtp = v.getType().cast(); - if (!mtp.isDynamicDim(0)) { - auto newMtp = - MemRefType::get({ShapedType::kDynamicSize}, mtp.getElementType()); - v = rewriter.create(loc, newMtp, v); - } - operands.push_back(v); - } - }; - ValueRange xs = op.getXs(); - addValues(xs); - addValues(op.getYs()); - auto insertPoint = op->getParentOfType(); - SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix - : kSortNonstableFuncNamePrefix); - FuncGeneratorType funcGenerator = - op.getStable() ? createSortStableFunc : createSortNonstableFunc; - FlatSymbolRefAttr func = - getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, - xs.size(), operands, funcGenerator); - rewriter.replaceOpWithNewOp(op, func, TypeRange(), operands); - return success(); + SmallVector xys(op.getXs()); + xys.append(op.getYs().begin(), op.getYs().end()); + return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0, + /*isCoo=*/false, rewriter); + } +}; + +/// Sparse rewriting rule for the sort_coo operator. +struct SortCooRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SortCooOp op, + PatternRewriter &rewriter) const override { + SmallVector xys; + xys.push_back(op.getXy()); + xys.append(op.getYs().begin(), op.getYs().end()); + uint64_t nx = 1; + if (auto nxAttr = op.getNxAttr()) + nx = nxAttr.getInt(); + + uint64_t ny = 0; + if (auto nyAttr = op.getNyAttr()) + ny = nyAttr.getInt(); + + return matchAndRewriteSortOp(op, xys, nx, ny, + /*isCoo=*/true, rewriter); } }; @@ -778,5 +888,6 @@ //===---------------------------------------------------------------------===// void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -170,6 +170,7 @@ // Most ops in the sparse dialect must go! target.addIllegalDialect(); target.addLegalOp(); + target.addLegalOp(); target.addLegalOp(); // All dynamic rules below accept new function, call, return, and various // tensor and bufferization operations as legal output of the rewriting diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -194,3 +194,33 @@ sparse_tensor.sort stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> } + +// ----- + +// Only check the generated supporting functions. We have integration test to +// verify correctness of the generated code. +// +// CHECK-DAG: func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref) -> i1 { +// CHECK-DAG: func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref) -> i1 { +// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_sort_nonstable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-LABEL: func.func @sparse_sort_coo +func.func @sparse_sort_coo(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { + sparse_tensor.sort_coo %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> + return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> +} + +// ----- + +// Only check the generated supporting functions. We have integration test to +// verify correctness of the generated code. +// +// CHECK-DAG: func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref) -> i1 { +// CHECK-DAG: func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-LABEL: func.func @sparse_sort_coo_stable +func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { + sparse_tensor.sort_coo stable %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> + return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> +} + diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir @@ -0,0 +1,134 @@ +// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +module { + // Stores 5 values to the memref buffer. + func.func @storeValuesTo(%b: memref, %v0: i32, %v1: i32, %v2: i32, + %v3: i32, %v4: i32) -> () { + %i0 = arith.constant 0 : index + %i1 = arith.constant 1 : index + %i2 = arith.constant 2 : index + %i3 = arith.constant 3 : index + %i4 = arith.constant 4 : index + memref.store %v0, %b[%i0] : memref + memref.store %v1, %b[%i1] : memref + memref.store %v2, %b[%i2] : memref + memref.store %v3, %b[%i3] : memref + memref.store %v4, %b[%i4] : memref + return + } + + // Stores 5 values to the memref buffer. + func.func @storeValuesToStrided(%b: memref>, %v0: i32, %v1: i32, %v2: i32, + %v3: i32, %v4: i32) -> () { + %i0 = arith.constant 0 : index + %i1 = arith.constant 1 : index + %i2 = arith.constant 2 : index + %i3 = arith.constant 3 : index + %i4 = arith.constant 4 : index + memref.store %v0, %b[%i0] : memref> + memref.store %v1, %b[%i1] : memref> + memref.store %v2, %b[%i2] : memref> + memref.store %v3, %b[%i3] : memref> + memref.store %v4, %b[%i4] : memref> + return + } + + // The main driver. + func.func @entry() { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %c3 = arith.constant 3 : i32 + %c4 = arith.constant 4 : i32 + %c5 = arith.constant 5 : i32 + %c6 = arith.constant 6 : i32 + %c7 = arith.constant 7 : i32 + %c8 = arith.constant 8 : i32 + %c9 = arith.constant 9 : i32 + %c10 = arith.constant 10 : i32 + %c100 = arith.constant 100 : i32 + + %i0 = arith.constant 0 : index + %i1 = arith.constant 1 : index + %i2 = arith.constant 2 : index + %i3 = arith.constant 3 : index + %i4 = arith.constant 4 : index + %i5 = arith.constant 5 : index + + // Prepare a buffer for x0, x1, x2, y0 and a buffer for y1. + %xys = memref.alloc() : memref<20xi32> + %xy = memref.cast %xys : memref<20xi32> to memref + %x0 = memref.subview %xy[%i0][%i5][%i4] : memref to memref> + %x1 = memref.subview %xy[%i1][%i5][%i4] : memref to memref> + %x2 = memref.subview %xy[%i2][%i5][%i4] : memref to memref> + %y0 = memref.subview %xy[%i3][%i5][%i4] : memref to memref> + %y1s = memref.alloc() : memref<7xi32> + %y1 = memref.cast %y1s : memref<7xi32> to memref + + // Sort "parallel arrays". + // CHECK: ( 1, 1, 2, 5, 10 ) + // CHECK: ( 3, 3, 1, 10, 1 ) + // CHECK: ( 9, 9, 4, 7, 2 ) + // CHECK: ( 8, 7, 10, 9, 6 ) + // CHECK: ( 4, 7, 7, 9, 5 ) + call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1) + : (memref>, i32, i32, i32, i32, i32) -> () + call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3) + : (memref>, i32, i32, i32, i32, i32) -> () + call @storeValuesToStrided(%x2, %c2, %c4, %c9, %c7, %c9) + : (memref>, i32, i32, i32, i32, i32) -> () + call @storeValuesToStrided(%y0, %c6, %c10, %c8, %c9, %c7) + : (memref>, i32, i32, i32, i32, i32) -> () + call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7) + : (memref, i32, i32, i32, i32, i32) -> () + sparse_tensor.sort_coo %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index} + : memref jointly memref + %x0v = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> + vector.print %x0v : vector<5xi32> + %x1v = vector.transfer_read %x1[%i0], %c100: memref>, vector<5xi32> + vector.print %x1v : vector<5xi32> + %x2v = vector.transfer_read %x2[%i0], %c100: memref>, vector<5xi32> + vector.print %x2v : vector<5xi32> + %y0v = vector.transfer_read %y0[%i0], %c100: memref>, vector<5xi32> + vector.print %y0v : vector<5xi32> + %y1v = vector.transfer_read %y1[%i0], %c100: memref, vector<5xi32> + vector.print %y1v : vector<5xi32> + // Stable sort. + // CHECK: ( 1, 1, 2, 5, 10 ) + // CHECK: ( 3, 3, 1, 10, 1 ) + // CHECK: ( 9, 9, 4, 7, 2 ) + // CHECK: ( 8, 7, 10, 9, 6 ) + // CHECK: ( 4, 7, 7, 9, 5 ) + call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1) + : (memref>, i32, i32, i32, i32, i32) -> () + call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3) + : (memref>, i32, i32, i32, i32, i32) -> () + call @storeValuesToStrided(%x2, %c2, %c4, %c9, %c7, %c9) + : (memref>, i32, i32, i32, i32, i32) -> () + call @storeValuesToStrided(%y0, %c6, %c10, %c8, %c9, %c7) + : (memref>, i32, i32, i32, i32, i32) -> () + call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7) + : (memref, i32, i32, i32, i32, i32) -> () + sparse_tensor.sort_coo stable %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index} + : memref jointly memref + %x0v2 = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> + vector.print %x0v2 : vector<5xi32> + %x1v2 = vector.transfer_read %x1[%i0], %c100: memref>, vector<5xi32> + vector.print %x1v2 : vector<5xi32> + %x2v2 = vector.transfer_read %x2[%i0], %c100: memref>, vector<5xi32> + vector.print %x2v2 : vector<5xi32> + %y0v2 = vector.transfer_read %y0[%i0], %c100: memref>, vector<5xi32> + vector.print %y0v2 : vector<5xi32> + %y1v2 = vector.transfer_read %y1[%i0], %c100: memref, vector<5xi32> + vector.print %y1v2 : vector<5xi32> + + // Release the buffers. + memref.dealloc %xy : memref + memref.dealloc %y1 : memref + return + } +}