diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -53,18 +53,23 @@ class Shape_Op traits = []> : Op; -def Shape_AddOp : Shape_Op<"add", [Commutative, SameOperandsAndResultType]> { - let summary = "Addition of sizes"; +def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> { + let summary = "Addition of sizes and indices"; let description = [{ - Adds two valid sizes as follows: - * lhs + rhs = unknown if either lhs or rhs unknown; - * lhs + rhs = (int)lhs + (int)rhs if known; + Adds two sizes or indices. If either operand is an error it will be + propagated to the result. The operands can be of type `size` or `index`. If + at least one of the operands can hold an error, i.e. if it is of type `size`, + then also the result must be of type `size`. }]; - let arguments = (ins Shape_SizeType:$lhs, Shape_SizeType:$rhs); - let results = (outs Shape_SizeType:$result); + let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs); + let results = (outs Shape_SizeOrIndexType:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict"; + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; + + let verifier = [{ return verifySizeOrIndexOp(*this); }]; } def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> { diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -105,7 +105,7 @@ // ----- -func @get_extent_error_possible(%arg : tensor) -> index { +func @get_extent(%arg : tensor) -> index { %c0 = shape.const_size 0 // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}} %result = shape.get_extent %arg, %c0 : tensor, !shape.size -> index @@ -114,7 +114,7 @@ // ----- -func @mul_error_possible(%lhs : !shape.size, %rhs : index) -> index { +func @mul(%lhs : !shape.size, %rhs : index) -> index { // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}} %result = shape.mul %lhs, %rhs : !shape.size, index -> index return %result : index @@ -122,9 +122,17 @@ // ----- -func @num_elements_error_possible(%arg : !shape.shape) -> index { +func @num_elements(%arg : !shape.shape) -> index { // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}} %result = shape.num_elements %arg : !shape.shape -> index return %result : index } +// ----- + +func @add(%lhs : !shape.size, %rhs : index) -> index { + // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}} + %result = shape.add %lhs, %rhs : !shape.size, index -> index + return %result : index +} + diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -120,6 +120,15 @@ return } +func @add(%size_arg : !shape.size, %index_arg : index) { + %size_sum = shape.add %size_arg, %size_arg + : !shape.size, !shape.size -> !shape.size + %index_sum = shape.add %index_arg, %index_arg : index, index -> index + %mixed_sum = shape.add %size_arg, %index_arg + : !shape.size, index -> !shape.size + return +} + func @const_size() { // CHECK: %c1 = shape.const_size 1 // CHECK: %c2 = shape.const_size 2