diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -181,6 +181,27 @@ let assemblyFormat = "attr-dict $input `:` type($input)"; } +def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> { + let summary = "Determines if 2 shapes can be successfully broadcasted"; + let description = [{ + Given two input shapes or extent tensors, return a predicate specifying if + they are broadcastable. This broadcastable follows the same logic as what + shape.broadcast documents. + + Example: + ```mlir + %true = shape.is_broadcastable [2,2], [3,1,2] + %false = shape.is_broadcastable [2,2], [3,2] + ``` + }]; + + let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, + Shape_ShapeOrExtentTensorType:$rhs); + let results = (outs I1:$result); + + let assemblyFormat = "$lhs `,` $rhs `:` type($lhs) `,` type($rhs) attr-dict"; +} + def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> { let summary = "Gets the rank of a shape"; let description = [{ diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -260,3 +260,17 @@ : tensor, tensor, tensor -> tensor return %result : tensor } + +func @is_broadcastable_on_extent_tensors(%a : tensor, + %b : tensor) -> i1 { + %result = shape.is_broadcastable %a, %b + : tensor, tensor + return %result : i1 +} + +func @is_broadcastable_on_shapes(%a : !shape.shape, + %b : !shape.shape) -> i1 { + %result = shape.is_broadcastable %a, %b + : !shape.shape, !shape.shape + return %result : i1 +}