diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -228,17 +228,15 @@ } def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> { - let summary = "Gets the specified extent from a shape"; + let summary = "Gets the specified extent from a shape or extent tensor"; let description = [{ - Gets the extent indexed by `dim` from `shape`. - If the shape is an error, it returns an error size. + Gets the extent indexed by `dim` from the `shape` operand. If the shape is + an error then it returns an error size. }]; - let arguments = (ins - Shape_ShapeType:$shape, - Shape_SizeType:$dim - ); + let arguments = (ins Shape_ShapeOrExtentTensorType:$shape, + Shape_SizeType:$dim); let results = (outs Shape_SizeType:$extent); - let assemblyFormat = "$shape `,` $dim attr-dict"; + let assemblyFormat = "$shape `,` $dim `:` type($shape) attr-dict"; let builders = [ // Builder that allows passing a constant dimension as a simple integer. diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -139,7 +139,7 @@ // CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32> // CHECK: return %[[RESULT]] : index %shape = shape.shape_of %arg : tensor<2x3xf32> - %result = shape.get_extent %shape, %idx + %result = shape.get_extent %shape, %idx : !shape.shape return %result : !shape.size } @@ -154,7 +154,7 @@ // CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor // CHECK: return %[[RESULT]] : index %shape = shape.from_extent_tensor %extents : tensor - %result = shape.get_extent %shape, %idx + %result = shape.get_extent %shape, %idx : !shape.shape return %result : !shape.size } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -239,9 +239,9 @@ // CHECK-LABEL: func @basic func @basic() -> !shape.size { // CHECK: shape.const_size 2 - %0 = shape.const_shape [0, 1, 2] : !shape.shape + %0 = shape.const_shape [0, 1, 2] : tensor %c2 = shape.const_size 2 - %1 = shape.get_extent %0, %c2 + %1 = shape.get_extent %0, %c2 : tensor return %1 : !shape.size } @@ -252,9 +252,9 @@ func @out_of_bounds() -> !shape.size { // CHECK: shape.const_shape // CHECK: shape.get_extent - %0 = shape.const_shape [0, 1, 2] : !shape.shape + %0 = shape.const_shape [0, 1, 2] : tensor %c3 = shape.const_size 3 - %1 = shape.get_extent %0, %c3 + %1 = shape.get_extent %0, %c3 : tensor return %1 : !shape.size } @@ -262,10 +262,10 @@ // Should not fold. // CHECK-LABEL: func @not_const -func @not_const(%arg0: !shape.shape) -> !shape.size { +func @not_const(%arg0: tensor) -> !shape.size { // CHECK: shape.get_extent %c3 = shape.const_size 3 - %0 = shape.get_extent %arg0, %c3 + %0 = shape.get_extent %arg0, %c3 : tensor return %0 : !shape.size } diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -161,3 +161,15 @@ %result = shape.shape_eq %a, %b : tensor, !shape.shape return %result : i1 } + +func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size { + %c0 = shape.const_size 0 + %result = shape.get_extent %arg, %c0 : !shape.shape + return %result : !shape.size +} + +func @get_extent_on_extent_tensor(%arg : tensor) -> !shape.size { + %c0 = shape.const_size 0 + %result = shape.get_extent %arg, %c0 : tensor + return %result : !shape.size +}