diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -134,6 +134,20 @@ let hasFolder = 1; } +def Shape_EqOp : Shape_Op<"eq", [Commutative, NoSideEffect, SameTypeOperands]> { + let summary = "Returns whether the input shapes or sizes are equal"; + let description = [{ + Takes either two shape operands or two size operands and determines whether + they are equal. + Shape operands are equal if their rank and all extents are equal. + }]; + + let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs); + let results = (outs I1:$result); + + let assemblyFormat = "$lhs `,` $rhs `:` type($lhs) attr-dict"; +} + def Shape_FromExtentsOp : Shape_Op<"from_extents", [NoSideEffect]> { let summary = "Creates a shape from extents"; let description = [{ diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -111,3 +111,13 @@ %0 = shape.from_extent_tensor %arg : tensor return %0 : !shape.shape } + +func @eq_shape(%a : !shape.shape, %b : !shape.shape) -> i1 { + %result = shape.eq %a, %b : !shape.shape + return %result : i1 +} + +func @eq_size(%a : !shape.size, %b : !shape.size) -> i1 { + %result = shape.eq %a, %b : !shape.size + return %result : i1 +}