diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -417,6 +417,46 @@ let hasCanonicalizer = 1; } +def Shape_RefineOp : Shape_Op<"with_shape", [NoSideEffect]> { + let summary = "Returns ValueShape with given shape"; + let description = [{ + Returns ValueShape with the shape updated to match the shape operand. If the + ValueShape and given `shape` are non-conformant, then the returned + ValueShape will represent the error of the mismatch. Similarly if either + inputs are in an error state, then an error is popagated. + + Usage: + %0 = shape.with_shape %1, %2 : tensor<...>, !shape.shape + + This is used where, for example, where one combines shape function + calculations and/or call one shape function from another. E.g., + + ```mlir + func @shape_foobah(%a: !shape.value_shape, + %b: !shape.value_shape, + %c: !shape.value_shape) -> !shape.shape { + %0 = call @shape_foo(%a, %b) : + (!shape.value_shape, !shape.value_shape) -> !shape.shape + %1 = shape.with_shape %%b, %0 : !shape.value_shape, !shape.shape + %2 = call @shape_bah(%c, %1) : + (!shape.value_shape, !shape.value_shape) -> !shape.shape + return %2 : !shape.shape + } + ``` + + Note: this op, even given in non-error cases the input ValueShape is + conformant and output is again, need not be a refinement of the type of the + in the ValueShape and the result may be less specified than the operand. If + join behavior is desired then a join op should be used. + }]; + + let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$operand, + Shape_ShapeType:$shape); + let results = (outs Shape_ValueShapeType:$result); + + let assemblyFormat = "operands attr-dict `:` type($operand) `,` type($shape)"; +} + def Shape_YieldOp : Shape_Op<"yield", [HasParent<"ReduceOp">, NoSideEffect, diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -154,3 +154,16 @@ %result = shape.shape_eq %a, %b : tensor, !shape.shape return %result : i1 } + +func @shape_equal_shapes(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape { + %0 = shape.shape_of %a : !shape.value_shape + %1 = shape.shape_of %b : !shape.value_shape + %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + return %2 : !shape.shape +} +func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape { + %0 = shape.shape_of %a : !shape.value_shape + %1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape + %2 = call @shape_equal_shapes(%a, %1) : (!shape.value_shape, !shape.value_shape) -> !shape.shape + return %2 : !shape.shape +}