diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1262,9 +1262,7 @@ // Operator: reduce_all //===----------------------------------------------------------------------===// def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [ - DeclareOpInterfaceMethods<InferShapedTypeOpInterface, - ["inferReturnTypeComponents"]>, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce All operator"; let description = [{ @@ -1281,15 +1279,19 @@ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// // Operator: reduce_any //===----------------------------------------------------------------------===// def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [ - DeclareOpInterfaceMethods<InferShapedTypeOpInterface, - ["inferReturnTypeComponents"]>, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce Any operator"; let description = [{ @@ -1306,15 +1308,19 @@ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// // Operator: reduce_max //===----------------------------------------------------------------------===// def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [ - DeclareOpInterfaceMethods<InferShapedTypeOpInterface, - ["inferReturnTypeComponents"]>, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce Max operator"; let description = [{ @@ -1331,15 +1337,19 @@ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// // Operator: reduce_min //===----------------------------------------------------------------------===// def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [ - DeclareOpInterfaceMethods<InferShapedTypeOpInterface, - ["inferReturnTypeComponents"]>, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce Min operator"; let description = [{ @@ -1356,15 +1366,19 @@ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// // Operator: reduce_prod //===----------------------------------------------------------------------===// def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [ - DeclareOpInterfaceMethods<InferShapedTypeOpInterface, - ["inferReturnTypeComponents"]>, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce Prod operator"; let description = [{ @@ -1381,15 +1395,19 @@ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// // Operator: reduce_sum //===----------------------------------------------------------------------===// def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [ - DeclareOpInterfaceMethods<InferShapedTypeOpInterface, - ["inferReturnTypeComponents"]>, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce Sum operator"; let description = [{ @@ -1406,6 +1424,12 @@ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -913,10 +913,10 @@ } static LogicalResult ReduceInferReturnTypes( - ShapeAdaptor operandShape, IntegerAttr axis, + ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { if (!operandShape.hasRank()) { - inferredReturnShapes.push_back(ShapedTypeComponents()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); return success(); } @@ -924,7 +924,7 @@ operandShape.getDims(outputShape); int64_t axisVal = axis.getValue().getSExtValue(); outputShape[axisVal] = 1; - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } @@ -934,9 +934,19 @@ ValueShapeRange operands, DictionaryAttr attributes, \ RegionRange regions, \ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \ - return ReduceInferReturnTypes(operands.getShape(0), \ + Type inputType = \ + operands.getType()[0].cast<TensorType>().getElementType(); \ + return ReduceInferReturnTypes(operands.getShape(0), inputType, \ attributes.get("axis").cast<IntegerAttr>(), \ inferredReturnShapes); \ + } \ + \ + bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \ + if (l.size() != r.size() || l.size() != 1) \ + return false; \ + if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \ + return false; \ + return succeeded(verifyCompatibleShape(l[0], r[0])); \ } REDUCE_SHAPE_INFER(tosa::ReduceAllOp)