diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -89,11 +89,13 @@ let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, Shape_ShapeOrExtentTensorType:$rhs, OptionalAttr:$error); - let results = (outs Shape_ShapeType:$result); + let results = (outs Shape_ShapeOrExtentTensorType:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)"; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; let hasFolder = 1; + + let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; } def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> { diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -54,7 +54,7 @@ // CHECK: shape.const_shape [7, 2] : !shape.shape %0 = shape.const_shape [1, 2] : !shape.shape %1 = shape.const_shape [7, 1] : !shape.shape - %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape + %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape return %2 : !shape.shape } @@ -65,7 +65,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape { // CHECK: return %arg0 %0 = shape.const_shape [] : !shape.shape - %1 = shape.broadcast %arg0, %0 : !shape.shape, !shape.shape + %1 = shape.broadcast %arg0, %0 : !shape.shape, !shape.shape -> !shape.shape return %1 : !shape.shape } @@ -76,7 +76,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape { // CHECK: return %arg0 %0 = shape.const_shape [] : !shape.shape - %1 = shape.broadcast %0, %arg0 : !shape.shape, !shape.shape + %1 = shape.broadcast %0, %arg0 : !shape.shape, !shape.shape -> !shape.shape return %1 : !shape.shape } @@ -89,7 +89,7 @@ // CHECK: return %[[CST]] %0 = shape.const_shape [] : !shape.shape %1 = shape.const_shape [1, 2, 3] : !shape.shape - %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape + %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape return %2 : !shape.shape } @@ -101,7 +101,7 @@ // CHECK: shape.broadcast %0 = shape.const_shape [2] : !shape.shape %1 = shape.const_shape [7] : !shape.shape - %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape + %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape return %2 : !shape.shape } diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -128,3 +128,19 @@ return %result : index } +// ----- + +func @broadcast_error_possible(%arg0 : !shape.shape, %arg1 : !shape.shape) -> tensor { + // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}} + %result = shape.broadcast %arg0, %arg1 : !shape.shape, !shape.shape -> tensor + return %result : tensor +} + + +// ----- + +func @broadcast_error_possible(%arg0 : !shape.shape, %arg1 : tensor) -> tensor { + // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}} + %result = shape.broadcast %arg0, %arg1 : !shape.shape, tensor -> tensor + return %result : tensor +} diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -49,11 +49,18 @@ func @test_broadcast_fixed() { %0 = shape.const_shape [10, 1, 57, 92] : !shape.shape %1 = shape.const_shape [4, 57, 92] : !shape.shape - %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape + %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return } +func @test_broadcast_extents() -> tensor { + %0 = shape.const_shape [10, 1, 57, 92] : tensor + %1 = shape.const_shape [4, 57, 92] : tensor + %2 = shape.broadcast %0, %1 : tensor, tensor -> tensor + return %2 : tensor +} + func @test_shape_any_fixed() { %0 = shape.const_shape [4, 57, 92] : !shape.shape %1 = shape.const_shape [4, 57, 92] : !shape.shape