diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -280,6 +280,11 @@ return builder.create(loc, i); } +/// Generates a constant of `i64` type. +inline Value constantI64(OpBuilder &builder, Location loc, int64_t i) { + return builder.create(loc, i, 64); +} + /// Generates a constant of `i32` type. inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) { return builder.create(loc, i, 32); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -43,8 +43,8 @@ static constexpr const char kSortStableFuncNamePrefix[] = "_sparse_sort_stable_"; -using FuncGeneratorType = function_ref; +using FuncGeneratorType = function_ref; /// Constructs a function name with this format to facilitate quick sort: /// __..._ for sort @@ -70,15 +70,21 @@ /// parameters `nx` and `ny` tell the number of x and y values provided /// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction /// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo. +// +// All sorting function generators take (lo, hi, xs, ys) in `operands` as +// parameters for the sorting functions. Other parameters, such as the recursive +// call depth, are appended to the end of the parameter list as +// "trailing parameters". static FlatSymbolRefAttr getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, StringRef namePrefix, uint64_t nx, uint64_t ny, bool isCoo, - ValueRange operands, FuncGeneratorType createFunc) { + ValueRange operands, FuncGeneratorType createFunc, + uint32_t nTrailingP = 0) { SmallString<32> nameBuffer; llvm::raw_svector_ostream nameOstream(nameBuffer); getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo, - operands); + operands.drop_back(nTrailingP)); ModuleOp module = insertPoint->getParentOfType(); MLIRContext *context = module.getContext(); @@ -94,7 +100,7 @@ loc, nameOstream.str(), FunctionType::get(context, operands.getTypes(), resultTypes)); func.setPrivate(); - createFunc(builder, module, func, nx, ny, isCoo); + createFunc(builder, module, func, nx, ny, isCoo, nTrailingP); } return result; @@ -243,7 +249,10 @@ // and so on ... static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo) { + bool isCoo, uint32_t nTrailingP = 0) { + // Compare functions don't use trailing parameters. + (void)nTrailingP; + assert(nTrailingP == 0); createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, createEqCompare); } @@ -303,7 +312,10 @@ // and so on ... static void createLessThanFunc(OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo) { + bool isCoo, uint32_t nTrailingP = 0) { + // Compare functions don't use trailing parameters. + (void)nTrailingP; + assert(nTrailingP == 0); createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, createLessThanCompare); } @@ -323,7 +335,10 @@ // static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo) { + bool isCoo, uint32_t nTrailingP = 0) { + // Binary search doesn't use trailing parameters. + (void)nTrailingP; + assert(nTrailingP == 0); OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); @@ -331,7 +346,7 @@ Location loc = func.getLoc(); ValueRange args = entryBlock->getArguments(); Value p = args[hiIdx]; - SmallVector types(2, p.getType()); // only two + SmallVector types(2, p.getType()); // Only two types. scf::WhileOp whileOp = builder.create( loc, types, SmallVector{args[loIdx], args[hiIdx]}); @@ -364,7 +379,7 @@ Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, - compareOperands, createLessThanFunc); + compareOperands, createLessThanFunc, nTrailingP); Value cond2 = builder .create(loc, lessThanFunc, TypeRange{i1Type}, compareOperands) @@ -571,7 +586,10 @@ // } static void createPartitionFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo) { + bool isCoo, uint32_t nTrailingP = 0) { + // Quick sort partition doesn't use trailing parameters. + (void)nTrailingP; + assert(nTrailingP == 0); OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); @@ -686,7 +704,8 @@ // } static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo) { + bool isCoo, uint32_t nTrailingP) { + (void)nTrailingP; OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); @@ -739,7 +758,10 @@ // } static void createSortStableFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo) { + bool isCoo, uint32_t nTrailingP) { + // Stable sort function doesn't use trailing parameters. + (void)nTrailingP; + assert(nTrailingP == 0); OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); @@ -830,9 +852,10 @@ : kSortNonstableFuncNamePrefix); FuncGeneratorType funcGenerator = op.getStable() ? createSortStableFunc : createSortNonstableFunc; + uint32_t nTrailingP = 0; FlatSymbolRefAttr func = getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx, - ny, isCoo, operands, funcGenerator); + ny, isCoo, operands, funcGenerator, nTrailingP); rewriter.replaceOpWithNewOp(op, func, TypeRange(), operands); return success(); }