diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -340,19 +340,19 @@ let summary = "Returns the number of elements for a given shape"; let description = [{ Returns the number of elements for a given shape which is the product of its - dimensions. - - ```mlir - %product = shape.mul %lhs, %rhs - ``` + extents. If the argument is of type `shape` then the result will be of type + `size` and potential errors will be propagated. Otherwise, if the argument + is and extent tensor `tensor` then the result will be of type + `index`. }]; - let arguments = (ins Shape_ShapeType:$shape); - let results = (outs Shape_SizeType:$result); + let arguments = (ins Shape_ShapeOrExtentTensorType:$shape); + let results = (outs Shape_SizeOrIndexType:$result); - let assemblyFormat = "$shape attr-dict"; + let assemblyFormat = "$shape `:` type($shape) `->` type($result) attr-dict"; let hasFolder = 1; + let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; } def Shape_ReduceOp : Shape_Op<"reduce", diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -217,7 +217,7 @@ // CHECK-NOT: shape.const_shape %shape = shape.const_shape [4, 5, 6] : !shape.shape // CHECK-NOT: shape.num_elements - %num_elements = shape.num_elements %shape + %num_elements = shape.num_elements %shape : !shape.shape -> !shape.size // CHECK: %[[NUM:.*]] = shape.const_size 120 // CHECK-NEXT: return %[[NUM]] : !shape.size return %num_elements : !shape.size @@ -229,7 +229,7 @@ // CHECK-LABEL: func @nonfoldable_num_elements func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size { // CHECK-NOT: shape.const_{{.*}} - %num_elements = shape.num_elements %shape + %num_elements = shape.num_elements %shape : !shape.shape -> !shape.size return %num_elements : !shape.size } diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -170,3 +170,19 @@ return %result : index } +// ----- + +func @num_elements_error_free(%arg : tensor) -> !shape.size { + // expected-error@+1 {{if none of the operands can hold error values then the result must be of type `index`}} + %result = shape.num_elements %arg : tensor -> !shape.size + return %result : !shape.size +} + +// ----- + +func @num_elements_error_possible(%arg : !shape.shape) -> index { + // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}} + %result = shape.num_elements %arg : !shape.shape -> index + return %result : index +} + diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -204,3 +204,14 @@ return } +func @num_elements_extent_tensor(%arg : tensor) -> index { + %result = shape.num_elements %arg : tensor -> index + return %result : index +} + +func @num_elements_shape(%arg : !shape.shape) -> !shape.size { + %result = shape.num_elements %arg : !shape.shape -> !shape.size + return %result : !shape.size +} + + diff --git a/mlir/test/Dialect/Shape/shape-to-shape.mlir b/mlir/test/Dialect/Shape/shape-to-shape.mlir --- a/mlir/test/Dialect/Shape/shape-to-shape.mlir +++ b/mlir/test/Dialect/Shape/shape-to-shape.mlir @@ -1,9 +1,9 @@ // RUN: mlir-opt -shape-to-shape-lowering -split-input-file %s | FileCheck %s -// CHECK-LABEL: func @num_elements_to_reduce( -// CHECK-SAME: [[ARG:%.*]]: !shape.shape) -> !shape.size { +// CHECK-LABEL: func @num_elements_to_reduce +// CHECK-SAME: ([[ARG:%.*]]: !shape.shape) -> !shape.size func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size { - %num_elements = shape.num_elements %shape + %num_elements = shape.num_elements %shape : !shape.shape -> !shape.size return %num_elements : !shape.size } // CHECK: [[C1:%.*]] = shape.const_size 1