diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -255,6 +255,24 @@ let results = (outs Shape_SizeType:$result); } +def Shape_NumElementsOp : Shape_Op<"num_elements", [ + NoSideEffect, + DeclareOpInterfaceMethods]> { + + let summary = "Returns the number of elements for a given shape"; + let description = [{ + Returns the number of elements for a given shape which is the product of its + dimensions. + }]; + + let arguments = (ins Shape_ShapeType:$shape); + let results = (outs Shape_SizeType:$result); + + let assemblyFormat = "attr-dict $shape"; + + let hasFolder = 1; +} + def Shape_ReduceOp : Shape_Op<"reduce", []> { let summary = "Returns an expression reduced over a shape"; let description = [{ diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -318,6 +318,32 @@ return builder.getI64TensorAttr(extents); } +//===----------------------------------------------------------------------===// +// NumElementsOp +//===----------------------------------------------------------------------===// + +OpFoldResult NumElementsOp::fold(ArrayRef operands) { + + // Fold only when argument constant. + Attribute shape = operands[0]; + if (!shape) + return {}; + + APInt product(64, 1); + for (auto value : shape.cast()) + product *= value; + Builder builder(getContext()); + return builder.getIndexAttr(product.getLimitedValue()); +} + +LogicalResult NumElementsOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(SizeType::get(context)); + return success(); +} + //===----------------------------------------------------------------------===// // ShapeOfOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -160,3 +160,25 @@ %cs = shape.index_to_size %ci return %cs : !shape.size } + +// ----- +// Fold number of elements computation. +// CHECK-LABEL: func @num_elements +func @num_elements() -> !shape.size { + // CHECK-NOT: shape.const_shape + %shape = shape.const_shape [4, 5, 6] + // CHECK-NOT: shape.num_elements + %num_elements = shape.num_elements %shape + // CHECK: %[[NUM:.*]] = shape.const_size 120 + // CHECK-NEXT: return %[[NUM]] : !shape.size + return %num_elements : !shape.size +} + +// ----- +// No folding. +// CHECK-LABEL: func @nonfoldable_num_elements +func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size { + // CHECK-NOT: shape.const_{{.*}} + %num_elements = shape.num_elements %shape + return %num_elements : !shape.size +}