diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -379,4 +379,216 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// Sparse Tensor Custom Linalg.Generic Operations. +//===----------------------------------------------------------------------===// + +def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [NoSideEffect]>, + Arguments<(ins AnyType:$x, AnyType:$y, UnitAttr:$left_identity, UnitAttr:$right_identity)>, + Results<(outs AnyType:$output)> { + let summary = "Binary set operation utilized within linalg.generic"; + let description = [{ + Defines a computation within a `linalg.generic` operation that takes two + operands and executes one of the regions depending on whether both operands + or either operand is nonzero (i.e. stored explicitly in the sparse storage + format). + + Three regions are defined for the operation and must appear in this order: + - overlap (elements present in both sparse tensors) + - left (elements only present in the left sparse tensor) + - right (element only present in the right sparse tensor) + + Each region contains a single block describing the computation and result. + The block must end with sparse_tensor.yield and the return type must + match the type of `output`. The primary region's block has two arguments, + while the left and right region's block has only one argument. + + A region may also be declared empty (i.e. `left={}`, indicating that the + region does not contribute to the output. For example, setting both + `left={}` and `right={}` is equivalent to the intersection of the two + inputs as only the overlap region will contribute values to the output. + + As a convenience, there is also a special token `identity` which can be + used in place of the left or right region. This token indicates that + the return value is the input value (i.e. func(%x) => return %x). + As a practical example, setting `left=identity` and `right=identity` + would be equivalent to a union operation where non-overlapping values + in the inputs are copied to the output unchanged. + + Example of isEqual applied for intersecting elements only: + ```mlir + %C = sparse_tensor.init... + %0 = linalg.generic #trait + ins(%A: tensor, %B: tensor) + outs(%C: tensor) { + ^bb0(%a: f64, %b: f64, %c: i8) : + %result = sparse_tensor.binary %a, %b : f64, f64 to i8 + overlap={ + ^bb0(%arg0: f64, %arg1: f64): + %cmp = arith.cmpf "oeq", %arg0, %arg1 : f64 + %ret_i8 = arith.extui %cmp : i1 to i8 + sparse_tensor.yield %ret_i8 : i8 + } + left={} + right={} + linalg.yield %result : i8 + } -> tensor + ``` + + Example of A+B in upper triangle, A-B in lower triangle: + ```mlir + %C = sparse_tensor.init... + %1 = linalg.generic #trait + ins(%A: tensor, %B: tensor + outs(%C: tensor { + ^bb0(%a: f64, %b: f64, %c: f64) : + %row = linalg.index 0 : index + %col = linalg.index 1 : index + %result = sparse_tensor.binary %a, %b : f64, f64 to f64 + overlap={ + ^bb0(%x: f64, %y: f64): + %cmp = arith.cmpi "uge", %column, %row : index + %upperTriangleResult = arith.addf %x, %y : f64 + %lowerTriangleResult = arith.subf %x, %y : f64 + %ret = arith.select %cmp, %upperTriangleResult, %lowerTriangleResult : f64 + sparse_tensor.yield %ret : f64 + } + left=identity + right={ + ^bb0(%y: f64): + %cmp = arith.cmpi "uge", %column, %row : index + %lowerTriangleResult = arith.negf %y : f64 + %ret = arith.select %cmp, %y, %lowerTriangleResult + sparse_tensor.yield %ret : f64 + } + linalg.yield %result : f64 + } -> tensor + ``` + + Example of set difference. Return a copy of A where its sparse structure + is *not* overlapped by B. The element type of B can be different than A + because we never use its values, only its sparse structure. + ```mlir + %C = sparse_tensor.init... + %2 = linalg.generic #trait + ins(%A: tensor, %B: tensor + outs(%C: tensor { + ^bb0(%a: f64, %b: i32, %c: f64) : + %result = sparse_tensor.binary %a, %b : f64, i32 to f64 + overlap={} + left=identity + right={} + linalg.yield %result : f64 + } -> tensor + ``` + }]; + + let regions = (region AnyRegion:$overlapRegion, AnyRegion:$leftRegion, AnyRegion:$rightRegion); + let assemblyFormat = [{ + $x `,` $y `:` attr-dict type($x) `,` type($y) `to` type($output) `\n` + `overlap` `=` $overlapRegion `\n` + `left` `=` (`identity` $left_identity^):($leftRegion)? `\n` + `right` `=` (`identity` $right_identity^):($rightRegion)? + }]; + let hasVerifier = 1; +} + +def SparseTensor_UnaryOp : SparseTensor_Op<"unary", [NoSideEffect]>, + Arguments<(ins AnyType:$x)>, + Results<(outs AnyType:$output)> { + let summary = "Unary set operation utilized within linalg.generic"; + let description = [{ + Defines a computation with a `linalg.generic` operation that takes a single + operand and executes one of two regions depending on whether the operand is + nonzero (i.e. stored explicitly in the sparse storage format). + + Two regions are defined for the operation must appear in this order: + - present (elements present in the sparse tensor) + - absent (elements not present in the sparse tensor) + + Each region contains a single block describing the computation and result. + The block must end with sparse_tensor.yield and the return type must match + the type of `output`. The primary region's block has one argument, while + the missing region's block has zero arguments. + + A region may also be declared empty (i.e. `absent={}`, indicating that the + region does not contribute to the output. + + Example of A+1, restricted to existing elements: + ```mlir + %C = sparse_tensor.init... + %0 = linalg.generic #trait + ins(%A: tensor) + outs(%C: tensor) { + ^bb0(%a: f64, %c: f64) : + %result = sparse_tensor.unary %a : f64 to f64 + present={ + ^bb0(%arg0: f64): + %cf1 = arith.constant 1.0 : f64 + %ret = arith.addf %arg0, %cf1 : f64 + sparse_tensor.yield %ret : f64 + } + absent={} + linalg.yield %result : f64 + } -> tensor + ``` + + Example returning +1 for existing values and -1 for missing values: + ```mlir + %result = sparse_tensor.unary %a : f64 to i32 + present={ + ^bb0(%x: f64): + %ret = arith.constant 1 : i32 + sparse_tensor.yield %ret : i32 + } + absent={ + %ret = arith.constant -1 : i32 + sparse_tensor.yield %ret : i32 + } + ``` + + Example showing a structural inversion (existing values become missing in + the output, while missing values are filled with 1): + ```mlir + %result = sparse_tensor.unary %a : f64 to i64 + present={} + absent={ + %ret = arith.constant 1 : i64 + sparse_tensor.yield %ret : i64 + } + ``` + }]; + + let regions = (region AnyRegion:$presentRegion, AnyRegion:$absentRegion); + let assemblyFormat = [{ + $x attr-dict `:` type($x) `to` type($output) `\n` + `present` `=` $presentRegion `\n` + `absent` `=` $absentRegion + }]; + let hasVerifier = 1; +} + +def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>, + Arguments<(ins AnyType:$result)> { + let summary = "Yield from sparse_tensor set-like operations"; + let description = [{ + Yield a value from within a `binary` or `unary` block. + + Example: + ``` + %0 = sparse_tensor.unary %a : i64 to i64 { + ^bb0(%arg0: i64): + %cst = arith.constant 1 : i64 + %ret = arith.addi %arg0, %cst : i64 + sparse_tensor.yield %ret : i64 + } + ``` + }]; + + let assemblyFormat = [{ + $result attr-dict `:` type($result) + }]; + let hasVerifier = 1; +} + #endif // SPARSETENSOR_OPS diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -333,6 +333,117 @@ return success(); } +//===----------------------------------------------------------------------===// +// TensorDialect Linalg.Generic Operations. +//===----------------------------------------------------------------------===// + +template +static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, + const char *regionName, + TypeRange inputTypes, Type outputType) { + unsigned numArgs = region.getNumArguments(); + unsigned expectedNum = inputTypes.size(); + if (numArgs != expectedNum) + return op->emitError() << regionName << " region must have exactly " + << expectedNum << " arguments"; + + for (unsigned i = 0; i < numArgs; i++) { + Type typ = region.getArgument(i).getType(); + if (typ != inputTypes[i]) + return op->emitError() << regionName << " region argument " << (i + 1) + << " type mismatch"; + } + Operation *term = region.front().getTerminator(); + YieldOp yield = dyn_cast(term); + if (!yield) + return op->emitError() << regionName + << " region must end with sparse_tensor.yield"; + if (yield.getOperand().getType() != outputType) + return op->emitError() << regionName << " region yield type mismatch"; + + return success(); +} + +LogicalResult BinaryOp::verify() { + NamedAttrList attrs = (*this)->getAttrs(); + Type leftType = x().getType(); + Type rightType = y().getType(); + Type outputType = output().getType(); + Region &overlap = overlapRegion(); + Region &left = leftRegion(); + Region &right = rightRegion(); + + // Check correct number of block arguments and return type for each + // non-empty region. + LogicalResult regionResult = success(); + if (!overlap.empty()) { + regionResult = verifyNumBlockArgs( + this, overlap, "overlap", TypeRange{leftType, rightType}, outputType); + if (failed(regionResult)) + return regionResult; + } + if (!left.empty()) { + regionResult = + verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType); + if (failed(regionResult)) + return regionResult; + } else if (left_identity()) { + if (leftType != outputType) + return emitError("left=identity requires first argument to have the same " + "type as the output"); + } + if (!right.empty()) { + regionResult = verifyNumBlockArgs(this, right, "right", + TypeRange{rightType}, outputType); + if (failed(regionResult)) + return regionResult; + } else if (right_identity()) { + if (rightType != outputType) + return emitError("right=identity requires second argument to have the " + "same type as the output"); + } + + return success(); +} + +LogicalResult UnaryOp::verify() { + Type inputType = x().getType(); + Type outputType = output().getType(); + LogicalResult regionResult = success(); + + // Check correct number of block arguments and return type for each + // non-empty region. + Region &present = presentRegion(); + if (!present.empty()) { + regionResult = verifyNumBlockArgs(this, present, "present", + TypeRange{inputType}, outputType); + if (failed(regionResult)) + return regionResult; + } + Region &absent = absentRegion(); + if (!absent.empty()) { + regionResult = + verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType); + if (failed(regionResult)) + return regionResult; + } + + return success(); +} + +LogicalResult YieldOp::verify() { + // Check for compatible parent. + auto *parentOp = (*this)->getParentOp(); + if (auto binaryOp = dyn_cast(parentOp)) + return success(); + if (auto unaryOp = dyn_cast(parentOp)) + return success(); + + // NOTE: Return type check is performed in each parent op's verify method. + + return emitOpError("expected parent op to be sparse_tensor binary or unary"); +} + //===----------------------------------------------------------------------===// // TensorDialect Methods. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -212,3 +212,111 @@ sparse_tensor.out %arg0, %arg1 : tensor<10xf64>, !llvm.ptr return } + +// ----- + +func @invalid_binary_num_args_mismatch_overlap(%arg0: f64, %arg1: f64) -> f64 { + // expected-error@+1 {{overlap region must have exactly 2 arguments}} + %r = sparse_tensor.binary %arg0, %arg1 : f64, f64 to f64 + overlap={ + ^bb0(%x: f64): + sparse_tensor.yield %x : f64 + } + left={} + right={} + return %r : f64 +} + +// ----- + +func @invalid_binary_num_args_mismatch_right(%arg0: f64, %arg1: f64) -> f64 { + // expected-error@+1 {{right region must have exactly 1 arguments}} + %r = sparse_tensor.binary %arg0, %arg1 : f64, f64 to f64 + overlap={} + left={} + right={ + ^bb0(%x: f64, %y: f64): + sparse_tensor.yield %y : f64 + } + return %r : f64 +} + +// ----- + +func @invalid_binary_argtype_mismatch(%arg0: f64, %arg1: f64) -> f64 { + // expected-error@+1 {{overlap region argument 2 type mismatch}} + %r = sparse_tensor.binary %arg0, %arg1 : f64, f64 to f64 + overlap={ + ^bb0(%x: f64, %y: f32): + sparse_tensor.yield %x : f64 + } + left=identity + right=identity + return %r : f64 +} + +// ----- + +func @invalid_binary_wrong_return_type(%arg0: f64, %arg1: f64) -> f64 { + // expected-error@+1 {{left region yield type mismatch}} + %0 = sparse_tensor.binary %arg0, %arg1 : f64, f64 to f64 + overlap={} + left={ + ^bb0(%x: f64): + %1 = arith.constant 0.0 : f32 + sparse_tensor.yield %1 : f32 + } + right=identity + return %0 : f64 +} + +// ----- + +func @invalid_binary_wrong_identity_type(%arg0: i64, %arg1: f64) -> f64 { + // expected-error@+1 {{left=identity requires first argument to have the same type as the output}} + %0 = sparse_tensor.binary %arg0, %arg1 : i64, f64 to f64 + overlap={} + left=identity + right=identity + return %0 : f64 +} + +// ----- + +func @invalid_unary_argtype_mismatch(%arg0: f64) -> f64 { + // expected-error@+1 {{present region argument 1 type mismatch}} + %r = sparse_tensor.unary %arg0 : f64 to f64 + present={ + ^bb0(%x: index): + sparse_tensor.yield %x : index + } + absent={} + return %r : f64 +} + +// ----- + +func @invalid_unary_num_args_mismatch(%arg0: f64) -> f64 { + // expected-error@+1 {{absent region must have exactly 0 arguments}} + %r = sparse_tensor.unary %arg0 : f64 to f64 + present={} + absent={ + ^bb0(%x: f64): + sparse_tensor.yield %x : f64 + } + return %r : f64 +} + +// ----- + +func @invalid_unary_wrong_return_type(%arg0: f64) -> f64 { + // expected-error@+1 {{present region yield type mismatch}} + %0 = sparse_tensor.unary %arg0 : f64 to f64 + present={ + ^bb0(%x: f64): + %1 = arith.constant 0.0 : f32 + sparse_tensor.yield %1 : f32 + } + absent={} + return %0 : f64 +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -193,3 +193,94 @@ sparse_tensor.out %arg0, %arg1 : tensor, !llvm.ptr return } + +// ----- + +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// CHECK-LABEL: func @sparse_binary( +// CHECK-SAME: %[[A:.*]]: f64, %[[B:.*]]: i64) -> f64 { +// CHECK: %[[Z:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[C1:.*]] = sparse_tensor.binary %[[A]], %[[B]] : f64, i64 to f64 +// CHECK: overlap = { +// CHECK: ^bb0(%[[A1:.*]]: f64, %[[B1:.*]]: i64): +// CHECK: sparse_tensor.yield %[[A1]] : f64 +// CHECK: } +// CHECK: left = identity +// CHECK: right = { +// CHECK: ^bb0(%[[A2:.*]]: i64): +// CHECK: sparse_tensor.yield %[[Z]] : f64 +// CHECK: } +// CHECK: return %[[C1]] : f64 +// CHECK: } +func @sparse_binary(%arg0: f64, %arg1: i64) -> f64 { + %cf0 = arith.constant 0.0 : f64 + %r = sparse_tensor.binary %arg0, %arg1 : f64, i64 to f64 + overlap={ + ^bb0(%x: f64, %y: i64): + sparse_tensor.yield %x : f64 + } + left=identity + right={ + ^bb0(%y: i64): + sparse_tensor.yield %cf0 : f64 + } + return %r : f64 +} + +// ----- + +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// CHECK-LABEL: func @sparse_unary( +// CHECK-SAME: %[[A:.*]]: f64) -> f64 { +// CHECK: %[[C1:.*]] = sparse_tensor.unary %[[A]] : f64 to f64 +// CHECK: present = { +// CHECK: ^bb0(%[[A1:.*]]: f64): +// CHECK: sparse_tensor.yield %[[A1]] : f64 +// CHECK: } +// CHECK: absent = { +// CHECK: %[[R:.*]] = arith.constant -1.000000e+00 : f64 +// CHECK: sparse_tensor.yield %[[R]] : f64 +// CHECK: } +// CHECK: return %[[C1]] : f64 +// CHECK: } +func @sparse_unary(%arg0: f64) -> f64 { + %r = sparse_tensor.unary %arg0 : f64 to f64 + present={ + ^bb0(%x: f64): + sparse_tensor.yield %x : f64 + } absent={ + ^bb0: + %cf1 = arith.constant -1.0 : f64 + sparse_tensor.yield %cf1 : f64 + } + return %r : f64 +} + +// ----- + +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// CHECK-LABEL: func @sparse_unary( +// CHECK-SAME: %[[A:.*]]: f64) -> i64 { +// CHECK: %[[C1:.*]] = sparse_tensor.unary %[[A]] : f64 to i64 +// CHECK: present = { +// CHECK: ^bb0(%[[A1:.*]]: f64): +// CHECK: %[[R:.*]] = arith.fptosi %[[A1]] : f64 to i64 +// CHECK: sparse_tensor.yield %[[R]] : i64 +// CHECK: } +// CHECK: absent = { +// CHECK: } +// CHECK: return %[[C1]] : i64 +// CHECK: } +func @sparse_unary(%arg0: f64) -> i64 { + %r = sparse_tensor.unary %arg0 : f64 to i64 + present={ + ^bb0(%x: f64): + %ret = arith.fptosi %x : f64 to i64 + sparse_tensor.yield %ret : i64 + } + absent={} + return %r : i64 +}