diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -100,7 +100,11 @@ corresponds to `Value` in the compiler) and a shape. Conceptually this is a tuple of a value (potentially unknown) and `shape.type`. The value and shape can either or both be unknown. If both the `value` and `shape` are known, - then the shape of `value` is conformant with `shape`. + then the shape of `value` is conformant with `shape`. That is, the shape of + the value conforms to the shape of the ValueShape, so that if we have + ValueShape `(value, shape)` then `join(shape_of(value), shape)` would be + error free and in particular it means that if both are statically known, + then they are equal. }]; } diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -432,6 +432,49 @@ let hasCanonicalizer = 1; } +def Shape_WithOp : Shape_Op<"with_shape", [NoSideEffect]> { + let summary = "Returns ValueShape with given shape"; + let description = [{ + Returns ValueShape with the shape updated to match the shape operand. That + is a new ValueShape tuple is created with value equal to `operand`'s + value and shape equal to `shape`. If the ValueShape and given `shape` are + non-conformant, then the returned ValueShape will represent an error of + this mismatch. Similarly if either inputs are in an error state, then an + error is popagated. + + Usage: + %0 = shape.with_shape %1, %2 : tensor<...>, !shape.shape + + This is used, for example, where one combines shape function calculations + and/or call one shape function from another. E.g., + + ```mlir + func @shape_foobah(%a: !shape.value_shape, + %b: !shape.value_shape, + %c: !shape.value_shape) -> !shape.shape { + %0 = call @shape_foo(%a, %b) : + (!shape.value_shape, !shape.value_shape) -> !shape.shape + %1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape + %2 = call @shape_bah(%c, %1) : + (!shape.value_shape, !shape.value_shape) -> !shape.shape + return %2 : !shape.shape + } + ``` + + This op need not be a refinement of the shape. In non-error cases the input + ValueShape's value and shape are conformant and so too for the output, but + the result may be less specified than `operand`'s shape as `shape` is + merely used to construct the new ValueShape. If join behavior is desired + then a join op should be used. + }]; + + let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$operand, + Shape_ShapeType:$shape); + let results = (outs Shape_ValueShapeType:$result); + + let assemblyFormat = "operands attr-dict `:` type($operand) `,` type($shape)"; +} + def Shape_YieldOp : Shape_Op<"yield", [HasParent<"ReduceOp">, NoSideEffect, diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -221,4 +221,17 @@ return %result : !shape.size } - +// Testing nvoking shape function from another. shape_equal_shapes is merely +// a trivial helper function to invoke elsewhere. +func @shape_equal_shapes(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape { + %0 = shape.shape_of %a : !shape.value_shape -> !shape.shape + %1 = shape.shape_of %b : !shape.value_shape -> !shape.shape + %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + return %2 : !shape.shape +} +func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape { + %0 = shape.shape_of %a : !shape.value_shape -> !shape.shape + %1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape + %2 = call @shape_equal_shapes(%a, %1) : (!shape.value_shape, !shape.value_shape) -> !shape.shape + return %2 : !shape.shape +}