diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1262,9 +1262,7 @@ // Operator: reduce_all //===----------------------------------------------------------------------===// def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [ - DeclareOpInterfaceMethods, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce All operator"; let description = [{ @@ -1281,15 +1279,19 @@ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// // Operator: reduce_any //===----------------------------------------------------------------------===// def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [ - DeclareOpInterfaceMethods, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce Any operator"; let description = [{ @@ -1306,15 +1308,19 @@ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// // Operator: reduce_max //===----------------------------------------------------------------------===// def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [ - DeclareOpInterfaceMethods, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce Max operator"; let description = [{ @@ -1331,15 +1337,19 @@ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// // Operator: reduce_min //===----------------------------------------------------------------------===// def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [ - DeclareOpInterfaceMethods, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce Min operator"; let description = [{ @@ -1356,15 +1366,19 @@ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// // Operator: reduce_prod //===----------------------------------------------------------------------===// def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [ - DeclareOpInterfaceMethods, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce Prod operator"; let description = [{ @@ -1381,15 +1395,19 @@ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// // Operator: reduce_sum //===----------------------------------------------------------------------===// def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [ - DeclareOpInterfaceMethods, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce Sum operator"; let description = [{ @@ -1406,6 +1424,12 @@ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -913,10 +913,10 @@ } static LogicalResult ReduceInferReturnTypes( - ShapeAdaptor operandShape, IntegerAttr axis, + ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl &inferredReturnShapes) { if (!operandShape.hasRank()) { - inferredReturnShapes.push_back(ShapedTypeComponents()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); return success(); } @@ -924,7 +924,7 @@ operandShape.getDims(outputShape); int64_t axisVal = axis.getValue().getSExtValue(); outputShape[axisVal] = 1; - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } @@ -934,9 +934,19 @@ ValueShapeRange operands, DictionaryAttr attributes, \ RegionRange regions, \ SmallVectorImpl &inferredReturnShapes) { \ - return ReduceInferReturnTypes(operands.getShape(0), \ + Type inputType = \ + operands.getType()[0].cast().getElementType(); \ + return ReduceInferReturnTypes(operands.getShape(0), inputType, \ attributes.get("axis").cast(), \ inferredReturnShapes); \ + } \ + \ + bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \ + if (l.size() != r.size() || l.size() != 1) \ + return false; \ + if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \ + return false; \ + return succeeded(verifyCompatibleShape(l[0], r[0])); \ } REDUCE_SHAPE_INFER(tosa::ReduceAllOp) diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -96,3 +96,49 @@ %2 = "tosa.fully_connected"(%1, %0, %arg1) : (tensor<273x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<273x2xf32> return %2 : tensor<273x2xf32> } + +// ----- + +// CHECK: @test_reduce_sum_type_mismatch +func.func @test_reduce_sum_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () { + // expected-error@+1 {{'tosa.reduce_sum' op inferred type(s) 'tensor<1x3x4x5xf32>' are incompatible with return type(s) of operation 'tensor<1x3x4x5xi32>'}} + %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<2x3x4x5xf32>) -> tensor<1x3x4x5xi32> + + %1 = "tosa.reduce_max"(%arg0) {axis = 3 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x1xi32> + + %2 = "tosa.reduce_min"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x1x4x5xi32> + + %3 = "tosa.reduce_prod"(%arg0) {axis = 2 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x2x5xi32> + + return +} + +// ----- + +// CHECK: @test_reduce_max_type_mismatch +func.func @test_reduce_max_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () { + // expected-error@+1 {{'tosa.reduce_max' op inferred type(s) 'tensor<2x3x4x1xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x1xi32>'}} + %0 = "tosa.reduce_max"(%arg0) {axis = 3 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x1xi32> + + return +} + +// ----- + +// CHECK: @test_reduce_min_type_mismatch +func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () { + // expected-error@+1 {{'tosa.reduce_min' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x1x4x5xi32>'}} + %0 = "tosa.reduce_min"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x1x4x5xi32> + + return +} + +// ----- + +// CHECK: @test_reduce_prod_type_mismatch +func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () { + // expected-error@+1 {{'tosa.reduce_prod' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x5xf32>'}} + %0 = "tosa.reduce_prod"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> + + return +}