diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -53,18 +53,25 @@ class Shape_Op traits = []> : Op; -def Shape_AddOp : Shape_Op<"add", [Commutative, SameOperandsAndResultType]> { - let summary = "Addition of sizes"; +def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> { + let summary = "Addition of sizes and indices"; let description = [{ - Adds two valid sizes as follows: - * lhs + rhs = unknown if either lhs or rhs unknown; - * lhs + rhs = (int)lhs + (int)rhs if known; + Adds two sizes or indices. If either operand is an error it will be + propagated to the result. The operands can be of type `size` or `index`. If + at least one of the operands can hold an error, i.e. if it is of type `size`, + then also the result must be of type `size`. If error propagation is not + possible because both operands are of type `index` then the result must also + be of type `index`. }]; - let arguments = (ins Shape_SizeType:$lhs, Shape_SizeType:$rhs); - let results = (outs Shape_SizeType:$result); + let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs); + let results = (outs Shape_SizeOrIndexType:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict"; + let assemblyFormat = [{ + $lhs `,` $rhs `:` type($lhs) `,` type($rhs) `->` type($result) attr-dict + }]; + + let verifier = [{ return verifySizeOrIndexOp(*this); }]; } def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> { diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -154,3 +154,19 @@ return %result : index } +// ----- + +func @add_error_free(%arg : index) -> !shape.size { + // expected-error@+1 {{if none of the operands can hold error values then the result must be of type `index`}} + %result = shape.add %arg, %arg : index, index -> !shape.size + return %result : !shape.size +} + +// ----- + +func @add_error_possible(%lhs : !shape.size, %rhs : index) -> index { + // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}} + %result = shape.add %lhs, %rhs : !shape.size, index -> index + return %result : index +} + diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -120,6 +120,15 @@ return } +func @add(%size_arg : !shape.size, %index_arg : index) { + %size_sum = shape.add %size_arg, %size_arg + : !shape.size, !shape.size -> !shape.size + %index_sum = shape.add %index_arg, %index_arg : index, index -> index + %mixed_sum = shape.add %size_arg, %index_arg + : !shape.size, index -> !shape.size + return +} + func @const_size() { // CHECK: %c1 = shape.const_size 1 // CHECK: %c2 = shape.const_size 2