diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -70,29 +70,36 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> { let summary = "Returns the broadcasted output shape of two inputs"; let description = [{ - Computes the broadcasted output shape following: - 1. If any inputs are unranked, output is unranked; - 2. Else the input array with number of dimensions smaller than the max - input dimension, has 1’s prepended to its shapes and the output shape is - calculated as follows: - - output[i] = lhs[i] if lhs[i] == rhs[i] or rhs[i] is unknown/undefined - = rhs[i] if lhs[i] is unknown/undefined - = lhs[i] if rhs[i] == 1 - = rhs[i] if lhs[i] == 1 - = error if lhs[i] != rhs[i] - - Op has an optional string attribute for the error case where there is no - broadcastable output shape possible for the given inputs. + Returns the broadcasted shape for two input shapes or extent tensors. Both + operands can be of type `shape.shape` or `tensor`. The result is of + type `shape.shape` and, if both operands are tensors, may be of type + `tensor`. + + If the two operand shapes are of different rank the smaller one is padded + with 1's from the left. The resulting broadcasted shape is then defined as + + result[i] = lhs[i] if lhs[i] == rhs[i] + = lhs[i] if rhs[i] == 1 + = rhs[i] if lhs[i] == 1. + + In case the resulting shape is undefined, i.e. if corresponding extents are + different from each other but none is 1, the result is an error shape. + Likewise error values are propagated if any of the operands holds an error + value. If the result type is an extent tensor (and can therefore not hold + the error value) the behavior may be undefined. The optional string + attribute can be used to describe the error case. }]; let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, Shape_ShapeOrExtentTensorType:$rhs, OptionalAttr:$error); - let results = (outs Shape_ShapeType:$result); + let results = (outs Shape_ShapeOrExtentTensorType:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)"; + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; + let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; let hasFolder = 1; } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -54,18 +54,31 @@ // CHECK: shape.const_shape [7, 2] : !shape.shape %0 = shape.const_shape [1, 2] : !shape.shape %1 = shape.const_shape [7, 1] : !shape.shape - %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape + %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape return %2 : !shape.shape } // ----- +// Basic case including extent tensors. +// CHECK-LABEL: @broadcast +func @broadcast() -> tensor { + // CHECK: shape.const_shape [7, 2] : tensor + %0 = shape.const_shape [1, 2] : tensor + %1 = shape.const_shape [7, 1] : tensor + %2 = shape.broadcast %0, %1 + : tensor, tensor -> tensor + return %2 : tensor +} + +// ----- + // Rhs is a scalar. // CHECK-LABEL: func @f func @f(%arg0 : !shape.shape) -> !shape.shape { // CHECK: return %arg0 %0 = shape.const_shape [] : !shape.shape - %1 = shape.broadcast %arg0, %0 : !shape.shape, !shape.shape + %1 = shape.broadcast %arg0, %0 : !shape.shape, !shape.shape -> !shape.shape return %1 : !shape.shape } @@ -76,7 +89,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape { // CHECK: return %arg0 %0 = shape.const_shape [] : !shape.shape - %1 = shape.broadcast %0, %arg0 : !shape.shape, !shape.shape + %1 = shape.broadcast %0, %arg0 : !shape.shape, !shape.shape -> !shape.shape return %1 : !shape.shape } @@ -89,7 +102,7 @@ // CHECK: return %[[CST]] %0 = shape.const_shape [] : !shape.shape %1 = shape.const_shape [1, 2, 3] : !shape.shape - %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape + %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape return %2 : !shape.shape } @@ -101,7 +114,7 @@ // CHECK: shape.broadcast %0 = shape.const_shape [2] : !shape.shape %1 = shape.const_shape [7] : !shape.shape - %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape + %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape return %2 : !shape.shape } diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -128,3 +128,12 @@ return %result : index } +// ----- + +func @broadcast(%a : !shape.shape, %b : tensor) -> tensor { + // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}} + %result = shape.broadcast %a, %b + : !shape.shape, tensor -> tensor + return %result : tensor +} + diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -49,7 +49,7 @@ func @test_broadcast_fixed() { %0 = shape.const_shape [10, 1, 57, 92] : !shape.shape %1 = shape.const_shape [4, 57, 92] : !shape.shape - %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape + %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return }