diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt @@ -1,6 +1,10 @@ add_mlir_dialect(SparseTensorOps sparse_tensor) add_mlir_doc(SparseTensorOps SparseTensorOps Dialects/ -gen-dialect-doc) +set(LLVM_TARGET_DEFINITIONS SparseTensorOps.td) +mlir_tablegen(SparseTensorOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(SparseTensorOpsEnums.cpp.inc -gen-enum-defs) + set(LLVM_TARGET_DEFINITIONS SparseTensorAttrDefs.td) mlir_tablegen(SparseTensorAttrDefs.h.inc -gen-attrdef-decls) mlir_tablegen(SparseTensorAttrDefs.cpp.inc -gen-attrdef-defs) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -16,6 +16,8 @@ #include "mlir/IR/TensorEncoding.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsEnums.h.inc" + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.h.inc" diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -13,6 +13,7 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/EnumAttr.td" //===----------------------------------------------------------------------===// // Base class. @@ -379,4 +380,192 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// Sparse Tensor Custom Linalg.Generic Operations. +//===----------------------------------------------------------------------===// + +def OverlapKindIntersection : I32EnumAttrCase<"Intersection", 0, "intersection">; +def OverlapKindLeftUnion : I32EnumAttrCase<"LeftUnion", 1, "left_union">; +def OverlapKindRightUnion : I32EnumAttrCase<"RightUnion", 2, "right_union">; +def OverlapKindUnion : I32EnumAttrCase<"Union", 3, "union">; + +/// Enum attribute of the different kinds of overlap for binary regions. +def OverlapKindAttr : I32EnumAttr<"OverlapKind", "sparse_tensor binary overlap kind", + [OverlapKindIntersection, OverlapKindLeftUnion, OverlapKindRightUnion, OverlapKindUnion]> { + let cppNamespace = "::mlir::sparse_tensor"; +} + +def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [NoSideEffect, SameTypeOperands]>, + Arguments<(ins AnyType:$x, AnyType:$y, OverlapKindAttr:$kind)>, + Results<(outs AnyType:$output)> { + let summary = "Binary set operation utilized within linalg.generic"; + let description = [{ + Defines a computation within `linalg.generic` operation that takes two operands and executes + one of the regions depending on whether both operands or either operand is nonzero (i.e. stored + explicitly in sparse storage format). + + Three regions are defined for the operation and must appear in this order (if present): + - primary (elements present in both sparse tensors) + - left (elements only present in the left sparse tensor) + - right (element only present in the right sparse tensor) + + Each region contains a single block describing the computation and result. + The block must end with sparse_tensor.yield and the return type must match the type of `output`. + The primary region's block has two arguments, while the left and right region's block + has only one argument. + + A region may also be declared empty (i.e. `left={ }`, implying that the output is a missing value. + + The `kind` attribute provides restrictions and default behaviors for the `left` and + `right` regions, which are both optional. + - intersection (left and right must be empty) + - left_union (left region returns identity if not defined; right must be empty) + - right_union (right region returns identity if not defined; left must be empty) + - union (left and right return identity if not defined) + + Example of isEqual applied for intersecting elements only: + ```mlir + %C = sparse_tensor.init... + %0 = linalg.generic #trait + ins(%A: tensor, %B: tensor) + outs(%C: tensor) { + ^bb0(%a: f64, %b: f64, %c: i8) : + %result = sparse_tensor.binary intersection %a, %b : f64 to i8 { + ^bb0(%arg0: f64, %arg1: f64): + %cmp = arith.cmpf "oeq", %arg0, %arg1 : f64 + %ret_i8 = arith.extui %cmp : i1 to i8 + sparse_tensor.yield %ret_i8 : i8 + } + linalg.yield %result : i8 + } -> tensor + ``` + + Example of A+B in upper triangle, A-B in lower triangle: + ```mlir + %C = sparse_tensor.init... + %1 = linalg.generic #trait + ins(%A: tensor, %B: tensor + outs(%C: tensor { + ^bb0(%a: f64, %b: f64, %c: f64) : + %row = linalg.index 0 : index + %col = linalg.index 1 : index + %result = sparse_tensor.binary union %a, %b : f64 to f64 { + ^bb0(%x: f64, %y: f64): + %cmp = arith.cmpi "uge", %column, %row : index + %upperTriangleResult = arith.addf %x, %y : f64 + %lowerTriangleResult = arith.subf %x, %y : f64 + %ret = arith.select %cmp, %upperTriangleResult, %lowerTriangleResult : f64 + sparse_tensor.yield %ret : f64 + } right={ + ^bb0(%y: f64): + %cmp = arith.cmpi "uge", %column, %row : index + %lowerTriangleResult = arith.negf %y : f64 + %ret = arith.select %cmp, %y, %lowerTriangleResult + sparse_tensor.yield %ret : f64 + } + linalg.yield %result : f64 + } -> tensor + ``` + }]; + + let regions = (region AnyRegion:$primaryRegion, AnyRegion:$leftRegion, AnyRegion:$rightRegion); + let assemblyFormat = [{ + $kind $x `,` $y `:` attr-dict type($x) `to` type($output) $primaryRegion (`left` `=` $leftRegion^)? (`right` `=` $rightRegion^)? + }]; + let hasVerifier = 1; +} + +def SparseTensor_UnaryOp : SparseTensor_Op<"unary", [NoSideEffect]>, + Arguments<(ins AnyType:$x)>, + Results<(outs AnyType:$output)> { + let summary = "Unary set operation utilized within linalg.generic"; + let description = [{ + Defines a computation with a `linalg.generic` operation that takes a single operand and executes + one of two regions depending on whether the operand is nonzero (i.e. stored explicitly in the sparse + storage format). + + Two regions are defined for the operation must appear in this order (if present): + - primary (elements present in the sparse tensor) + - missing (elements not present in the sparse tensor) + + Each region contains a single block describing the computation and result. + The block must end with sparse_tensor.yield and the return type must match the type of `output`. + The primary region's block has one argument, while the missing region's block + has zero arguments. + + A region may also be empty, implying that the output is a missing value. + + The primary region is required. + The missing region is optional and is assumed to be empty if not defined. + + Example of A+1, restricted to existing elements: + ```mlir + %C = sparse_tensor.init... + %0 = linalg.generic #trait + ins(%A: tensor) + outs(%C: tensor) { + ^bb0(%a: f64, %c: f64) : + %result = sparse_tensor.unary %a : f64 to f64 { + ^bb0(%arg0: f64): + %cf1 = arith.constant 1.0 : f64 + %ret = arith.addf %arg0, %cf1 : f64 + sparse_tensor.yield %ret : f64 + } + linalg.yield %result : f64 + } -> tensor + ``` + + Example returning +1 for existing values and -1 for missing values: + ```mlir + %result = sparse_tensor.unary %a : f64 to i64 { + ^bb0(%x: f64): + %ret = arith.constant 1 : i64 + sparse_tensor.yield %ret : i64 + } missing={ + %ret = arith.constant -1 : i64 + sparse_tensor.yield %ret : i64 + } + ``` + + Example showing a structural inversion (existing values become missing in the output, + while missing values are filled with 1): + ```mlir + %result = sparse_tensor.unary %a : f64 to i64 { + } missing={ + %ret = arith.constant 1 : i64 + sparse_tensor.yield %ret : i64 + } + ``` + }]; + + let regions = (region AnyRegion:$primaryRegion, AnyRegion:$missingRegion); + let assemblyFormat = [{ + $x attr-dict `:` type($x) `to` type($output) $primaryRegion (`missing` `=` $missingRegion^)? + }]; + let hasVerifier = 1; +} + +def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>, + Arguments<(ins AnyType:$result)> { + let summary = "Yield from sparse_tensor set-like operations"; + let description = [{ + Yield a value from within a `binary` or `unary` block. + + Example: + ``` + %0 = sparse_tensor.unary %a : i64 to i64 { + ^bb0(%arg0: i64): + %cst = arith.constant 1 : i64 + %ret = arith.addi %arg0, %cst : i64 + sparse_tensor.yield %ret : i64 + } + ``` + }]; + + let assemblyFormat = [{ + $result attr-dict `:` type($result) + }]; + let hasVerifier = 1; +} + #endif // SPARSETENSOR_OPS diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -21,6 +21,8 @@ // TensorDialect Attribute Methods. //===----------------------------------------------------------------------===// +#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsEnums.cpp.inc" + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" @@ -333,6 +335,119 @@ return success(); } +//===----------------------------------------------------------------------===// +// TensorDialect Linalg.Generic Operations. +//===----------------------------------------------------------------------===// + +template +static LogicalResult +verifyNumBlockArgs(T *op, Region ®ion, const char *regionName, + unsigned expectedNum, Type inputType, Type outputType) { + unsigned numArgs = region.getNumArguments(); + if (numArgs != expectedNum) + return op->emitError() << regionName << " region must have exactly " + << expectedNum << " arguments"; + + for (unsigned i = 0; i < numArgs; i++) { + Type typ = region.getArgument(i).getType(); + if (typ != inputType) + return op->emitError() << regionName << " region argument " << (i + 1) + << " type mismatch"; + } + Operation *term = region.front().getTerminator(); + YieldOp yield = dyn_cast(term); + if (!yield) + return op->emitError() << regionName + << " region must end with sparse_tensor.yield"; + if (yield.getOperand().getType() != outputType) + return op->emitError() << regionName << " region yield type mismatch"; + + return success(); +} + +LogicalResult BinaryOp::verify() { + NamedAttrList attrs = (*this)->getAttrs(); + Type inputType = x().getType(); + Type outputType = output().getType(); + OverlapKind kind = attrs.get("kind").cast().getValue(); + Region &primary = primaryRegion(); + Region &left = leftRegion(); + Region &right = rightRegion(); + + // Verify that expected empty region (based on kind) are actually empty + if (kind == OverlapKind::Intersection) { + if (!left.empty() || !right.empty()) + return emitError("left and right region must be empty for intersection"); + } else if (kind == OverlapKind::LeftUnion) { + if (!right.empty()) + return emitError("right region must be empty for left_union"); + } else if (kind == OverlapKind::RightUnion) { + if (!left.empty()) + return emitError("left region must be empty for right_union"); + } + + // Check correct number of arguments and return type for each non-empty region + LogicalResult regionResult = success(); + if (!primary.empty()) { + regionResult = + verifyNumBlockArgs(this, primary, "primary", 2, inputType, outputType); + if (failed(regionResult)) + return regionResult; + } + if (!left.empty()) { + regionResult = + verifyNumBlockArgs(this, left, "left", 1, inputType, outputType); + if (failed(regionResult)) + return regionResult; + } + if (!right.empty()) { + regionResult = + verifyNumBlockArgs(this, right, "right", 1, inputType, outputType); + if (failed(regionResult)) + return regionResult; + } + + return success(); +} + +LogicalResult UnaryOp::verify() { + Type inputType = x().getType(); + Type outputType = output().getType(); + LogicalResult regionResult = success(); + + // Check the number of block arguments and return type for all non-empty + // regions + Region &primary = primaryRegion(); + if (!primary.empty()) { + regionResult = + verifyNumBlockArgs(this, primary, "primary", 1, inputType, outputType); + if (failed(regionResult)) + return regionResult; + } + Region &missing = missingRegion(); + if (!missing.empty()) { + regionResult = + verifyNumBlockArgs(this, missing, "missing", 0, inputType, outputType); + if (failed(regionResult)) + return regionResult; + } + + return success(); +} + +LogicalResult YieldOp::verify() { + // Check for compatible parent + auto *parentOp = (*this)->getParentOp(); + if (auto binaryOp = dyn_cast(parentOp)) + return success(); + if (auto unaryOp = dyn_cast(parentOp)) + return success(); + + // NOTE: Return type check is performed in each parent op's verify method + + return emitOpError("expected parent op to be sparse_tensor binary or unary"); +} + //===----------------------------------------------------------------------===// // TensorDialect Methods. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -212,3 +212,97 @@ sparse_tensor.out %arg0, %arg1 : tensor<10xf64>, !llvm.ptr return } + +// ----- + +func @invalid_binary_kind(%arg0: f64, %arg1: f64) -> f64 { + // expected-error@+1 {{custom op 'sparse_tensor.binary' expected string or keyword containing one of the following enum values for attribute 'kind' [intersection, left_union, right_union, union]}} + %r = sparse_tensor.binary %arg0, %arg1 : f64 to f64 { + ^bb0(%x: f64): + sparse_tensor.yield %x : f64 + } + return %r : f64 +} + +// ----- + +func @invalid_binary_num_args_mismatch(%arg0: f64, %arg1: f64) -> f64 { + // expected-error@+1 {{primary region must have exactly 2 arguments}} + %r = sparse_tensor.binary intersection %arg0, %arg1 : f64 to f64 { + ^bb0(%x: f64): + sparse_tensor.yield %x : f64 + } + return %r : f64 +} + +// ----- + +func @invalid_binary_argtype_mismatch(%arg0: f64, %arg1: f64) -> f64 { + // expected-error@+1 {{primary region argument 2 type mismatch}} + %r = sparse_tensor.binary left_union %arg0, %arg1 : f64 to f64 { + ^bb0(%x: f64, %y: index): + sparse_tensor.yield %x : f64 + } + return %r : f64 +} + +// ----- + +func @invalid_binary_region_override(%arg0: f64, %arg1: f64) -> f64 { + // expected-error@+1 {{right region must be empty for left_union}} + %r = sparse_tensor.binary left_union %arg0, %arg1 : f64 to f64 { + } left={ + } right={ + ^bb0(%y: f64): + sparse_tensor.yield %y : f64 + } + return %r : f64 +} + +// ----- + +func @invalid_binary_wrong_return_type(%arg0: f64, %arg1: f64) -> f64 { + // expected-error@+1 {{right region yield type mismatch}} + %0 = sparse_tensor.binary right_union %arg0, %arg1 : f64 to f64 { + } right={ + ^bb0(%x: f64): + %1 = arith.constant 0.0 : f32 + sparse_tensor.yield %1 : f32 + } + return %0 : f64 +} + +// ----- + +func @invalid_unary_argtype_mismatch(%arg0: f64) -> f64 { + // expected-error@+1 {{primary region argument 1 type mismatch}} + %r = sparse_tensor.unary %arg0 : f64 to f64 { + ^bb0(%x: index): + sparse_tensor.yield %x : index + } + return %r : f64 +} + +// ----- + +func @invalid_unary_num_args_mismatch(%arg0: f64) -> f64 { + // expected-error@+1 {{missing region must have exactly 0 arguments}} + %r = sparse_tensor.unary %arg0 {include_index=true} : f64 to f64 { + } missing={ + ^bb0(%x: f64): + sparse_tensor.yield %x : f64 + } + return %r : f64 +} + +// ----- + +func @invalid_unary_wrong_return_type(%arg0: f64) -> f64 { + // expected-error@+1 {{primary region yield type mismatch}} + %0 = sparse_tensor.unary %arg0 : f64 to f64 { + ^bb0(%x: f64): + %1 = arith.constant 0.0 : f32 + sparse_tensor.yield %1 : f32 + } + return %0 : f64 +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -193,3 +193,78 @@ sparse_tensor.out %arg0, %arg1 : tensor, !llvm.ptr return } + +// ----- + +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// CHECK-LABEL: func @sparse_binary( +// CHECK-SAME: %[[A:.*]]: f64, %[[B:.*]]: f64) -> f64 { +// CHECK: %[[C1:.*]] = sparse_tensor.binary right_union %[[A]], %[[B]] : f64 to f64 { +// CHECK: ^bb0(%[[A1:.*]]: f64, %[[B1:.*]]: f64): +// CHECK: sparse_tensor.yield %[[A1]] : f64 +// CHECK: } right = { +// CHECK: ^bb0(%[[A2:.*]]: f64): +// CHECK: sparse_tensor.yield %[[A2]] : f64 +// CHECK: } +// CHECK: return %[[C1]] : f64 +// CHECK: } +func @sparse_binary(%arg0: f64, %arg1: f64) -> f64 { + %r = sparse_tensor.binary right_union %arg0, %arg1 : f64 to f64 { + ^bb0(%x: f64, %y: f64): + sparse_tensor.yield %x : f64 + } right={ + ^bb0(%y: f64): + sparse_tensor.yield %y : f64 + } + return %r : f64 +} + +// ----- + +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// CHECK-LABEL: func @sparse_unary( +// CHECK-SAME: %[[A:.*]]: f64) -> f64 { +// CHECK: %[[C1:.*]] = sparse_tensor.unary %[[A]] : f64 to f64 { +// CHECK: ^bb0(%[[A1:.*]]: f64): +// CHECK: sparse_tensor.yield %[[A1]] : f64 +// CHECK: } missing = { +// CHECK: %[[R:.*]] = arith.constant -1.000000e+00 : f64 +// CHECK: sparse_tensor.yield %[[R]] : f64 +// CHECK: } +// CHECK: return %[[C1]] : f64 +// CHECK: } +func @sparse_unary(%arg0: f64) -> f64 { + %r = sparse_tensor.unary %arg0 : f64 to f64 { + ^bb0(%x: f64): + sparse_tensor.yield %x : f64 + } missing={ + ^bb0: + %cf1 = arith.constant -1.0 : f64 + sparse_tensor.yield %cf1 : f64 + } + return %r : f64 +} + +// ----- + +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// CHECK-LABEL: func @sparse_unary( +// CHECK-SAME: %[[A:.*]]: f64) -> i64 { +// CHECK: %[[C1:.*]] = sparse_tensor.unary %[[A]] : f64 to i64 { +// CHECK: ^bb0(%[[A1:.*]]: f64): +// CHECK: %[[R:.*]] = arith.fptosi %[[A1]] : f64 to i64 +// CHECK: sparse_tensor.yield %[[R]] : i64 +// CHECK: } +// CHECK: return %[[C1]] : i64 +// CHECK: } +func @sparse_unary(%arg0: f64) -> i64 { + %r = sparse_tensor.unary %arg0 : f64 to i64 { + ^bb0(%x: f64): + %ret = arith.fptosi %x : f64 to i64 + sparse_tensor.yield %ret : i64 + } + return %r : i64 +}