diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -129,8 +129,8 @@ return any_of(dSizes, [](int64_t dSize) { return isDynamic(dSize); }); } - /// Return the number of elements present in the given shape. Asserts that - /// all dimensions are static. + /// Return the number of elements present in the given shape. Returns + /// kDynamic if at least one dimension is dynamic. static int64_t getNumElements(ArrayRef shape); }]; @@ -140,10 +140,9 @@ return $_type.getShape().size(); } - /// If it has static shape, return the number of elements. Otherwise, abort. + /// If it has static shape, return the number of elements. Returns kDynamic + /// if at least one dimension is dynamic. int64_t getNumElements() const { - assert(hasStaticShape() - && "cannot get element count of dynamically shaped type"); return ::mlir::RankedShapedType::getNumElements($_type.getShape()); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1228,13 +1228,6 @@ setNameFn(getResult(), "reshape"); } -static int64_t getNumElements(RankedShapedType type) { - int64_t numElements = 1; - for (auto dim : type.getShape()) - numElements *= dim; - return numElements; -} - LogicalResult ReshapeOp::verify() { TensorType operandType = getSource().getType().cast(); TensorType resultType = getResult().getType().cast(); @@ -1251,7 +1244,8 @@ if (resultRankedType) { if (operandRankedType && resultRankedType.hasStaticShape() && operandRankedType.hasStaticShape()) { - if (getNumElements(operandRankedType) != getNumElements(resultRankedType)) + if (operandRankedType.getNumElements() != + resultRankedType.getNumElements()) return emitOpError("source and destination tensor should have the " "same number of elements"); } diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp --- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp @@ -28,7 +28,8 @@ int64_t RankedShapedType::getNumElements(ArrayRef shape) { int64_t num = 1; for (int64_t dim : shape) { - assert(!RankedShapedType::isDynamic(dim) && "expected only static dims"); + if (RankedShapedType::isDynamic(dim)) + return RankedShapedType::kDynamic; num *= dim; assert(num >= 0 && "integer overflow in element count computation"); }