diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -49,25 +49,24 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> { let summary = "Returns the broadcasted output shape of two inputs"; let description = [{ - Computes the broadcasted output shape following: - 1. If any inputs are unranked, output is unranked; - 2. Else the input array with number of dimensions smaller than the max - input dimension, has 1’s prepended to its shapes and the output shape is - calculated as follows: - - output[i] = lhs[i] if lhs[i] == rhs[i] or rhs[i] is unknown/undefined - = rhs[i] if lhs[i] is unknown/undefined - = lhs[i] if rhs[i] == 1 - = rhs[i] if lhs[i] == 1 - = error if lhs[i] != rhs[i] - - Op has an optional string attribute for the error case where there is no - broadcastable output shape possible for the given inputs. - - Op may also return an ExtentTensor, but this should only be done when this - is statically guaranteed to never fail, either because of a dependency on a - cstr_broadcastable operation or other details of the construction of the - program. + Returns the broadcasted shape for two input shapes or extent tensors. Both + operands can be of type `shape.shape` or `tensor`. The result is of + type `shape.shape` and, if both operands are tensors, may be of type + `tensor`. + + If the two operand shapes are of different rank the smaller one is padded + with 1's from the left. The resulting broadcasted shape is then defined as + + result[i] = lhs[i] if lhs[i] == rhs[i] + = lhs[i] if rhs[i] == 1 + = rhs[i] if lhs[i] == 1. + + In case the resulting shape is undefined, i.e. if corresponding extents are + different from each other but none is 1, the result is an error shape. + Likewise error values are propagated if any of the operands holds an error + value. If the result type is an extent tensor (and can therefore not hold + the error value) the behavior may be undefined. The optional string + attribute can be used to describe the error case. }]; let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, @@ -75,8 +74,11 @@ OptionalAttr:$error); let results = (outs Shape_ShapeOrExtentTensorType:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; + let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; let hasFolder = 1; let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -60,6 +60,31 @@ // ----- +// Basic case including extent tensors. +// CHECK-LABEL: @broadcast +func @broadcast() -> tensor { + // CHECK: shape.const_shape [7, 2] : tensor + %0 = shape.const_shape [1, 2] : tensor + %1 = shape.const_shape [7, 1] : tensor + %2 = shape.broadcast %0, %1 + : tensor, tensor -> tensor + return %2 : tensor +} + +// ----- + +// Basic case including extent tensors. +// CHECK-LABEL: @broadcast +func @broadcast() -> !shape.shape { + // CHECK: shape.const_shape [7, 2] : !shape.shape + %0 = shape.const_shape [1, 2] : tensor + %1 = shape.const_shape [7, 1] : tensor + %2 = shape.broadcast %0, %1 : tensor, tensor -> !shape.shape + return %2 : !shape.shape +} + +// ----- + // Rhs is a scalar. // CHECK-LABEL: func @f func @f(%arg0 : !shape.shape) -> !shape.shape { diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -138,17 +138,19 @@ // ----- -func @broadcast_error_possible(%arg0 : !shape.shape, %arg1 : !shape.shape) -> tensor { +func @broadcast(%arg0 : !shape.shape, %arg1 : !shape.shape) -> tensor { // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}} - %result = shape.broadcast %arg0, %arg1 : !shape.shape, !shape.shape -> tensor + %result = shape.broadcast %arg0, %arg1 + : !shape.shape, !shape.shape -> tensor return %result : tensor } // ----- -func @broadcast_error_possible(%arg0 : !shape.shape, %arg1 : tensor) -> tensor { +func @broadcast(%arg0 : !shape.shape, %arg1 : tensor) -> tensor { // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}} - %result = shape.broadcast %arg0, %arg1 : !shape.shape, tensor -> tensor + %result = shape.broadcast %arg0, %arg1 + : !shape.shape, tensor -> tensor return %result : tensor }