diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1441,8 +1441,7 @@ // Operator: concat //===----------------------------------------------------------------------===// def Tosa_ConcatOp : Tosa_Op<"concat", [ - InferTensorType, - Pure]> { + InferTensorType, Pure]> { let summary = "Concatenates tensors along one dimension."; let description = [{ @@ -1503,9 +1502,7 @@ // Operator: reshape //===----------------------------------------------------------------------===// def Tosa_ReshapeOp: Tosa_Op<"reshape", [ - DeclareOpInterfaceMethods, - Pure]> { + InferTensorType, Pure]> { let summary = "Reshape operator"; let description = [{ @@ -1526,6 +1523,12 @@ let results = (outs Tosa_RankedTensor:$output ); + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -674,19 +674,27 @@ return success(); } +bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != r.size() || l.size() != 1) + return false; + return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]); +} + LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ReshapeOpAdaptor adaptor(operands, attributes); ShapeAdaptor inputShape = operands.getShape(0); + Type inputType = operands.getType()[0].cast().getElementType(); llvm::SmallVector newShapeValue = convertToMlirShape(adaptor.getNewShape()); // We cannot infer from the total number of elements so we must take the // shape attribute as exact. if (!inputShape.hasRank() || !inputShape.hasStaticShape()) { - inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue)); + inferredReturnShapes.push_back( + ShapedTypeComponents(newShapeValue, inputType)); return success(); } @@ -707,7 +715,8 @@ val = numElements / staticMul; } - inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue)); + inferredReturnShapes.push_back( + ShapedTypeComponents(newShapeValue, inputType)); return success(); } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -128,3 +128,11 @@ %0 = "tosa.reduce_prod"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> return } + +// ----- + +func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () { + // expected-error@+1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}} + %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<13x21x3xf32>) -> tensor<13x21x3x1xi32> + return +}