diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -387,13 +387,52 @@ used to return an error to the user upon mismatch of dimensions. ```mlir - %c = shape.join %a, %b, error="" : !shape.shape + %c = shape.join %a, %b, error="" : !shape.shape, !shape.shape -> !shape.shape ``` }]; let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1, OptionalAttr:$error); let results = (outs Shape_ShapeOrSizeType:$result); + + let assemblyFormat = [{ + $arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:` + type($arg0) `,` type($arg1) `->` type($result) + }]; +} + +def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> { + let summary = "Elementwise maximum"; + let description = [{ + Computes the elementwise maximum of two shapes with equal ranks. If either + operand is an error, then an error will be propagated to the result. If the + input types mismatch or the ranks do not match, then the result is an + error. + }]; + + let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs); + let results = (outs Shape_ShapeOrSizeType:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> { + let summary = "Elementwise minimum"; + let description = [{ + Computes the elementwise maximum of two shapes with equal ranks. If either + operand is an error, then an error will be propagated to the result. If the + input types mismatch or the ranks do not match, then the result is an + error. + }]; + + let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs); + let results = (outs Shape_ShapeOrSizeType:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; } def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> { diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -115,7 +115,7 @@ } func @eq_on_extent_tensors(%lhs : tensor, - %rhs : tensor) { + %rhs : tensor) { %w0 = shape.cstr_eq %lhs, %rhs : tensor, tensor return } @@ -183,7 +183,6 @@ return %rank : index } - func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 { %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape return %result : i1 @@ -289,3 +288,35 @@ : !shape.shape, !shape.shape return %result : i1 } + +func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape { + %0 = shape.const_shape [4, 57, 92] : !shape.shape + %1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape + %2 = shape.join %0, %1, error="exceeded element-wise upper bound" : + !shape.shape, !shape.shape -> !shape.shape + return %2 : !shape.shape +} + +func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape { + %0 = shape.const_shape [4, 57, 92] : !shape.shape + %1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape + %2 = shape.join %0, %1, error="lower bound element-wise exceeded" : + !shape.shape, !shape.shape -> !shape.shape + return %2 : !shape.shape +} + +func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size { + %0 = shape.const_size 5 + %1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size + %2 = shape.join %0, %1, error="exceeded element-wise upper bound" : + !shape.size, !shape.size -> !shape.size + return %2 : !shape.size +} + +func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size { + %0 = shape.const_size 9 + %1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size + %2 = shape.join %0, %1, error="lower bound element-wise exceeded" : + !shape.size, !shape.size -> !shape.size + return %2 : !shape.size +}