diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -70,28 +70,32 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> { let summary = "Returns the broadcasted output shape of two inputs"; let description = [{ - Computes the broadcasted output shape following: - 1. If any inputs are unranked, output is unranked; - 2. Else the input array with number of dimensions smaller than the max - input dimension, has 1’s prepended to its shapes and the output shape is - calculated as follows: - - output[i] = lhs[i] if lhs[i] == rhs[i] or rhs[i] is unknown/undefined - = rhs[i] if lhs[i] is unknown/undefined - = lhs[i] if rhs[i] == 1 - = rhs[i] if lhs[i] == 1 - = error if lhs[i] != rhs[i] - - Op has an optional string attribute for the error case where there is no - broadcastable output shape possible for the given inputs. + Returns the broadcasted shape for two input shapes or extent tensors. Both + operands and the result type can be of type `shape.shape` or + `tensor`. If the two operand shapes are of different rank the + smaller one is padded with 1's from the left. The resulting broadcasted + shape is then defined as + + result[i] = lhs[i] if lhs[i] == rhs[i] + = lhs[i] if rhs[i] == 1 + = rhs[i] if lhs[i] == 1. + + In case the resulting shape is undefined, i.e. if corresponding extents are + different from each other but none is 1, the result is an error shape. + Likewise error values are propagated if any of the operands holds an error + value. If the result type is an extent tensor (and can therefore not hold + the error value) the behavior is undefined. The optional string attribute + can be used to describe the error case. }]; let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, Shape_ShapeOrExtentTensorType:$rhs, OptionalAttr:$error); - let results = (outs Shape_ShapeType:$result); + let results = (outs Shape_ShapeOrExtentTensorType:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)"; + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; let hasFolder = 1; } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -54,18 +54,30 @@ // CHECK: shape.const_shape [7, 2] : !shape.shape %0 = shape.const_shape [1, 2] : !shape.shape %1 = shape.const_shape [7, 1] : !shape.shape - %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape + %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape return %2 : !shape.shape } // ----- +// Basic case including extent tensors. +// CHECK-LABEL: @broadcast +func @broadcast() -> tensor { + // CHECK: shape.const_shape [7, 2] : tensor + %0 = shape.const_shape [1, 2] : !shape.shape + %1 = shape.const_shape [7, 1] : tensor + %2 = shape.broadcast %0, %1 : !shape.shape, tensor -> tensor + return %2 : tensor +} + +// ----- + // Rhs is a scalar. // CHECK-LABEL: func @f func @f(%arg0 : !shape.shape) -> !shape.shape { // CHECK: return %arg0 %0 = shape.const_shape [] : !shape.shape - %1 = shape.broadcast %arg0, %0 : !shape.shape, !shape.shape + %1 = shape.broadcast %arg0, %0 : !shape.shape, !shape.shape -> !shape.shape return %1 : !shape.shape } @@ -76,7 +88,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape { // CHECK: return %arg0 %0 = shape.const_shape [] : !shape.shape - %1 = shape.broadcast %0, %arg0 : !shape.shape, !shape.shape + %1 = shape.broadcast %0, %arg0 : !shape.shape, !shape.shape -> !shape.shape return %1 : !shape.shape } @@ -89,7 +101,7 @@ // CHECK: return %[[CST]] %0 = shape.const_shape [] : !shape.shape %1 = shape.const_shape [1, 2, 3] : !shape.shape - %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape + %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape return %2 : !shape.shape } @@ -101,7 +113,7 @@ // CHECK: shape.broadcast %0 = shape.const_shape [2] : !shape.shape %1 = shape.const_shape [7] : !shape.shape - %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape + %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape return %2 : !shape.shape } diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -49,7 +49,7 @@ func @test_broadcast_fixed() { %0 = shape.const_shape [10, 1, 57, 92] : !shape.shape %1 = shape.const_shape [4, 57, 92] : !shape.shape - %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape + %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return }