diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -397,7 +397,34 @@ let hasCanonicalizer = 1; } -def Shape_JoinOp : Shape_Op<"join", +def Shape_MaxOp : Shape_Op<"max", + [Commutative, NoSideEffect, + DeclareOpInterfaceMethods]> { + let summary = "Elementwise maximum"; + let description = [{ + Computes the elementwise maximum of two sizes or shapes with equal ranks. + If either operand is an error, then an error will be propagated to the + result. If the input types mismatch or the ranks do not match, then the + result is an error. + }]; + + let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs); + let results = (outs Shape_ShapeOrSizeType:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; + + let hasFolder = 1; + + let extraClassDeclaration = [{ + // Returns when two result types are compatible for this op; method used by + // InferTypeOpInterface + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; +} + +def Shape_MeetOp : Shape_Op<"meet", [Commutative, DeclareOpInterfaceMethods]> { let summary = "Returns the least general shape.shape of its operands"; let description = [{ @@ -408,21 +435,21 @@ contradictory requirements. E.g., using pseudo code ``` - shape.join([*], [*]) -> [*] - shape.join([*], [1, ?]) -> [1, ?] - shape.join([1, 2], [1, ?]) -> [1, 2] - shape.join([*], [1, 2]) -> [1, 2] - shape.join([], []) -> [] - shape.join([], [*]) -> [] - shape.join([], [?, ?]) -> [invalid] - shape.join([1, ?], [2, ?, ?]) -> [invalid] + shape.meet([*], [*]) -> [*] + shape.meet([*], [1, ?]) -> [1, ?] + shape.meet([1, 2], [1, ?]) -> [1, 2] + shape.meet([*], [1, 2]) -> [1, 2] + shape.meet([], []) -> [] + shape.meet([], [*]) -> [] + shape.meet([], [?, ?]) -> [invalid] + shape.meet([1, ?], [2, ?, ?]) -> [invalid] ``` - `shape.join` also allows specifying an optional error string, that may be + `shape.meet` also allows specifying an optional error string, that may be used to return an error to the user upon mismatch of dimensions. ```mlir - %c = shape.join %a, %b, error="" : !shape.shape, !shape.shape -> !shape.shape + %c = shape.meet %a, %b, error="" : !shape.shape, !shape.shape -> !shape.shape ``` }]; @@ -442,33 +469,6 @@ }]; } -def Shape_MaxOp : Shape_Op<"max", - [Commutative, NoSideEffect, - DeclareOpInterfaceMethods]> { - let summary = "Elementwise maximum"; - let description = [{ - Computes the elementwise maximum of two sizes or shapes with equal ranks. - If either operand is an error, then an error will be propagated to the - result. If the input types mismatch or the ranks do not match, then the - result is an error. - }]; - - let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs); - let results = (outs Shape_ShapeOrSizeType:$result); - - let assemblyFormat = [{ - $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) - }]; - - let hasFolder = 1; - - let extraClassDeclaration = [{ - // Returns when two result types are compatible for this op; method used by - // InferTypeOpInterface - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); - }]; -} - def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect, DeclareOpInterfaceMethods]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1177,10 +1177,10 @@ } //===----------------------------------------------------------------------===// -// JoinOp +// MeetOp //===----------------------------------------------------------------------===// -LogicalResult mlir::shape::JoinOp::inferReturnTypes( +LogicalResult mlir::shape::MeetOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { @@ -1188,7 +1188,7 @@ return success(); } -bool mlir::shape::JoinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { +bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { if (l.size() != 1 || r.size() != 1) return false; if (l == r) diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -65,7 +65,7 @@ func @test_shape_any_fixed() { %0 = shape.const_shape [4, 57, 92] : !shape.shape %1 = shape.const_shape [4, 57, 92] : !shape.shape - %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return } @@ -73,7 +73,7 @@ func @test_shape_any_unknown() { %0 = shape.const_shape [4, -1, 92] : !shape.shape %1 = shape.const_shape [-1, 57, 92] : !shape.shape - %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return } @@ -81,7 +81,7 @@ func @test_shape_any_fixed_mismatch() { %0 = shape.const_shape [4, 57, 92] : !shape.shape %1 = shape.const_shape [2, 57, 92] : !shape.shape - %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return } @@ -243,7 +243,7 @@ func @shape_equal_shapes(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape { %0 = shape.shape_of %a : !shape.value_shape -> !shape.shape %1 = shape.shape_of %b : !shape.value_shape -> !shape.shape - %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape return %2 : !shape.shape } func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape { @@ -293,7 +293,7 @@ func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape { %0 = shape.const_shape [4, 57, 92] : !shape.shape %1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape - %2 = shape.join %0, %1, error="exceeded element-wise upper bound" : + %2 = shape.meet %0, %1, error="exceeded element-wise upper bound" : !shape.shape, !shape.shape -> !shape.shape return %2 : !shape.shape } @@ -301,7 +301,7 @@ func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape { %0 = shape.const_shape [4, 57, 92] : !shape.shape %1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape - %2 = shape.join %0, %1, error="lower bound element-wise exceeded" : + %2 = shape.meet %0, %1, error="lower bound element-wise exceeded" : !shape.shape, !shape.shape -> !shape.shape return %2 : !shape.shape } @@ -309,7 +309,7 @@ func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size { %0 = shape.const_size 5 %1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size - %2 = shape.join %0, %1, error="exceeded element-wise upper bound" : + %2 = shape.meet %0, %1, error="exceeded element-wise upper bound" : !shape.size, !shape.size -> !shape.size return %2 : !shape.size } @@ -317,7 +317,7 @@ func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size { %0 = shape.const_size 9 %1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size - %2 = shape.join %0, %1, error="lower bound element-wise exceeded" : + %2 = shape.meet %0, %1, error="lower bound element-wise exceeded" : !shape.size, !shape.size -> !shape.size return %2 : !shape.size }