diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -166,6 +166,9 @@ std::unique_ptr createDenseBufferizationPass( const bufferization::OneShotBufferizationOptions &options); +void populateSparseBufferRewriting(RewritePatternSet &patterns); +std::unique_ptr createSparseBufferRewritePass(); + //===----------------------------------------------------------------------===// // Registration. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -178,4 +178,19 @@ ]; } +def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> { + let summary = "Rewrite sparse primitives on buffers to actual code"; + let description = [{ + A pass that rewrites sparse primitives on buffers to the MLIR implementation + of the primitives. For example, sparse_tensor.sort operator is implemented + in this pass. + }]; + let constructor = "mlir::createSparseBufferRewritePass()"; + let dependentDialects = [ + "arith::ArithmeticDialect", + "memref::MemRefDialect", + "scf::SCFDialect", + "sparse_tensor::SparseTensorDialect", + ]; +} #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -64,6 +64,7 @@ options.sparseTensorConversionOptions())); else pm.addPass(createSparseTensorCodegenPass()); + pm.addPass(createSparseBufferRewritePass()); pm.addNestedPass(createCanonicalizerPass()); pm.addPass(createDenseBufferizationPass( getBufferizationOptions(/*analysisOnly=*/false))); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ CodegenUtils.cpp DenseBufferizationPass.cpp Sparsification.cpp + SparseBufferRewriting.cpp SparseTensorCodegen.cpp SparseTensorConversion.cpp SparseTensorPasses.cpp diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -0,0 +1,382 @@ +//===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements rewriting rules that are specific to sparse tensor +// primitives with memref operands. +// +//===----------------------------------------------------------------------===// + +#include "CodegenUtils.h" + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Support/LLVM.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; + +//===---------------------------------------------------------------------===// +// Helper methods for the actual rewriting rules. +//===---------------------------------------------------------------------===// + +constexpr uint64_t loIdx = 0; +constexpr uint64_t hiIdx = 1; +constexpr uint64_t xStartIdx = 2; + +typedef function_ref + FuncGeneratorType; + +/// Constructs a function name with this format to facilitate quick sort: +/// __..._ +static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, + StringRef namePrefix, size_t dim, + ValueRange operands) { + nameOstream + << namePrefix << dim << "_" + << operands[xStartIdx].getType().cast().getElementType(); + + for (Value v : operands.drop_front(xStartIdx + dim)) + nameOstream << "_" << v.getType().cast().getElementType(); +} + +/// Looks up a function that is appropriate for the given operands being +/// sorted, and creates such a function if it doesn't exist yet. +static FlatSymbolRefAttr +getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, + TypeRange resultTypes, StringRef namePrefix, + size_t dim, ValueRange operands, + FuncGeneratorType createFunc) { + SmallString<32> nameBuffer; + llvm::raw_svector_ostream nameOstream(nameBuffer); + getMangledSortHelperFuncName(nameOstream, namePrefix, dim, operands); + + ModuleOp module = insertPoint->getParentOfType(); + MLIRContext *context = module.getContext(); + auto result = SymbolRefAttr::get(context, nameOstream.str()); + auto func = module.lookupSymbol(result.getAttr()); + + if (!func) { + // Create the function. + OpBuilder::InsertionGuard insertionGuard(builder); + builder.setInsertionPoint(insertPoint); + Location loc = insertPoint.getLoc(); + func = builder.create( + loc, nameOstream.str(), + FunctionType::get(context, operands.getTypes(), resultTypes)); + func.setPrivate(); + createFunc(builder, module, func, dim); + } + + return result; +} + +/// Creates a function for swapping the values in index i and j for all the +/// buffers. +// +// The generate IR corresponds to this C like algorithm: +// if (i != j) { +// swap(x0[i], x0[j]); +// swap(x1[i], x1[j]); +// ... +// swap(xn[i], xn[j]); +// swap(y0[i], y0[j]); +// ... +// swap(yn[i], yn[j]); +// } +static void createMaySwapFunc(OpBuilder &builder, ModuleOp unused, + func::FuncOp func, size_t dim) { + OpBuilder::InsertionGuard insertionGuard(builder); + + Block *entryBlock = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + + Location loc = func.getLoc(); + ValueRange args = entryBlock->getArguments(); + Value i = args[0]; + Value j = args[1]; + Value cond = + builder.create(loc, arith::CmpIPredicate::ne, i, j); + scf::IfOp ifOp = builder.create(loc, cond, /*else=*/false); + + // If i!=j swap values in the buffers. + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + for (auto arg : args.drop_front(xStartIdx)) { + Value vi = builder.create(loc, arg, i); + Value vj = builder.create(loc, arg, j); + builder.create(loc, vj, arg, i); + builder.create(loc, vi, arg, j); + } + + builder.setInsertionPointAfter(ifOp); + builder.create(loc); +} + +/// Generates an if-statement to compare x[i] and x[j]. +static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc, + Value i, Value j, Value x, + bool isLastDim) { + Value f = constantI1(builder, loc, false); + Value t = constantI1(builder, loc, true); + Value vi = builder.create(loc, x, i); + Value vj = builder.create(loc, x, j); + + Value cond = + builder.create(loc, arith::CmpIPredicate::ult, vi, vj); + scf::IfOp ifOp = + builder.create(loc, f.getType(), cond, /*else=*/true); + // If (x[i] < x[j]). + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + builder.create(loc, t); + + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + if (isLastDim == 1) { + // Finish checking all dimensions. + builder.create(loc, f); + } else { + cond = + builder.create(loc, arith::CmpIPredicate::ult, vj, vi); + scf::IfOp ifOp2 = + builder.create(loc, f.getType(), cond, /*else=*/true); + // Otherwise if (x[j] < x[i]). + builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); + builder.create(loc, f); + + // Otherwise check the remaining dimensions. + builder.setInsertionPointAfter(ifOp2); + builder.create(loc, ifOp2.getResult(0)); + // Set up the insertion point for the nested if-stmt that checks the + // remaining dimensions. + builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); + } + + return ifOp; +} + +/// Creates a function to compare the xs values in index i and j for all the +/// dimensions. The function returns true iff xs[i] < xs[j]. +// +// The generate IR corresponds to this C like algorithm: +// if (x0[i] < x0[j]) +// return true; +// else if (x0[j] < x0[i]) +// return false; +// else +// if (x1[i] < x1[j]) +// return true; +// else if (x1[j] < x1[i])) +// and so on ... +static void createLessThanFunc(OpBuilder &builder, ModuleOp unused, + func::FuncOp func, size_t dim) { + OpBuilder::InsertionGuard insertionGuard(builder); + + Block *entryBlock = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + Location loc = func.getLoc(); + ValueRange args = entryBlock->getArguments(); + + scf::IfOp topIfOp; + for (const auto &item : llvm::enumerate(args.slice(xStartIdx, dim))) { + scf::IfOp ifOp = + createLessThanCompare(builder, loc, args[0], args[1], item.value(), + (item.index() == dim - 1)); + if (item.index() == 0) { + topIfOp = ifOp; + } else { + OpBuilder::InsertionGuard insertionGuard(builder); + builder.setInsertionPointAfter(ifOp); + builder.create(loc, ifOp.getResult(0)); + } + } + + builder.setInsertionPointAfter(topIfOp); + builder.create(loc, topIfOp.getResult(0)); +} + +/// Creates a function to perform quick sort partition on the values in the +/// range of index [lo, hi), assuming lo < hi. +// +// The generated IR corresponds to this C like algorithm: +// int partition(lo, hi, data) { +// pivot = data[hi - 1]; +// i = (lo – 1) // RHS of the pivot found so far. +// for (j = lo; j < hi - 1; j++){ +// if (data[j] < pivot){ +// i++; +// swap data[i] and data[j] +// } +// } +// i++ +// swap data[i] and data[hi-1]) +// return i +// } +static void createPartitionFunc(OpBuilder &builder, ModuleOp module, + func::FuncOp func, size_t dim) { + OpBuilder::InsertionGuard insertionGuard(builder); + + Block *entryBlock = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + + MLIRContext *context = module.getContext(); + Location loc = func.getLoc(); + ValueRange args = entryBlock->getArguments(); + Value lo = args[loIdx]; + Value c1 = constantIndex(builder, loc, 1); + Value i = builder.create(loc, lo, c1); + Value him1 = builder.create(loc, args[hiIdx], c1); + scf::ForOp forOp = + builder.create(loc, lo, him1, c1, ValueRange{i}); + + // Start the for-stmt body. + builder.setInsertionPointToStart(forOp.getBody()); + Value j = forOp.getInductionVar(); + SmallVector compareOperands{j, him1}; + ValueRange xs = args.slice(xStartIdx, dim); + compareOperands.append(xs.begin(), xs.end()); + Type i1Type = IntegerType::get(context, 1, IntegerType::Signless); + FlatSymbolRefAttr lessThanFunc = + getMangledSortHelperFunc(builder, func, {i1Type}, "_sparse_less_than_", + dim, compareOperands, createLessThanFunc); + Value cond = builder + .create(loc, lessThanFunc, TypeRange{i1Type}, + compareOperands) + .getResult(0); + scf::IfOp ifOp = + builder.create(loc, i.getType(), cond, /*else=*/true); + + // The if-stmt true branch: i++; swap(data[i], data[j]); yield i. + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value i1 = + builder.create(loc, forOp.getRegionIterArgs().front(), c1); + SmallVector swapOperands{i1, j}; + swapOperands.append(args.begin() + xStartIdx, args.end()); + FlatSymbolRefAttr swapFunc = + getMangledSortHelperFunc(builder, func, TypeRange(), "_sparse_may_swap_", + dim, swapOperands, createMaySwapFunc); + builder.create(loc, swapFunc, TypeRange(), swapOperands); + builder.create(loc, i1); + + // The if-stmt false branch: yield i. + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + builder.create(loc, forOp.getRegionIterArgs().front()); + + // After the if-stmt, yield the updated i value to end the for-stmt body. + builder.setInsertionPointAfter(ifOp); + builder.create(loc, ifOp.getResult(0)); + + // After the for-stmt: i++; swap(data[i], data[him1]); return i. + builder.setInsertionPointAfter(forOp); + i1 = builder.create(loc, forOp.getResult(0), c1); + swapOperands[0] = i1; + swapOperands[1] = him1; + builder.create(loc, swapFunc, TypeRange(), swapOperands); + builder.create(loc, i1); +} + +/// Creates a function to perform quick sort on the value in the range of +/// index [lo, hi). +// +// The generate IR corresponds to this C like algorithm: +// void quickSort(lo, hi, data) { +// if (lo < hi) { +// p = partition(low, high, data); +// quickSort(lo, p, data); +// quickSort(p + 1, hi, data); +// } +// } +static void createSortFunc(OpBuilder &builder, ModuleOp module, + func::FuncOp func, size_t dim) { + OpBuilder::InsertionGuard insertionGuard(builder); + Block *entryBlock = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + + MLIRContext *context = module.getContext(); + Location loc = func.getLoc(); + ValueRange args = entryBlock->getArguments(); + Value lo = args[loIdx]; + Value hi = args[hiIdx]; + Value cond = + builder.create(loc, arith::CmpIPredicate::ult, lo, hi); + scf::IfOp ifOp = builder.create(loc, cond, /*else=*/false); + + // The if-stmt true branch. + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( + builder, func, {IndexType::get(context)}, "_sparse_partition_", dim, args, + createPartitionFunc); + auto p = builder.create( + loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args)); + + SmallVector lowOperands{lo, p.getResult(0)}; + lowOperands.append(args.begin() + xStartIdx, args.end()); + builder.create(loc, func, lowOperands); + + SmallVector highOperands{ + builder.create(loc, p.getResult(0), + constantIndex(builder, loc, 1)), + hi}; + highOperands.append(args.begin() + xStartIdx, args.end()); + builder.create(loc, func, highOperands); + + // After the if-stmt. + builder.setInsertionPointAfter(ifOp); + builder.create(loc); +} + +//===---------------------------------------------------------------------===// +// The actual sparse buffer rewriting rules. +//===---------------------------------------------------------------------===// + +namespace { + +/// Sparse rewriting rule for the sort operator. +struct SortRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SortOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + SmallVector operands{constantIndex(rewriter, loc, 0), op.getN()}; + + // Convert `values` to have dynamic shape and append them to `operands`. + auto addValues = [&](ValueRange values) { + for (Value v : values) { + auto mtp = v.getType().cast(); + if (!mtp.isDynamicDim(0)) { + auto new_mtp = + MemRefType::get({ShapedType::kDynamicSize}, mtp.getElementType()); + v = rewriter.create(loc, new_mtp, v); + } + operands.push_back(v); + } + }; + ValueRange xs = op.getXs(); + addValues(xs); + addValues(op.getYs()); + auto insertPoint = op->getParentOfType(); + FlatSymbolRefAttr func = getMangledSortHelperFunc( + rewriter, insertPoint, TypeRange(), "_sparse_sort_", xs.size(), + operands, createSortFunc); + rewriter.replaceOpWithNewOp(op, func, TypeRange(), operands); + return success(); + } +}; + +} // namespace + +//===---------------------------------------------------------------------===// +// Methods that add patterns described in this file to a pattern list. +//===---------------------------------------------------------------------===// + +void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -24,6 +24,7 @@ #define GEN_PASS_DEF_SPARSIFICATIONPASS #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS #define GEN_PASS_DEF_SPARSETENSORCODEGEN +#define GEN_PASS_DEF_SPARSEBUFFERREWRITE #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" } // namespace mlir @@ -198,6 +199,20 @@ } }; +struct SparseBufferRewritePass + : public impl::SparseBufferRewriteBase { + + SparseBufferRewritePass() = default; + SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default; + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + populateSparseBufferRewriting(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -241,3 +256,7 @@ std::unique_ptr mlir::createSparseTensorCodegenPass() { return std::make_unique(); } + +std::unique_ptr mlir::createSparseBufferRewritePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -0,0 +1,107 @@ +// RUN: mlir-opt %s --sparse-buffer-rewrite --canonicalize --cse | FileCheck %s + +// CHECK-LABEL: func.func private @_sparse_less_than_1_i8( +// CHECK-SAME: %[[I:arg0]]: index, +// CHECK-SAME: %[[J:.*]]: index, +// CHECK-SAME: %[[X0:.*]]: memref) -> i1 { +// CHECK: %[[VI:.*]] = memref.load %[[X0]]{{\[}}%[[I]]] +// CHECK: %[[VJ:.*]] = memref.load %[[X0]]{{\[}}%[[J]]] +// CHECK: %[[C:.*]] = arith.cmpi ult, %[[VI]], %[[VJ]] +// CHECK: return %[[C]] +// CHECK: } + +// CHECK-LABEL: func.func private @_sparse_may_swap_1_i8_f32_index( +// CHECK-SAME: %[[I:arg0]]: index, +// CHECK-SAME: %[[J:.*]]: index, +// CHECK-SAME: %[[X0:.*]]: memref, +// CHECK-SAME: %[[Y0:.*]]: memref, +// CHECK-SAME: %[[Y1:.*]]: memref) { +// CHECK: %[[C:.*]] = arith.cmpi ne, %[[I]], %[[J]] +// CHECK: scf.if %[[C]] { +// CHECK: %[[Vx0i:.*]] = memref.load %[[X0]]{{\[}}%[[I]]] +// CHECK: %[[Vx0j:.*]] = memref.load %[[X0]]{{\[}}%[[J]]] +// CHECK: memref.store %[[Vx0j]], %[[X0]]{{\[}}%[[I]]] +// CHECK: memref.store %[[Vx0i]], %[[X0]]{{\[}}%[[J]]] +// CHECK: %[[Vy0i:.*]] = memref.load %[[Y0]]{{\[}}%[[I]]] +// CHECK: %[[Vy0j:.*]] = memref.load %[[Y0]]{{\[}}%[[J]]] +// CHECK: memref.store %[[Vy0j]], %[[Y0]]{{\[}}%[[I]]] +// CHECK: memref.store %[[Vy0i]], %[[Y0]]{{\[}}%[[J]]] +// CHECK: %[[Vy1i:.*]] = memref.load %[[Y1]]{{\[}}%[[I]]] +// CHECK: %[[Vy1j:.*]] = memref.load %[[Y1]]{{\[}}%[[J]]] +// CHECK: memref.store %[[Vy1j]], %[[Y1]]{{\[}}%[[I]]] +// CHECK: memref.store %[[Vy1i]], %[[Y1]]{{\[}}%[[J]]] +// CHECK: } +// CHECK: return +// CHECK: } + +// CHECK-LABEL: func.func private @_sparse_partition_1_i8_f32_index( +// CHECK-SAME: %[[L:arg0]]: index, +// CHECK-SAME: %[[H:.*]]: index, +// CHECK-SAME: %[[X0:.*]]: memref, +// CHECK-SAME: %[[Y0:.*]]: memref, +// CHECK-SAME: %[[Y1:.*]]: memref) -> index { +// CHECK: %[[C1:.*]] = arith.constant 1 +// CHECK: %[[I:.*]] = arith.subi %[[L]], %[[C1]] +// CHECK: %[[Hm1:.*]] = arith.subi %[[H]], %[[C1]] +// CHECK: %[[I3:.*]] = scf.for %[[J:.*]] = %[[L]] to %[[Hm1]] step %[[C1]] iter_args(%[[I2:.*]] = %[[I]]) -> (index) { +// CHECK: %[[COND:.*]] = func.call @_sparse_less_than_1_i8(%[[J]], %[[Hm1]], %[[X0]]) +// CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (index) { +// CHECK: %[[Ip1:.*]] = arith.addi %[[I2]], %[[C1]] +// CHECK: func.call @_sparse_may_swap_1_i8_f32_index(%[[Ip1]], %[[J]], %[[X0]], %[[Y0]], %[[Y1]]) +// CHECK: scf.yield %[[Ip1]] +// CHECK: } else { +// CHECK: scf.yield %[[I2]] +// CHECK: } +// CHECK: scf.yield %[[IF:.*]] +// CHECK: } +// CHECK: %[[I3p1:.*]] = arith.addi %[[I3:.*]], %[[C1]] : index +// CHECK: call @_sparse_may_swap_1_i8_f32_index(%[[I3p1]], %[[Hm1]], %[[X0]], %[[Y0]], %[[Y1]]) +// CHECK: return %[[I3p1]] +// CHECK: } + +// CHECK-LABEL: func.func private @_sparse_sort_1_i8_f32_index( +// CHECK-SAME: %[[L:arg0]]: index, +// CHECK-SAME: %[[H:.*]]: index, +// CHECK-SAME: %[[X0:.*]]: memref, +// CHECK-SAME: %[[Y0:.*]]: memref, +// CHECK-SAME: %[[Y1:.*]]: memref) { +// CHECK: %[[C1:.*]] = arith.constant 1 +// CHECK: %[[COND:.*]] = arith.cmpi ult, %[[L]], %[[H]] +// CHECK: scf.if %[[COND]] { +// CHECK: %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]]) +// CHECK: func.call @_sparse_sort_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]]) +// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]] : index +// CHECK: func.call @_sparse_sort_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]]) +// CHECK: } +// CHECK: return +// CHECK: } + +// CHECK-LABEL: func.func @sparse_sort_1d2v( +// CHECK-SAME: %[[N:.*]]: index, +// CHECK-SAME: %[[X0:.*]]: memref<10xi8>, +// CHECK-SAME: %[[Y0:.*]]: memref, +// CHECK-SAME: %[[Y1:.*]]: memref<10xindex>) -> (memref<10xi8>, memref, memref<10xindex>) { +// CHECK: %[[C0:.*]] = arith.constant 0 +// CHECK: %[[DX0:.*]] = memref.cast %[[X0]] : memref<10xi8> to memref +// CHECK: %[[DY1:.*]] = memref.cast %[[Y1]] : memref<10xindex> to memref +// CHECK: call @_sparse_sort_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]]) +// CHECK: return %[[X0]], %[[Y0]], %[[Y1]] +// CHECK: } +func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref, %arg3: memref<10xindex>) + -> (memref<10xi8>, memref, memref<10xindex>) { + sparse_tensor.sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref, memref<10xindex> + return %arg1, %arg2, %arg3 : memref<10xi8>, memref, memref<10xindex> +} + +// Only check the generated supporting function now. We have integration test +// to verify correctness of the generated code. +// +// CHECK-DAG: func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> i1 { +// CHECK-DAG: func.func private @_sparse_may_swap_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_sort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-LABEL: func.func @sparse_sort_3d +func.func @sparse_sort_3d(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { + sparse_tensor.sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> + return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir @@ -0,0 +1,100 @@ +// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +module { + // Stores 5 values to the memref buffer. + func.func @storeValuesTo(%b: memref, %v0: i32, %v1: i32, %v2: i32, + %v3: i32, %v4: i32) -> () { + %i0 = arith.constant 0 : index + %i1 = arith.constant 1 : index + %i2 = arith.constant 2 : index + %i3 = arith.constant 3 : index + %i4 = arith.constant 4 : index + memref.store %v0, %b[%i0] : memref + memref.store %v1, %b[%i1] : memref + memref.store %v2, %b[%i2] : memref + memref.store %v3, %b[%i3] : memref + memref.store %v4, %b[%i4] : memref + return + } + + // The main driver. + func.func @entry() { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %c3 = arith.constant 3 : i32 + %c4 = arith.constant 4 : i32 + %c5 = arith.constant 5 : i32 + %c6 = arith.constant 6 : i32 + %c7 = arith.constant 7 : i32 + %c8 = arith.constant 8 : i32 + %c9 = arith.constant 9 : i32 + %c10 = arith.constant 10 : i32 + %c100 = arith.constant 100 : i32 + + %i0 = arith.constant 0 : index + %i4 = arith.constant 4 : index + %i5 = arith.constant 5 : index + + // Prepare a buffer. + %x0s = memref.alloc() : memref<5xi32> + %x0 = memref.cast %x0s : memref<5xi32> to memref + call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1) + : (memref, i32, i32, i32, i32, i32) -> () + + // Sort 0 elements. + // CHECK: ( 10, 2, 0, 5, 1 ) + sparse_tensor.sort %i0, %x0 : memref + %x0v0 = vector.transfer_read %x0[%i0], %c100: memref, vector<5xi32> + vector.print %x0v0 : vector<5xi32> + + // Sort the first 4 elements, with the last valid value untouched. + // CHECK: ( 0, 2, 5, 10, 1 ) + sparse_tensor.sort %i4, %x0 : memref + %x0v1 = vector.transfer_read %x0[%i0], %c100: memref, vector<5xi32> + vector.print %x0v1 : vector<5xi32> + + // Prepare more buffers of different dimensions. + %x1s = memref.alloc() : memref<10xi32> + %x1 = memref.cast %x1s : memref<10xi32> to memref + %x2s = memref.alloc() : memref<6xi32> + %x2 = memref.cast %x2s : memref<6xi32> to memref + %y0s = memref.alloc() : memref<7xi32> + %y0 = memref.cast %y0s : memref<7xi32> to memref + call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1) + : (memref, i32, i32, i32, i32, i32) -> () + call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3) + : (memref, i32, i32, i32, i32, i32) -> () + call @storeValuesTo(%x2, %c2, %c4, %c4, %c7, %c9) + : (memref, i32, i32, i32, i32, i32) -> () + call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7) + : (memref, i32, i32, i32, i32, i32) -> () + + // Sort "parallel arrays". + // CHECK: ( 0, 1, 2, 5, 10 ) + // CHECK: ( 3, 3, 1, 10, 1 ) + // CHECK: ( 4, 9, 4, 7, 2 ) + // CHECK: ( 8, 7, 10, 9, 6 ) + sparse_tensor.sort %i5, %x0, %x1, %x2 jointly %y0 + : memref, memref, memref jointly memref + %x0v2 = vector.transfer_read %x0[%i0], %c100: memref, vector<5xi32> + vector.print %x0v2 : vector<5xi32> + %x1v = vector.transfer_read %x1[%i0], %c100: memref, vector<5xi32> + vector.print %x1v : vector<5xi32> + %x2v = vector.transfer_read %x2[%i0], %c100: memref, vector<5xi32> + vector.print %x2v : vector<5xi32> + %y0v = vector.transfer_read %y0[%i0], %c100: memref, vector<5xi32> + vector.print %y0v : vector<5xi32> + + // Release the buffers. + memref.dealloc %x0 : memref + memref.dealloc %x1 : memref + memref.dealloc %x2 : memref + memref.dealloc %y0 : memref + return + } +}