diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // // This file defines the operation set for the TOSA dialect as defined in -// the TOSA specfication (https://developer.mlplatform.org/w/tosa/). +// the TOSA specfication (https://developer.mlplatform.org/w/tosa/). // //===----------------------------------------------------------------------===// @@ -58,7 +58,7 @@ //===----------------------------------------------------------------------===// def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [ DeclareOpInterfaceMethods, + ["inferReturnTypeComponents"]>, NoSideEffect]> { let summary = "Performs max pooling on the input."; @@ -275,6 +275,8 @@ let results = (outs Tosa_Tensor4D:$output ); + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -326,9 +328,9 @@ let description = [{ Clamp to an arbitrary minimum and maximum value. - Maximum and minimum values are specified as values in the range of the + Maximum and minimum values are specified as values in the range of the input type. - No zero point subtraction is done to the values, thus to clamp to the zero + No zero point subtraction is done to the values, thus to clamp to the zero point value, the zero point itself should be supplied as the minimum value. }]; @@ -488,7 +490,7 @@ let description = [{ Elementwise bitwise AND of input1 and input2. Axis of size 1 - will be broadcast as necessary. + will be broadcast as necessary. }]; let arguments = (ins @@ -1379,7 +1381,7 @@ let summary = "Concatenates tensors along one dimension."; let description = [{ - Concatenate a variadic amount of tensors along a given axis. No data + Concatenate a variadic amount of tensors along a given axis. No data conversion happens during a concat operation. }]; @@ -1405,7 +1407,7 @@ let summary = "Pads a tensor with value specified."; let description = [{ - Pads a tensor along borders of each dimension with pad_value. + Pads a tensor along borders of each dimension with pad_value. }]; let arguments = (ins @@ -1510,7 +1512,7 @@ //===----------------------------------------------------------------------===// def Tosa_TileOp: Tosa_Op<"tile", [ DeclareOpInterfaceMethods, + ["inferReturnTypeComponents"]>, NoSideEffect]> { let summary = "Tile operator"; @@ -1534,7 +1536,7 @@ //===----------------------------------------------------------------------===// def Tosa_TransposeOp : Tosa_Op<"transpose", [ DeclareOpInterfaceMethods, + ["inferReturnTypeComponents"]>, NoSideEffect]> { let summary = "Transpose operator"; @@ -1565,7 +1567,7 @@ //===----------------------------------------------------------------------===// def Tosa_GatherOp : Tosa_Op<"gather", [ DeclareOpInterfaceMethods, + ["inferReturnTypeComponents"]>, NoSideEffect]> { let summary = "Gather operation,"; @@ -1697,7 +1699,7 @@ //===----------------------------------------------------------------------===// // Operator: rescale //===----------------------------------------------------------------------===// -def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect, +def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Tosa rescale operator"; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -614,6 +614,41 @@ results.insert(context); } +struct MaxPool2dIsNoOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, + PatternRewriter &rewriter) const override { + Value input = op.input(); + Value output = op.output(); + ShapedType inputType = input.getType().cast(); + ShapedType outputType = output.getType().cast(); + + if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) { + return failure(); + } + + // If the output and input shapes are 1x1, then this is a no op. + ArrayRef outputShape = outputType.getShape(); + if (outputShape[1] != 1 || outputShape[2] != 1) { + return failure(); + } + + ArrayRef inputShape = inputType.getShape(); + if (inputShape[1] != 1 || inputShape[2] != 1) { + return failure(); + } + + rewriter.replaceOp(op, input); + return success(); + } +}; + +void MaxPool2dOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -181,7 +181,17 @@ return %0 : tensor<4x10x10x6xf32> } -// ---- +// ----- + +// CHECK-LABEL: @max_pool2d_is_noop +func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32> { + // CHECK-NOT: "tosa.max_pool2d" + // CHECK: return %arg0 + %0 = "tosa.max_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32> + return %0 : tensor<10x1x1x3xf32> +} + +// ----- // CHECK-LABEL: @pad_noop func @pad_noop(%arg0: tensor) -> tensor { @@ -191,7 +201,7 @@ return %1 : tensor } -// ---- +// ----- // CHECK-LABEL: @pad_determine_val_i32 func @pad_determine_val_i32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { @@ -202,7 +212,7 @@ return %1 : tensor } -// ---- +// ----- // CHECK-LABEL: @pad_determine_val_f32 func @pad_determine_val_f32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { @@ -213,7 +223,7 @@ return %1 : tensor } -// ---- +// ----- // CHECK-LABEL: @pad_determine_val_quant func @pad_determine_val_quant(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor {