diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -379,4 +379,184 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// Sparse Tensor Custom Linalg.Generic Operations. +//===----------------------------------------------------------------------===// + +def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [NoSideEffect, SameTypeOperands]>, + Arguments<(ins AnyType:$x, AnyType:$y, DefaultValuedAttr:$include_index)>, + Results<(outs AnyType:$output)> { + let summary = "Binary set operation utilized within linalg.generic"; + let description = [{ + Fulfills the need to provide iteration instructions within `linalg.generic` + for binary operations while separating out the computation results. + + Three regions must be defined in order: + - primary (elements present in both sparse tensors) + - left (elements only present in the left sparse tensor) + - right (element only present in the right sparse tensor) + + Each region contains a single block describing the computation and result for that + set region. The block must end with sparse_tensor.yield and the return type must match + the type of `output`. + + Alternatively, the region may be left blank to indicate that the output + should not contain an entry for these set elements. This is typically used to indicate + an intersection by specifying `left={}` and `right={}`. + + Another alternative is to use the special token `identity` instead of a region and + indicates that the return value should be the same as the input value. This is + only available for `left` and `right`, which have a single input value. Setting both + `left=identity` and `right=identity` is equivalent to a union operation. + + The number of block arguments depends on the number of available values present. + For the primary region, two arguments are included. + For either the left or right region, only one argument is included. + + An optional attribute "include_index" can be set to true to augment the block arguments + with the index or indices of the element within the tensor. + For example, a rank 2 tensor in the primary region would have arguments (x, y, row, column). + A rank 1 tensor in the right region would have arguments (y, index). + + Example of isEqual applied for intersecting elements only: + ```mlir + %C = sparse_tensor.init... + %0 = linalg.generic #trait + ins(%A: tensor, %B: tensor) + outs(%C: tensor) { + ^bb0(%a: f64, %b: f64, %c: i8) : + %result = sparse_tensor.binary %a, %b : f64 to i8 { + ^bb0(%arg0: f64, %arg1: f64): + %cmp = arith.cmpf "oeq", %arg0, %arg1 : f64 + %ret_i8 = arith.extui %cmp : i1 to i8 + sparse_tensor.yield %ret_i8 : i8 + } + left={} + right={} + linalg.yield %result : i8 + } -> tensor + ``` + + Example of A+B in upper triangle, A-B in lower triangle: + ```mlir + %C = sparse_tensor.init... + %1 = linalg.generic #trait + ins(%A: tensor, %B: tensor + outs(%C: tensor { + ^bb0(%a: f64, %b: f64, %c: f64) : + %result = sparse_tensor.binary %a, %b {include_index=true} : f64 to f64 { + ^bb0(%x: f64, %y: f64, %row: index, %column: index): + %cmp = arith.cmpi "uge", %column, %row : index + %upperTriangleResult = arith.addf %x, %y : f64 + %lowerTriangleResult = arith.subf %x, %y : f64 + %ret = arith.select %cmp, %upperTriangleResult, %lowerTriangleResult : f64 + sparse_tensor.yield %ret : f64 + } + left=identity + right={ + ^bb0(%y: f64, %row: index, %column: index): + %cmp = arith.cmpi "uge", %column, %row : index + %lowerTriangleResult = arith.negf %y : f64 + %ret = arith.select %cmp, %y, %lowerTriangleResult + sparse_tensor.yield %ret : f64 + } + linalg.yield %result : f64 + } -> tensor + ``` + }]; + + let regions = (region AnyRegion:$primaryRegion, AnyRegion:$leftRegion, AnyRegion:$rightRegion); + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def SparseTensor_UnaryOp : SparseTensor_Op<"unary", [NoSideEffect, SameTypeOperands]>, + Arguments<(ins AnyType:$x, DefaultValuedAttr:$include_index)>, + Results<(outs AnyType:$output)> { + let summary = "Unary set operation utilized within linalg.generic"; + let description = [{ + Fulfills the need to provide iteration instructions within `linalg.generic` + for unary operations while separating out the computation results. + + Two regions are defined: + - primary (elements present) + - missing (elements not present); optional + + Each region contains a single block describing the computation and result for that + set region. The block must end with sparse_tensor.yield and the return type must match + the type of `output`. + + Alternatively, the region may be left blank to indicate that the output + should not contain an entry for these set elements. This is the default behavior + for the missing region if not specified. + + The primary region is required and takes a single block argument. + The missing region is optional and takes no block arguments (unless include_index is set). + + An optional attribute "include_index" can be set to true to augment the block arguments + with the index or indices of the element within the tensor. + For example, a rank 1 tensor in the primary region would have arguments (x, index). + A rank 2 tensor in the missing region would have arguments (row, column). + + Example of A+1, restricted to existing elements: + ```mlir + %C = sparse_tensor.init... + %0 = linalg.generic #trait + ins(%A: tensor) + outs(%C: tensor) { + ^bb0(%a: f64, %c: f64) : + %result = sparse_tensor.unary %a : f64 to f64 { + ^bb0(%arg0: f64): + %cf1 = arith.constant 1.0 : f64 + %ret = arith.addf %arg0, %cf1 : f64 + sparse_tensor.yield %ret : f64 + } + linalg.yield %result : f64 + } -> tensor + ``` + + Example returning the column index for existing values and -1 for missing values: + ```mlir + %result = sparse_tensor.unary %a, %b {include_index=true} : f64 to i64 { + ^bb0(%x: f64, %row: index, %column: index): + %ret = arith.index_cast %column : index to i64 + sparse_tensor.yield %column : i64 + } + missing={ + ^bb0(%row: index, %column: index): + %ret = arith.constant -1 : i64 + sparse_tensor.yield %ret : i64 + } + ``` + }]; + + let regions = (region AnyRegion:$primaryRegion, AnyRegion:$missingRegion); + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]> { + let summary = "Yield from sparse_tensor set-like operations"; + let description = [{ + Yield a value from within a block. + Used to terminate a block in sparse_tensor set-like operations, which + operate on different pieces of sparse tensor overlaps. + + Example: + ``` + { + ^bb0(%y: i64): + %cst = arith.constant 1 : i64 + %ret = arith.addi %y, %cst : i64 + sparse_tensor.yield %ret : i64 + } + ``` + }]; + + let arguments = (ins AnyType:$result); + let assemblyFormat = [{ + $result attr-dict `:` type($result) + }]; +} + #endif // SPARSETENSOR_OPS diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -333,6 +333,234 @@ return success(); } +//===----------------------------------------------------------------------===// +// TensorDialect Linalg.Generic Operations. +//===----------------------------------------------------------------------===// + +template +LogicalResult verifyNumBlockArgs(T *op, Region ®ion, const char *regionName, unsigned expectedNum, + Type inputType, Type outputType, bool includeIndex) { + unsigned numArgs = region.getNumArguments(); + if (!includeIndex) { + if (numArgs != expectedNum) + return op->emitError() << regionName << " region must have exactly " << expectedNum << " arguments"; + } else { + if (numArgs <= expectedNum) + return op->emitError() << regionName << " region expected to have more than " << expectedNum << " arguments"; + } + for (unsigned i = 0; i < numArgs; i++) { + Type typ = region.getArgument(i).getType(); + if (i < expectedNum) { + if (typ != inputType) + return op->emitError() << regionName << " region argument " << (i+1) << " type mismatch"; + } else { + if (!typ.isIndex()) + return op->emitError() << regionName << " region argument " << (i+1) << " must be IndexType"; + } + } + Operation *term = region.front().getTerminator(); + YieldOp yield = dyn_cast_or_null(term); + if (!yield) + return op->emitError() << regionName << " region must end with sparse_tensor.yield"; + if (yield.getOperand().getType() != outputType) + return op->emitError() << regionName << " region yield type mismatch"; + + return success(); +} + +LogicalResult BinaryOp::verify() { + bool includeIndex = include_index(); + NamedAttrList attrs = (*this)->getAttrs(); + Type inputType = x().getType(); + Type outputType = output().getType(); + LogicalResult regionResult = success(); + + Region &primary = primaryRegion(); + if (!primary.empty()) { + regionResult = verifyNumBlockArgs(this, primary, "primary", 2, inputType, outputType, includeIndex); + if (failed(regionResult)) + return regionResult; + + } + Region &left = leftRegion(); + if (!left.empty()) { + auto left_identity = attrs.get("left_identity").dyn_cast_or_null(); + if (left_identity && left_identity.getValue()) + return emitError("left_identity set with non-empty left region"); + regionResult = verifyNumBlockArgs(this, left, "left", 1, inputType, outputType, includeIndex); + if (failed(regionResult)) + return regionResult; + } + Region &right = rightRegion(); + if (!right.empty()) { + auto right_identity = attrs.get("right_identity").dyn_cast_or_null(); + if (right_identity && right_identity.getValue()) + return emitError("right_identity set with non-empty right region"); + regionResult = verifyNumBlockArgs(this, right, "right", 1, inputType, outputType, includeIndex); + if (failed(regionResult)) + return regionResult; + } + + return success(); +} + +ParseResult BinaryOp::parse(OpAsmParser &parser, OperationState &result) { + auto &builder = parser.getBuilder(); + + // Create the regions for 'primary', 'left', and 'right' + result.regions.reserve(3); + Region *primaryRegion = result.addRegion(); + Region *leftRegion = result.addRegion(); + Region *rightRegion = result.addRegion(); + + OpAsmParser::OperandType left, right; + if (parser.parseOperand(left) || + parser.parseComma() || parser.parseOperand(right)) + return failure(); + + // Parse the optional attribute list + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + Type inputType, outputType; + if (parser.parseColonType(inputType) || + parser.parseKeywordType("to", outputType)) + return failure(); + + if (parser.resolveOperand(left, inputType, result.operands) || + parser.resolveOperand(right, inputType, result.operands)) + return failure(); + result.types.push_back(outputType); + + // Parse the 'primary' region + // This region has an optional "primary=" keyword + if (succeeded(parser.parseOptionalKeyword("primary"))) + if (parser.parseEqual()) + return failure(); + if (parser.parseRegion(*primaryRegion)) + return failure(); + // Parse the 'left' region; might be `left=identity` helper + if (parser.parseKeyword("left") || parser.parseEqual()) + return failure(); + if (!parser.parseOptionalKeyword("identity")) + result.attributes.append(StringRef("left_identity"), builder.getBoolAttr(true)); + else if (parser.parseRegion(*leftRegion)) + return failure(); + // Parse the 'right' region; might be `right=identity` helper + if (parser.parseKeyword("right") || parser.parseEqual()) + return failure(); + if (!parser.parseOptionalKeyword("identity")) + result.attributes.append(StringRef("right_identity"), builder.getBoolAttr(true)); + else if (parser.parseRegion(*rightRegion)) + return failure(); + + return success(); +} + +void BinaryOp::print(OpAsmPrinter &p) { + p << " " << x() << ", " << y(); + NamedAttrList attrs = (*this)->getAttrs(); + auto left_identity = attrs.erase("left_identity").dyn_cast_or_null(); + auto right_identity = attrs.erase("right_identity").dyn_cast_or_null(); + p.printOptionalAttrDict(attrs); + p << ": " << x().getType() << " to " << output().getType(); + p << ' '; + p.printRegion(primaryRegion()); + p.printNewline(); + // Print left region (condense if identity) + p << "left="; + if (left_identity && left_identity.getValue()) + p << "identity"; + else if (leftRegion().empty()) + p << "{}"; + else + p.printRegion(leftRegion()); + p.printNewline(); + // Print right region (condense if identity) + p << "right="; + if (right_identity && right_identity.getValue()) + p << "identity"; + else if (rightRegion().empty()) + p << "{}"; + else + p.printRegion(rightRegion()); +} + +LogicalResult UnaryOp::verify() { + bool includeIndex = include_index(); + Type inputType = x().getType(); + Type outputType = output().getType(); + LogicalResult regionResult = success(); + + Region &primary = primaryRegion(); + if (!primary.empty()) { + regionResult = verifyNumBlockArgs(this, primary, "primary", 1, inputType, outputType, includeIndex); + if (failed(regionResult)) + return regionResult; + + } + Region &missing = missingRegion(); + if (!missing.empty()) { + regionResult = verifyNumBlockArgs(this, missing, "missing", 0, inputType, outputType, includeIndex); + if (failed(regionResult)) + return regionResult; + } + + return success(); +} + +ParseResult UnaryOp::parse(OpAsmParser &parser, OperationState &result) { + // Create the regions for 'primary' and 'missing' + result.regions.reserve(2); + Region *primaryRegion = result.addRegion(); + Region *missingRegion = result.addRegion(); + + OpAsmParser::OperandType inp; + if (parser.parseOperand(inp)) + return failure(); + + // Parse the optional attribute list + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + Type inputType, outputType; + if (parser.parseColonType(inputType) || + parser.parseKeywordType("to", outputType)) + return failure(); + + if (parser.resolveOperand(inp, inputType, result.operands)) + return failure(); + result.types.push_back(outputType); + + // Parse the 'primary' region + // This region has an optional "primary=" keyword + if (succeeded(parser.parseOptionalKeyword("primary"))) + if (parser.parseEqual()) + return failure(); + if (parser.parseRegion(*primaryRegion)) + return failure(); + // Parse the optional 'missing' region + if (succeeded(parser.parseOptionalKeyword("missing"))) { + if (parser.parseEqual() || parser.parseRegion(*missingRegion)) + return failure(); + } + + return success(); +} + +void UnaryOp::print(OpAsmPrinter &p) { + p << " " << x(); + p.printOptionalAttrDict((*this)->getAttrs()); + p << ": " << x().getType() << " to " << output().getType(); + p << ' '; + p.printRegion(primaryRegion()); + if (!missingRegion().empty()) { + p.printNewline(); + p << "missing="; + p.printRegion(missingRegion()); + } +} + //===----------------------------------------------------------------------===// // TensorDialect Methods. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -212,3 +212,93 @@ sparse_tensor.out %arg0, %arg1 : tensor<10xf64>, !llvm.ptr return } + +// ----- + +func @invalid_binary_num_args_mismatch(%arg0: f64, %arg1: f64) -> f64 { + // expected-error@+1 {{primary region must have exactly 2 arguments}} + %r = sparse_tensor.binary %arg0, %arg1 : f64 to f64 { + ^bb0(%x: f64): + sparse_tensor.yield %x : f64 + } + left={} + right={} + return %r : f64 +} + +// ----- + +func @invalid_binary_argtype_mismatch(%arg0: f64, %arg1: f64) -> f64 { + // expected-error@+1 {{primary region argument 2 type mismatch}} + %r = sparse_tensor.binary %arg0, %arg1 : f64 to f64 { + ^bb0(%x: f64, %y: index): + sparse_tensor.yield %x : f64 + } + left={} + right={} + return %r : f64 +} + +// ----- + +func @invalid_binary_argtype_mismatch2(%arg0: f64, %arg1: f64) -> f64 { + // expected-error@+1 {{right region argument 2 must be IndexType}} + %r = sparse_tensor.binary %arg0, %arg1 {include_index=true} : f64 to f64 + primary={} + left={} + right={ + ^bb0(%x: f64, %y: f64, %idx: f64): + sparse_tensor.yield %y : f64 + } + return %r : f64 +} + +// ----- + +func @invalid_binary_wrong_return_type(%arg0: f64, %arg1: f64) -> f64 { + // expected-error@+1 {{left region yield type mismatch}} + %0 = sparse_tensor.binary %arg0, %arg1 : f64 to f64 { } + left={ + ^bb0(%x: f64): + %1 = arith.constant 0.0 : f32 + sparse_tensor.yield %1 : f32 + } + right={} + return %0 : f64 +} + +// ----- + +func @invalid_unary_argtype_mismatch(%arg0: f64) -> f64 { + // expected-error@+1 {{primary region argument 1 type mismatch}} + %r = sparse_tensor.unary %arg0 : f64 to f64 { + ^bb0(%x: index): + sparse_tensor.yield %x : index + } + return %r : f64 +} + +// ----- + +func @invalid_unary_argtype_mismatch2(%arg0: f64) -> f64 { + // expected-error@+1 {{missing region argument 1 must be IndexType}} + %r = sparse_tensor.unary %arg0 {include_index=true} : f64 to f64 + primary={} + missing={ + ^bb0(%idx: f64): + sparse_tensor.yield %idx : f64 + } + return %r : f64 +} + +// ----- + +func @invalid_unary_wrong_return_type(%arg0: f64) -> f64 { + // expected-error@+1 {{primary region yield type mismatch}} + %0 = sparse_tensor.unary %arg0 : f64 to f64 { + ^bb0(%x: f64): + %1 = arith.constant 0.0 : f32 + sparse_tensor.yield %1 : f32 + } + return %0 : f64 +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -193,3 +193,84 @@ sparse_tensor.out %arg0, %arg1 : tensor, !llvm.ptr return } + +// ----- + +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// CHECK-LABEL: func @sparse_binary( +// CHECK-SAME: %[[A:.*]]: f64, %[[B:.*]]: f64) -> f64 { +// CHECK: %[[C1:.*]] = sparse_tensor.binary %[[A]], %[[B]] {include_index = true}: f64 to f64 { +// CHECK: ^bb0(%[[A1:.*]]: f64, %[[B1:.*]]: f64, %[[I1:.*]]: index): +// CHECK: sparse_tensor.yield %[[A1]] : f64 +// CHECK: } +// CHECK: left=identity +// CHECK: right={ +// CHECK: ^bb0(%[[A2:.*]]: f64, %[[I2:.*]]: index): +// CHECK: sparse_tensor.yield %[[A2]] : f64 +// CHECK: } +// CHECK: return %[[C1]] : f64 +// CHECK: } +func @sparse_binary(%arg0: f64, %arg1: f64) -> f64 { + %r = sparse_tensor.binary %arg0, %arg1 {include_index=true}: f64 to f64 { + ^bb0(%x: f64, %y: f64, %idx: index): + sparse_tensor.yield %x : f64 + } + left=identity + right={ + ^bb0(%y: f64, %idx: index): + sparse_tensor.yield %y : f64 + } + return %r : f64 +} + +// ----- + +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// CHECK-LABEL: func @sparse_unary( +// CHECK-SAME: %[[A:.*]]: f64) -> f64 { +// CHECK: %[[C1:.*]] = sparse_tensor.unary %[[A]]: f64 to f64 { +// CHECK: ^bb0(%[[A1:.*]]: f64): +// CHECK: sparse_tensor.yield %[[A1]] : f64 +// CHECK: } +// CHECK: missing={ +// CHECK: %[[R:.*]] = arith.constant -1.000000e+00 : f64 +// CHECK: sparse_tensor.yield %[[R]] : f64 +// CHECK: } +// CHECK: return %[[C1]] : f64 +// CHECK: } +func @sparse_unary(%arg0: f64) -> f64 { + %r = sparse_tensor.unary %arg0 : f64 to f64 { + ^bb0(%x: f64): + sparse_tensor.yield %x : f64 + } + missing={ + ^bb0: + %cf1 = arith.constant -1.0 : f64 + sparse_tensor.yield %cf1 : f64 + } + return %r : f64 +} + +// ----- + +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// CHECK-LABEL: func @sparse_unary( +// CHECK-SAME: %[[A:.*]]: f64) -> i64 { +// CHECK: %[[C1:.*]] = sparse_tensor.unary %[[A]] {include_index = true}: f64 to i64 { +// CHECK: ^bb0(%[[A1:.*]]: f64, %[[I1:.*]]: index, %[[I2:.*]]: index): +// CHECK: %[[R:.*]] = arith.index_cast %[[I2]] : index to i64 +// CHECK: sparse_tensor.yield %[[R]] : i64 +// CHECK: } +// CHECK: return %[[C1]] : i64 +// CHECK: } +func @sparse_unary(%arg0: f64) -> i64 { + %r = sparse_tensor.unary %arg0 {include_index=true}: f64 to i64 { + ^bb0(%x: f64, %row: index, %col: index): + %ret = arith.index_cast %col : index to i64 + sparse_tensor.yield %ret : i64 + } + return %r : i64 +}