diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -104,7 +104,13 @@ }]; } -def Shape_ShapeOrSizeType: AnyTypeOf<[Shape_SizeType, Shape_ShapeType], +def Shape_ShapeOrSizeType : AnyTypeOf<[Shape_SizeType, Shape_ShapeType], "shape or size">; +def Shape_ExtentTensorType : 1DTensorOf<[Index]>; + +def Shape_ShapeOrExtentTensorType : AnyTypeOf<[Shape_ShapeType, + Shape_ExtentTensorType], + "shape or extent tensor">; + #endif // SHAPE_BASE_TD diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -138,6 +138,22 @@ let hasFolder = 1; } +def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Commutative, NoSideEffect]> { + let summary = "Returns whether the input shapes or extent tensors are equal"; + let description = [{ + Takes two shape or extent tensor operands and determines whether they are + equal. When extent tensors are compared to shapes they are regarded as their + equivalent non-error shapes. Error shapes can be tested for equality like + any other shape value, meaning that the error value is equal to itself. + }]; + + let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, + Shape_ShapeOrExtentTensorType:$rhs); + let results = (outs I1:$result); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)"; +} + def Shape_FromExtentsOp : Shape_Op<"from_extents", [NoSideEffect]> { let summary = "Creates a shape from extents"; let description = [{ diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -116,3 +116,18 @@ %rank = shape.rank %shape return %rank : !shape.size } + +func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 { + %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape + return %result : i1 +} + +func @shape_eq_on_tensors(%a : tensor, %b : tensor) -> i1 { + %result = shape.shape_eq %a, %b : tensor, tensor + return %result : i1 +} + +func @shape_eq_on_mixed(%a : tensor, %b : !shape.shape) -> i1 { + %result = shape.shape_eq %a, %b : tensor, !shape.shape + return %result : i1 +}