diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -889,6 +889,11 @@ would be equivalent to a union operation where non-overlapping values in the inputs are copied to the output unchanged. + Due to the possibility of empty regions, i.e. lack of a value for certain + cases, the result of this operation may only feed directly into the output + of the `linalg.generic` operation or into into a custom reduction + `sparse_tensor.reduce` operation that follows in the same region. + Example of isEqual applied to intersecting elements only: ```mlir @@ -992,6 +997,11 @@ A region may also be declared empty (i.e. `absent={}`), indicating that the region does not contribute to the output. + Due to the possibility of empty regions, i.e. lack of a value for certain + cases, the result of this operation may only feed directly into the output + of the `linalg.generic` operation or into into a custom reduction + `sparse_tensor.reduce` operation that follows in the same region. + Example of A+1, restricted to existing elements: ```mlir @@ -1015,28 +1025,41 @@ Example returning +1 for existing values and -1 for missing values: ```mlir - %result = sparse_tensor.unary %a : f64 to i32 - present={ - ^bb0(%x: f64): - %ret = arith.constant 1 : i32 + %C = bufferization.alloc_tensor... + %1 = linalg.generic #trait + ins(%A: tensor) + outs(%C: tensor) { + ^bb0(%a: f64, %c: f64) : + %result = sparse_tensor.unary %a : f64 to i32 + present={ + ^bb0(%x: f64): + %ret = arith.constant 1 : i32 + sparse_tensor.yield %ret : i32 + } + absent={ + %ret = arith.constant -1 : i32 sparse_tensor.yield %ret : i32 - } - absent={ - %ret = arith.constant -1 : i32 - sparse_tensor.yield %ret : i32 - } + } + linalg.yield %result : f64 + } -> tensor ``` Example showing a structural inversion (existing values become missing in the output, while missing values are filled with 1): ```mlir - %result = sparse_tensor.unary %a : f64 to i64 - present={} - absent={ - %ret = arith.constant 1 : i64 - sparse_tensor.yield %ret : i64 - } + %C = bufferization.alloc_tensor... + %2 = linalg.generic #trait + ins(%A: tensor) + outs(%C: tensor) { + %result = sparse_tensor.unary %a : f64 to i64 + present={} + absent={ + %ret = arith.constant 1 : i64 + sparse_tensor.yield %ret : i64 + } + linalg.yield %result : f64 + } -> tensor ``` }]; @@ -1139,7 +1162,7 @@ ```mlir %C = bufferization.alloc_tensor... - %0 = linalg.generic #trait + %1 = linalg.generic #trait ins(%A: tensor) outs(%C: tensor) { ^bb0(%a: f64, %c: f64) : @@ -1173,10 +1196,12 @@ ```mlir %0 = sparse_tensor.unary %a : i64 to i64 { - ^bb0(%arg0: i64): - %cst = arith.constant 1 : i64 - %ret = arith.addi %arg0, %cst : i64 - sparse_tensor.yield %ret : i64 + present={ + ^bb0(%arg0: i64): + %cst = arith.constant 1 : i64 + %ret = arith.addi %arg0, %cst : i64 + sparse_tensor.yield %ret : i64 + } } ``` }]; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1049,11 +1049,13 @@ /// Generates a store on a dense or sparse tensor. static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, Value rhs) { - // Only unary and binary are allowed to return uninitialized rhs - // to indicate missing output. + // Only unary and binary are allowed to return an uninitialized rhs + // to indicate missing output. Or otherwise a custom reduction that + // received no value to accumulate. if (!rhs) { assert(env.exp(exp).kind == TensorExp::Kind::kUnary || - env.exp(exp).kind == TensorExp::Kind::kBinary); + env.exp(exp).kind == TensorExp::Kind::kBinary || + env.exp(exp).kind == TensorExp::Kind::kReduce); return; } // Test if this is a scalarized reduction. @@ -1146,12 +1148,17 @@ Value v0 = genExp(env, rewriter, exp.children.e0, ldx); Value v1 = genExp(env, rewriter, exp.children.e1, ldx); - Value ee = env.merger().buildExp(rewriter, loc, e, v0, v1); - if (ee && - (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary || - kind == TensorExp::Kind::kBinaryBranch || - kind == TensorExp::Kind::kReduce || kind == TensorExp::Kind::kSelect)) - ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx); + Value ee; + if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) { + // custom reduce did not receive a value + } else { + ee = env.merger().buildExp(rewriter, loc, e, v0, v1); + if (ee && + (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary || + kind == TensorExp::Kind::kBinaryBranch || + kind == TensorExp::Kind::kReduce || kind == TensorExp::Kind::kSelect)) + ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx); + } if (kind == TensorExp::Kind::kReduce) env.endCustomReduc(); // exit custom diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom_sum.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom_sum.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom_sum.mlir @@ -0,0 +1,170 @@ +// DEFINE: %{option} = enable-runtime-library=true +// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \ +// DEFINE: mlir-cpu-runner \ +// DEFINE: -e entry -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_c_runner_utils | \ +// DEFINE: FileCheck %s +// +// RUN: %{command} +// +// Do the same run, but now with direct IR generation. +// REDEFINE: %{option} = enable-runtime-library=false +// RUN: %{command} +// +// Do the same run, but now with direct IR generation and vectorization. +// REDEFINE: %{option} = "enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true" +// RUN: %{command} + +#SV = #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }> + +#trait_reduction = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> ()> // x (scalar out) + ], + iterator_types = ["reduction"], + doc = "x += SUM_CUSTOM_i UNARY a(i)" +} + +// Test of unary op feeding into custom sum reduction. +module { + + // Contrived example for stress testing, where neither branch feeds + // a value into a subsequent custom sum reduction. The code should + // be folded into the initial value 1. + func.func @red0(%arga: tensor<8xi32, #SV>, %argx: tensor) -> tensor { + %c1 = arith.constant 1 : i32 + %0 = linalg.generic #trait_reduction + ins(%arga: tensor<8xi32, #SV>) + outs(%argx: tensor) { + ^bb(%a: i32, %b: i32): + %u = sparse_tensor.unary %a : i32 to i32 + present={ } + absent={ } + %r = sparse_tensor.reduce %u, %b, %c1 : i32 { + ^bb0(%x: i32, %y: i32): + %sum = arith.addi %x, %y : i32 + sparse_tensor.yield %sum : i32 + } + linalg.yield %r : i32 + } -> tensor + return %0 : tensor + } + + // Typical example where present branch contributes a value + // into a subsequent custom sum reduction. + func.func @red1(%arga: tensor<8xi32, #SV>, %argx: tensor) -> tensor { + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %0 = linalg.generic #trait_reduction + ins(%arga: tensor<8xi32, #SV>) + outs(%argx: tensor) { + ^bb(%a: i32, %b: i32): + %u = sparse_tensor.unary %a : i32 to i32 + present={ + ^bb(%p: i32): + sparse_tensor.yield %c2 : i32 + } + absent={ } + %r = sparse_tensor.reduce %u, %b, %c1 : i32 { + ^bb0(%x: i32, %y: i32): + %sum = arith.addi %x, %y : i32 + sparse_tensor.yield %sum : i32 + } + linalg.yield %r : i32 + } -> tensor + return %0 : tensor + } + + // A complementing example where absent branch contributes a value + // into a subsequent custom sum reduction. + func.func @red2(%arga: tensor<8xi32, #SV>, %argx: tensor) -> tensor { + %c1 = arith.constant 1 : i32 + %c3 = arith.constant 3 : i32 + %0 = linalg.generic #trait_reduction + ins(%arga: tensor<8xi32, #SV>) + outs(%argx: tensor) { + ^bb(%a: i32, %b: i32): + %u = sparse_tensor.unary %a : i32 to i32 + present={ } + absent={ + sparse_tensor.yield %c3 : i32 + } + %r = sparse_tensor.reduce %u, %b, %c1 : i32 { + ^bb0(%x: i32, %y: i32): + %sum = arith.addi %x, %y : i32 + sparse_tensor.yield %sum : i32 + } + linalg.yield %r : i32 + } -> tensor + return %0 : tensor + } + + // An example where both present and absent branch contribute values + // into a subsequent custom sum reduction. + func.func @red3(%arga: tensor<8xi32, #SV>, %argx: tensor) -> tensor { + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %c3 = arith.constant 3 : i32 + %0 = linalg.generic #trait_reduction + ins(%arga: tensor<8xi32, #SV>) + outs(%argx: tensor) { + ^bb(%a: i32, %b: i32): + %u = sparse_tensor.unary %a : i32 to i32 + present={ + ^bb(%p: i32): + sparse_tensor.yield %c2 : i32 + } absent={ + sparse_tensor.yield %c3 : i32 + } + %r = sparse_tensor.reduce %u, %b, %c1 : i32 { + ^bb0(%x: i32, %y: i32): + %sum = arith.addi %x, %y : i32 + sparse_tensor.yield %sum : i32 + } + linalg.yield %r : i32 + } -> tensor + return %0 : tensor + } + + func.func @dump_i32(%arg0 : tensor) { + %v = tensor.extract %arg0[] : tensor + vector.print %v : i32 + return + } + + func.func @entry() { + %ri = arith.constant dense<0> : tensor + + // Sparse vector of length 8 with 2 stored elements (and thus 6 implicit zeros). + %v0 = arith.constant sparse< [ [4], [6] ], [ 99, 999 ] > : tensor<8xi32> + %s0 = sparse_tensor.convert %v0: tensor<8xi32> to tensor<8xi32, #SV> + + // Call the kernels. + %0 = call @red0(%s0, %ri) : (tensor<8xi32, #SV>, tensor) -> tensor + %1 = call @red1(%s0, %ri) : (tensor<8xi32, #SV>, tensor) -> tensor + %2 = call @red2(%s0, %ri) : (tensor<8xi32, #SV>, tensor) -> tensor + %3 = call @red3(%s0, %ri) : (tensor<8xi32, #SV>, tensor) -> tensor + + // Verify results. + // 1 + nothing + // 1 + 2 x present + // 1 + 3 x absent + // 1 + 2 x present + 3 x absent + // + // CHECK: 1 + // CHECK: 5 + // CHECK: 19 + // CHECK: 23 + // + call @dump_i32(%0) : (tensor) -> () + call @dump_i32(%1) : (tensor) -> () + call @dump_i32(%2) : (tensor) -> () + call @dump_i32(%3) : (tensor) -> () + + // Release the resources. + bufferization.dealloc_tensor %s0 : tensor<8xi32, #SV> + + return + } +}