diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1262,9 +1262,7 @@
 // Operator: reduce_all
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+    InferTensorType, Pure]> {
   let summary = "Reduce All operator";
 
   let description = [{
@@ -1281,15 +1279,19 @@
   );
 
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    /// Returns true when two result types are compatible for this op;
+    /// Method used by InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
 // Operator: reduce_any
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+    InferTensorType, Pure]> {
   let summary = "Reduce Any operator";
 
   let description = [{
@@ -1306,15 +1308,19 @@
   );
 
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    /// Returns true when two result types are compatible for this op;
+    /// Method used by InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
 // Operator: reduce_max
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+    InferTensorType, Pure]> {
   let summary = "Reduce Max operator";
 
   let description = [{
@@ -1331,15 +1337,19 @@
   );
 
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    /// Returns true when two result types are compatible for this op;
+    /// Method used by InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
 // Operator: reduce_min
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+    InferTensorType, Pure]> {
   let summary = "Reduce Min operator";
 
   let description = [{
@@ -1356,15 +1366,19 @@
   );
 
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    /// Returns true when two result types are compatible for this op;
+    /// Method used by InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
 // Operator: reduce_prod
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+    InferTensorType, Pure]> {
   let summary = "Reduce Prod operator";
 
   let description = [{
@@ -1381,15 +1395,19 @@
   );
 
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    /// Returns true when two result types are compatible for this op;
+    /// Method used by InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
 // Operator: reduce_sum
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+    InferTensorType, Pure]> {
   let summary = "Reduce Sum operator";
 
   let description = [{
@@ -1406,6 +1424,12 @@
   );
 
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    /// Returns true when two result types are compatible for this op;
+    /// Method used by InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -913,10 +913,10 @@
 }
 
 static LogicalResult ReduceInferReturnTypes(
-    ShapeAdaptor operandShape, IntegerAttr axis,
+    ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   if (!operandShape.hasRank()) {
-    inferredReturnShapes.push_back(ShapedTypeComponents());
+    inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
     return success();
   }
 
@@ -924,7 +924,7 @@
   operandShape.getDims(outputShape);
   int64_t axisVal = axis.getValue().getSExtValue();
   outputShape[axisVal] = 1;
-  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
   return success();
 }
 
@@ -934,9 +934,19 @@
       ValueShapeRange operands, DictionaryAttr attributes,                     \
       RegionRange regions,                                                     \
       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
-    return ReduceInferReturnTypes(operands.getShape(0),                        \
+    Type inputType =                                                           \
+        operands.getType()[0].cast<TensorType>().getElementType();             \
+    return ReduceInferReturnTypes(operands.getShape(0), inputType,             \
                                   attributes.get("axis").cast<IntegerAttr>(),  \
                                   inferredReturnShapes);                       \
+  }                                                                            \
+                                                                               \
+  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) {                 \
+    if (l.size() != r.size() || l.size() != 1)                                 \
+      return false;                                                            \
+    if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0]))              \
+      return false;                                                            \
+    return succeeded(verifyCompatibleShape(l[0], r[0]));                       \
   }
 
 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)