diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -179,6 +179,7 @@ let results = (outs Shape_SizeType:$size); let assemblyFormat = "attr-dict $shape"; + let hasFolder = 1; } def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -432,6 +432,19 @@ build(builder, result, shape, dimValue); } +//===----------------------------------------------------------------------===// +// GetSizeOp +//===----------------------------------------------------------------------===// + +OpFoldResult GetSizeOp::fold(ArrayRef operands) { + auto shape = operands[0].dyn_cast_or_null(); + if (!shape) + return {}; + int64_t size = shape.end() - shape.begin(); + Builder builder(getContext()); + return builder.getIndexAttr(size); +} + //===----------------------------------------------------------------------===// // NumElementsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -442,3 +442,39 @@ "consume.witness"(%0) : (!shape.witness) -> () return } + +// ----- + +// Fold `get_size` based on statically shaped tensor. +// CHECK-LABEL: @fold_size +func @fold_size(%arg : tensor<2x3x4xf32>) -> !shape.size { + // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3 + // CHECK-DAG: return %[[RESULT]] : !shape.size + %shape = shape.shape_of %arg : tensor<2x3x4xf32> + %size = shape.get_size %shape + return %size : !shape.size +} + +// ----- + +// Fold `get_size` based on constant shape. +// CHECK-LABEL: @fold_size +func @fold_size() -> !shape.size { + // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 5 + // CHECK-DAG: return %[[RESULT]] : !shape.size + %shape = shape.const_shape [3, 4, 5, 6, 7] + %size = shape.get_size %shape + return %size : !shape.size +} + +// ----- + +// Do not fold `get_size` if shape is dynamic. +// CHECK-LABEL: @dont_fold_size +// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape) -> !shape.size +func @dont_fold_size(%shape : !shape.shape) -> !shape.size { + // CHECK-DAG: %[[RESULT:.*]] = shape.get_size %[[SHAPE]] + // CHECK-DAG: return %[[RESULT]] : !shape.size + %size = shape.get_size %shape + return %size : !shape.size +}