diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -194,13 +194,13 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> { let summary = "Gets the rank of a shape"; let description = [{ - Returns the rank of the shape, i.e. the number of extents. + Returns the rank of the shape or extent tensor, i.e. the number of extents. }]; - let arguments = (ins Shape_ShapeType:$shape); + let arguments = (ins Shape_ShapeOrExtentTensorType:$shape); let results = (outs Shape_SizeType:$rank); - let assemblyFormat = "attr-dict $shape"; + let assemblyFormat = "$shape `:` type($shape) attr-dict"; let hasFolder = 1; let hasCanonicalizer = 1; diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -124,7 +124,7 @@ // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]] // CHECK-DAG: return %[[RESULT]] : index - %rank = shape.rank %shape + %rank = shape.rank %shape : !shape.shape return %rank : !shape.size } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -499,7 +499,7 @@ // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 5 // CHECK-DAG: return %[[RESULT]] : !shape.size %shape = shape.const_shape [3, 4, 5, 6, 7] - %rank = shape.rank %shape + %rank = shape.rank %shape : !shape.shape return %rank : !shape.size } @@ -511,7 +511,7 @@ func @dont_fold_rank(%shape : !shape.shape) -> !shape.size { // CHECK-DAG: %[[RESULT:.*]] = shape.rank %[[SHAPE]] // CHECK-DAG: return %[[RESULT]] : !shape.size - %rank = shape.rank %shape + %rank = shape.rank %shape : !shape.shape return %rank : !shape.size } @@ -520,11 +520,11 @@ // Canonicalize `rank` when shape is derived from ranked tensor. // CHECK-LABEL: @canonicalize_rank func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size { -// CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3 -// CHECK-DAG: return %[[RESULT]] : !shape.size -%shape = shape.shape_of %arg : tensor<1x2x?xf32> -%rank = shape.rank %shape -return %rank : !shape.size + // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3 + // CHECK-DAG: return %[[RESULT]] : !shape.size + %shape = shape.shape_of %arg : tensor<1x2x?xf32> + %rank = shape.rank %shape : !shape.shape + return %rank : !shape.size } // ----- @@ -533,12 +533,12 @@ // CHECK-LABEL: @dont_canonicalize_rank // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size { -// CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> -// CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]] -// CHECK-DAG: return %[[SIZE]] : !shape.size -%shape = shape.shape_of %arg : tensor<*xf32> -%rank = shape.rank %shape -return %rank : !shape.size + // CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> + // CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]] + // CHECK-DAG: return %[[SIZE]] : !shape.size + %shape = shape.shape_of %arg : tensor<*xf32> + %rank = shape.rank %shape : !shape.shape + return %rank : !shape.size } // Canonicalize redundant conversion from `index` to `size` and back. diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -130,10 +130,16 @@ } func @rank(%shape : !shape.shape) -> !shape.size { - %rank = shape.rank %shape + %rank = shape.rank %shape : !shape.shape return %rank : !shape.size } +func @rank_on_extent_tensor(%shape : tensor) -> !shape.size { + %rank = shape.rank %shape : tensor + return %rank : !shape.size +} + + func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 { %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape return %result : i1