diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -544,6 +544,56 @@ let hasVerifier = 1; } +def SparseTensor_ReduceOp : SparseTensor_Op<"reduce", [NoSideEffect, SameOperandsAndResultType]>, + Arguments<(ins AnyType:$x, AnyType:$y, AnyType:$identity)>, + Results<(outs AnyType:$output)> { + let summary = "Custom reduction operation utilized within linalg.generic"; + let description = [{ + Defines a computation with a `linalg.generic` operation that takes two + operands and an identity value and reduces all values down to a single + result based on the computation in the region. + + The region must contain exactly one block taking two arguments. The block + must end with a sparse_tensor.yield and the output must match the input + argument types. + + Note that this operation is only required for custom reductions beyond the + + standard operations (add, mul, and, or, etc). The `linalg.generic` + `iterator_types` defines which indices are being reduced. When the associated + operands are used in an operation, a reduction will occur. The use of this + explicit `reduce` operation is not required in most cases. + + Example of Matrix->Vector reduction using max(product(x_i), 100): + + ```mlir + %cf1 = arith.constant 1.0 : f64 + %cf100 = arith.constant 100.0 : f64 + %C = bufferization.alloc_tensor... + %0 = linalg.generic #trait + ins(%A: tensor) + outs(%C: tensor) { + ^bb0(%a: f64, %c: f64) : + %result = sparse_tensor.reduce %c, %a, %cf1 : f64 { + ^bb0(%arg0: f64, %arg1: f64): + %0 = arith.mulf %arg0, %arg1 : f64 + %cmp = arith.cmpf "ogt", %0, %cf100 : f64 + %ret = arith.select %cmp, %cf100, %0 : f64 + sparse_tensor.yield %ret : f64 + } + linalg.yield %result : f64 + } -> tensor + ``` + }]; + + let regions = (region SizedRegion<1>:$region); + + let assemblyFormat = [{ + $x `,` $y `,` $identity attr-dict `:` type($output) $region + }]; + let hasVerifier = 1; +} + def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>, Arguments<(ins AnyType:$result)> { let summary = "Yield from sparse_tensor set-like operations"; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -357,15 +357,31 @@ return success(); } +LogicalResult ReduceOp::verify() { + Type inputType = x().getType(); + LogicalResult regionResult = success(); + + // Check correct number of block arguments and return type. + Region &formula = region(); + if (!formula.empty()) { + regionResult = verifyNumBlockArgs( + this, formula, "reduce", TypeRange{inputType, inputType}, inputType); + if (failed(regionResult)) + return regionResult; + } + + return success(); +} + LogicalResult YieldOp::verify() { // Check for compatible parent. auto *parentOp = (*this)->getParentOp(); - if (auto binaryOp = dyn_cast(parentOp)) - return success(); - if (auto unaryOp = dyn_cast(parentOp)) + if (isa(parentOp) || isa(parentOp) || + isa(parentOp)) return success(); - return emitOpError("expected parent op to be sparse_tensor binary or unary"); + return emitOpError( + "expected parent op to be sparse_tensor unary, binary, or reduce"); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -253,6 +253,20 @@ // ----- +func.func @invalid_binary_wrong_yield(%arg0: f64, %arg1: f64) -> f64 { + // expected-error@+1 {{left region must end with sparse_tensor.yield}} + %0 = sparse_tensor.binary %arg0, %arg1 : f64, f64 to f64 + overlap={} + left={ + ^bb0(%x: f64): + tensor.yield %x : f64 + } + right=identity + return %0 : f64 +} + +// ----- + func.func @invalid_unary_argtype_mismatch(%arg0: f64) -> f64 { // expected-error@+1 {{present region argument 1 type mismatch}} %r = sparse_tensor.unary %arg0 : f64 to f64 @@ -290,3 +304,67 @@ absent={} return %0 : f64 } + +// ----- + +func.func @invalid_unary_wrong_yield(%arg0: f64) -> f64 { + // expected-error@+1 {{present region must end with sparse_tensor.yield}} + %0 = sparse_tensor.unary %arg0 : f64 to f64 + present={ + ^bb0(%x: f64): + tensor.yield %x : f64 + } + absent={} + return %0 : f64 +} + +// ----- + +func.func @invalid_reduce_num_args_mismatch(%arg0: f64, %arg1: f64) -> f64 { + %cf1 = arith.constant 1.0 : f64 + // expected-error@+1 {{reduce region must have exactly 2 arguments}} + %r = sparse_tensor.reduce %arg0, %arg1, %cf1 : f64 { + ^bb0(%x: f64): + sparse_tensor.yield %x : f64 + } + return %r : f64 +} + +// ----- + +func.func @invalid_reduce_block_arg_type_mismatch(%arg0: i64, %arg1: i64) -> i64 { + %ci1 = arith.constant 1 : i64 + // expected-error@+1 {{reduce region argument 1 type mismatch}} + %r = sparse_tensor.reduce %arg0, %arg1, %ci1 : i64 { + ^bb0(%x: f64, %y: f64): + %cst = arith.constant 2 : i64 + sparse_tensor.yield %cst : i64 + } + return %r : i64 +} + +// ----- + +func.func @invalid_reduce_return_type_mismatch(%arg0: f64, %arg1: f64) -> f64 { + %cf1 = arith.constant 1.0 : f64 + // expected-error@+1 {{reduce region yield type mismatch}} + %r = sparse_tensor.reduce %arg0, %arg1, %cf1 : f64 { + ^bb0(%x: f64, %y: f64): + %cst = arith.constant 2 : i64 + sparse_tensor.yield %cst : i64 + } + return %r : f64 +} + +// ----- + +func.func @invalid_reduce_wrong_yield(%arg0: f64, %arg1: f64) -> f64 { + %cf1 = arith.constant 1.0 : f64 + // expected-error@+1 {{reduce region must end with sparse_tensor.yield}} + %r = sparse_tensor.reduce %arg0, %arg1, %cf1 : f64 { + ^bb0(%x: f64, %y: f64): + %cst = arith.constant 2 : i64 + tensor.yield %cst : i64 + } + return %r : f64 +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -268,3 +268,25 @@ absent={} return %r : i64 } + +// ----- + +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// CHECK-LABEL: func @sparse_reduce_2d_to_1d( +// CHECK-SAME: %[[A:.*]]: f64, %[[B:.*]]: f64) -> f64 { +// CHECK: %[[Z:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[C1:.*]] = sparse_tensor.reduce %[[A]], %[[B]], %[[Z]] : f64 { +// CHECK: ^bb0(%[[A1:.*]]: f64, %[[B1:.*]]: f64): +// CHECK: sparse_tensor.yield %[[A1]] : f64 +// CHECK: } +// CHECK: return %[[C1]] : f64 +// CHECK: } +func.func @sparse_reduce_2d_to_1d(%arg0: f64, %arg1: f64) -> f64 { + %cf0 = arith.constant 0.0 : f64 + %r = sparse_tensor.reduce %arg0, %arg1, %cf0 : f64 { + ^bb0(%x: f64, %y: f64): + sparse_tensor.yield %x : f64 + } + return %r : f64 +} \ No newline at end of file