diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -148,6 +148,21 @@ let hasFolder = 1; } +def Shape_SizeEqOp : Shape_Op<"size_eq", [Commutative, NoSideEffect]> { + let summary = "Returns whether the input sizes are equal"; + let description = [{ + Takes two size or index operands and determines whether they are equal. + Index values are regarded as equivalent to non-error size values. Error + values can be tested for equality like any other size value, meaning that + error value is equal to itself. + }]; + + let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs); + let results = (outs I1:$result); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)"; +} + def Shape_FromExtentsOp : Shape_Op<"from_extents", [NoSideEffect]> { let summary = "Creates a shape from extents"; let description = [{ diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -258,3 +258,18 @@ : tensor, tensor, tensor -> tensor return %result : tensor } + +func @size_eq_on_sizes(%a : !shape.size, %b : !shape.size) -> i1 { + %result = shape.size_eq %a, %b : !shape.size, !shape.size + return %result : i1 +} + +func @size_eq_on_indices(%a : index, %b : index) -> i1 { + %result = shape.size_eq %a, %b : index, index + return %result : i1 +} + +func @size_eq(%a : !shape.size, %b : index) -> i1 { + %result = shape.size_eq %a, %b : !shape.size, index + return %result : i1 +}