diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -387,6 +387,88 @@ } }; +/// Rewrites a sparse reduction that would not sparsify directly since +/// doing so would only iterate over the stored elements, ignoring the +/// implicit zeros, into a semi-ring. Applies to prod/and/min/max, although +/// probably only useful for the latter two. Note that reductions like +/// add/sub/or/xor can directly be sparsified since/ the implicit zeros +/// do not contribute to the final result. +/// +/// TODO: this essentially "densifies" the operation; we want to implement +/// this much more efficiently by performing the reduction over the +/// stored values, and feed in the zero once if there were *any* +/// implicit zeros as well; but for now, at least we provide +/// the functionality +/// +struct GenSemiRingReduction : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp op, + PatternRewriter &rewriter) const override { + // Reject non-reductions. + if (!op.hasTensorSemantics() || op.getNumDpsInputs() != 1 || + op.getNumReductionLoops() == 0 || op.getNumResults() != 1) + return failure(); + auto inp = op.getDpsInputOperand(0); + auto init = op.getDpsInitOperand(0); + if (!isSparseTensor(inp)) + return failure(); + // Look for direct x = x OP y for semi-ring ready reductions. + auto red = cast(op.getRegion().front().getTerminator()) + .getOperand(0) + .getDefiningOp(); + if (!isa(red) && !isa(red) && + !isa(red) && !isa(red) && + !isa(red) && !isa(red) && + !isa(red) && !isa(red) && + !isa(red)) + return failure(); + Value s0 = op.getBlock()->getArgument(0); + Value s1 = op.getBlock()->getArgument(1); + if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) && + (red->getOperand(0) != s1 || red->getOperand(1) != s0)) + return failure(); + // Identity. + Location loc = op.getLoc(); + Value identity = + rewriter.create(loc, init->get(), ValueRange()); + // Unary { + // present -> value + // absent -> zero. + // } + Type rtp = s0.getType(); + rewriter.setInsertionPointToStart(&op.getRegion().front()); + auto semiring = rewriter.create(loc, rtp, s0); + Block *present = + rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc); + rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front()); + rewriter.create(loc, present->getArgument(0)); + rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {}); + rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front()); + auto zero = + rewriter.create(loc, rewriter.getZeroAttr(rtp)); + rewriter.create(loc, zero); + rewriter.setInsertionPointAfter(semiring); + // CustomReduce { + // x = x REDUC y, identity + // } + auto custom = rewriter.create( + loc, rtp, semiring.getResult(), s1, identity); + Block *region = + rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc}); + rewriter.setInsertionPointToStart(&custom.getRegion().front()); + IRMapping irMap; + irMap.map(red->getOperand(0), region->getArgument(0)); + irMap.map(red->getOperand(1), region->getArgument(1)); + auto cloned = rewriter.clone(*red, irMap); + rewriter.create(loc, cloned->getResult(0)); + rewriter.setInsertionPointAfter(custom); + rewriter.replaceOp(red, custom.getResult()); + return success(); + } +}; + /// Sparse rewriting rule for sparse-to-sparse reshape operator. struct TensorReshapeRewriter : public OpRewritePattern { public: @@ -1262,8 +1344,8 @@ //===---------------------------------------------------------------------===// void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns.add(patterns.getContext()); } void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns, diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_min.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_min.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_min.mlir @@ -0,0 +1,128 @@ +// DEFINE: %{option} = enable-runtime-library=true +// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \ +// DEFINE: mlir-cpu-runner \ +// DEFINE: -e entry -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_c_runner_utils | \ +// DEFINE: FileCheck %s +// +// RUN: %{command} +// +// Do the same run, but now with direct IR generation. +// REDEFINE: %{option} = enable-runtime-library=false +// RUN: %{command} +// +// Do the same run, but now with direct IR generation and vectorization. +// REDEFINE: %{option} = "enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true" +// RUN: %{command} + +#SV = #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }> + +#trait_reduction = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> ()> // x (scalar out) + ], + iterator_types = ["reduction"], + doc = "x += MIN_i a(i)" +} + +// Examples of sparse vector MIN reductions. +module { + + // Custom MIN reduction: stored i32 elements only. + func.func @min1(%arga: tensor<32xi32, #SV>, %argx: tensor) -> tensor { + %c = tensor.extract %argx[] : tensor + %0 = linalg.generic #trait_reduction + ins(%arga: tensor<32xi32, #SV>) + outs(%argx: tensor) { + ^bb(%a: i32, %b: i32): + %1 = sparse_tensor.reduce %a, %b, %c : i32 { + ^bb0(%x: i32, %y: i32): + %m = arith.minsi %x, %y : i32 + sparse_tensor.yield %m : i32 + } + linalg.yield %1 : i32 + } -> tensor + return %0 : tensor + } + + // Regular MIN reduction: stored i32 elements AND implicit zeros. + // Note that dealing with the implicit zeros is taken care of + // by the sparse compiler to preserve semantics of the "original". + func.func @min2(%arga: tensor<32xi32, #SV>, %argx: tensor) -> tensor { + %c = tensor.extract %argx[] : tensor + %0 = linalg.generic #trait_reduction + ins(%arga: tensor<32xi32, #SV>) + outs(%argx: tensor) { + ^bb(%a: i32, %b: i32): + %m = arith.minsi %a, %b : i32 + linalg.yield %m : i32 + } -> tensor + return %0 : tensor + } + + func.func @dump_i32(%arg0 : tensor) { + %v = tensor.extract %arg0[] : tensor + vector.print %v : i32 + return + } + + func.func @entry() { + %ri = arith.constant dense<999> : tensor + + // Vectors with a few zeros. + %c_0_i32 = arith.constant dense<[ + 2, 2, 7, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 3, 0, 9, 2, 2, 2, 2, 0, 5, 1, 7, 3 + ]> : tensor<32xi32> + + // Vectors with no zeros. + %c_1_i32 = arith.constant dense<[ + 2, 2, 7, 2, 2, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 2, + 2, 2, 2, 2, 3, 2, 7, 2, 2, 2, 2, 2, 2, 1, 7, 3 + ]> : tensor<32xi32> + + // Convert constants to annotated tensors. Note that this + // particular conversion only stores nonzero elements, + // so we will have no explicit zeros, only implicit zeros. + %sv0 = sparse_tensor.convert %c_0_i32 + : tensor<32xi32> to tensor<32xi32, #SV> + %sv1 = sparse_tensor.convert %c_1_i32 + : tensor<32xi32> to tensor<32xi32, #SV> + + // Special case, construct a sparse vector with an explicit zero. + %v = arith.constant sparse< [ [1], [7] ], [ 0, 22 ] > : tensor<32xi32> + %sv2 = sparse_tensor.convert %v: tensor<32xi32> to tensor<32xi32, #SV> + + // Call the kernels. + %0 = call @min1(%sv0, %ri) : (tensor<32xi32, #SV>, tensor) -> tensor + %1 = call @min1(%sv1, %ri) : (tensor<32xi32, #SV>, tensor) -> tensor + %2 = call @min1(%sv2, %ri) : (tensor<32xi32, #SV>, tensor) -> tensor + %3 = call @min2(%sv0, %ri) : (tensor<32xi32, #SV>, tensor) -> tensor + %4 = call @min2(%sv1, %ri) : (tensor<32xi32, #SV>, tensor) -> tensor + %5 = call @min2(%sv2, %ri) : (tensor<32xi32, #SV>, tensor) -> tensor + + // Verify results. + // + // CHECK: 1 + // CHECK: 1 + // CHECK: 0 + // CHECK: 0 + // CHECK: 1 + // CHECK: 0 + // + call @dump_i32(%0) : (tensor) -> () + call @dump_i32(%1) : (tensor) -> () + call @dump_i32(%2) : (tensor) -> () + call @dump_i32(%3) : (tensor) -> () + call @dump_i32(%4) : (tensor) -> () + call @dump_i32(%5) : (tensor) -> () + + // Release the resources. + bufferization.dealloc_tensor %sv0 : tensor<32xi32, #SV> + bufferization.dealloc_tensor %sv1 : tensor<32xi32, #SV> + bufferization.dealloc_tensor %sv2 : tensor<32xi32, #SV> + + return + } +}