diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1431,6 +1431,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; let arguments = (ins Tosa_Tensor:$input1, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -864,10 +864,14 @@ currRhsDim < rhsShape.size()) { if (lhsSize < rhsSize) { currLhsDim++; - lhsSize *= lhsShape[currLhsDim]; + if (currLhsDim < lhsShape.size()) { + lhsSize *= lhsShape[currLhsDim]; + } } else { currRhsDim++; - rhsSize *= rhsShape[currRhsDim]; + if (currRhsDim < rhsShape.size()) { + rhsSize *= rhsShape[currRhsDim]; + } } } if (lhsSize == rhsSize) { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -700,6 +700,21 @@ return success(); } +mlir::LogicalResult tosa::ReshapeOp::verify() { + ShapedType inputType = getInput1().getType().cast(); + ShapedType outputType = getType().cast(); + + if (inputType.hasStaticShape() && outputType.hasStaticShape()) { + int64_t inputElementsNum = inputType.getNumElements(); + int64_t outputElementsNum = outputType.getNumElements(); + if (inputElementsNum != outputElementsNum) { + return emitOpError() << "Cannot reshape " << inputElementsNum + << " elements into " << outputElementsNum; + } + } + return mlir::success(); +} + LogicalResult tosa::TransposeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,