diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -43,8 +43,8 @@ static constexpr const char kSortStableFuncNamePrefix[] = "_sparse_sort_stable_"; -using FuncGeneratorType = function_ref; +using FuncGeneratorType = function_ref; /// Constructs a function name with this format to facilitate quick sort: /// __..._ for sort @@ -69,15 +69,21 @@ /// parameters `nx` and `ny` tell the number of x and y values provided /// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction /// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo. +// +// All sorting function generators take (lo, hi, xs, ys) in `operands` as +// parameters for the sorting functions. Other parameters, such as the recursive +// call depth, are appended to the end of the parameter list as +// "trailing parameters". static FlatSymbolRefAttr getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, StringRef namePrefix, uint64_t nx, uint64_t ny, bool isCoo, - ValueRange operands, FuncGeneratorType createFunc) { + ValueRange operands, FuncGeneratorType createFunc, + uint32_t nTrailingP = 0) { SmallString<32> nameBuffer; llvm::raw_svector_ostream nameOstream(nameBuffer); getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo, - operands); + operands.drop_back(nTrailingP)); ModuleOp module = insertPoint->getParentOfType(); MLIRContext *context = module.getContext(); @@ -93,7 +99,7 @@ loc, nameOstream.str(), FunctionType::get(context, operands.getTypes(), resultTypes)); func.setPrivate(); - createFunc(builder, module, func, nx, ny, isCoo); + createFunc(builder, module, func, nx, ny, isCoo, nTrailingP); } return result; @@ -242,7 +248,10 @@ // and so on ... static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo) { + bool isCoo, uint32_t nTrailingP = 0) { + // Compare functions don't use trailing parameters. + (void)nTrailingP; + assert(nTrailingP == 0); createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, createEqCompare); } @@ -302,7 +311,10 @@ // and so on ... static void createLessThanFunc(OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo) { + bool isCoo, uint32_t nTrailingP = 0) { + // Compare functions don't use trailing parameters. + (void)nTrailingP; + assert(nTrailingP == 0); createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, createLessThanCompare); } @@ -322,7 +334,10 @@ // static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo) { + bool isCoo, uint32_t nTrailingP = 0) { + // Binary search doesn't use trailing parameters. + (void)nTrailingP; + assert(nTrailingP == 0); OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); @@ -330,7 +345,7 @@ Location loc = func.getLoc(); ValueRange args = entryBlock->getArguments(); Value p = args[hiIdx]; - SmallVector types(2, p.getType()); // only two + SmallVector types(2, p.getType()); // Only two types. scf::WhileOp whileOp = builder.create( loc, types, SmallVector{args[loIdx], args[hiIdx]}); @@ -363,7 +378,7 @@ Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, - compareOperands, createLessThanFunc); + compareOperands, createLessThanFunc, nTrailingP); Value cond2 = builder .create(loc, lessThanFunc, TypeRange{i1Type}, compareOperands) @@ -560,7 +575,10 @@ // } static void createPartitionFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo) { + bool isCoo, uint32_t nTrailingP = 0) { + // Quick sort partition doesn't use trailing parameters. + (void)nTrailingP; + assert(nTrailingP == 0); OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); @@ -675,7 +693,8 @@ // } static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo) { + bool isCoo, uint32_t nTrailingP) { + (void)nTrailingP; OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); @@ -728,7 +747,10 @@ // } static void createSortStableFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo) { + bool isCoo, uint32_t nTrailingP) { + // Stable sort function doesn't use trailing parameters. + (void)nTrailingP; + assert(nTrailingP == 0); OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); @@ -821,9 +843,10 @@ : kSortNonstableFuncNamePrefix); FuncGeneratorType funcGenerator = isStable ? createSortStableFunc : createSortNonstableFunc; + uint32_t nTrailingP = 0; FlatSymbolRefAttr func = getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx, - ny, isCoo, operands, funcGenerator); + ny, isCoo, operands, funcGenerator, nTrailingP); rewriter.replaceOpWithNewOp(op, func, TypeRange(), operands); return success(); }