diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -179,6 +179,8 @@ let results = (outs Shape_SizeType:$rank); let assemblyFormat = "attr-dict $shape"; + + let hasFolder = 1; } def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -437,6 +437,19 @@ build(builder, result, shape, dimValue); } +//===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + +OpFoldResult RankOp::fold(ArrayRef operands) { + auto shape = operands[0].dyn_cast_or_null(); + if (!shape) + return {}; + int64_t rank = shape.getNumElements(); + Builder builder(getContext()); + return builder.getIndexAttr(rank); +} + //===----------------------------------------------------------------------===// // NumElementsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -442,3 +442,27 @@ "consume.witness"(%0) : (!shape.witness) -> () return } + +// ----- + +// Fold `rank` based on constant shape. +// CHECK-LABEL: @fold_rank +func @fold_rank() -> !shape.size { + // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 5 + // CHECK-DAG: return %[[RESULT]] : !shape.size + %shape = shape.const_shape [3, 4, 5, 6, 7] + %rank = shape.rank %shape + return %rank : !shape.size +} + +// ----- + +// Do not fold `rank` if shape is dynamic. +// CHECK-LABEL: @dont_fold_rank +// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape) -> !shape.size +func @dont_fold_rank(%shape : !shape.shape) -> !shape.size { + // CHECK-DAG: %[[RESULT:.*]] = shape.rank %[[SHAPE]] + // CHECK-DAG: return %[[RESULT]] : !shape.size + %rank = shape.rank %shape + return %rank : !shape.size +}