diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1634,6 +1634,8 @@ let results = (outs Tosa_Tensor4D:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/IR/BuiltinTypes.h" @@ -846,6 +847,35 @@ return {}; } +OpFoldResult ResizeOp::fold(ArrayRef operands) { + SmallVector scale, offset, border; + getValuesFromIntArrayAttribute(getScale(), scale); + getValuesFromIntArrayAttribute(getOffset(), offset); + getValuesFromIntArrayAttribute(getBorder(), border); + + // Check unit scaling. + if (scale[0] != scale[1] || scale[2] != scale[3]) { + return {}; + } + + // There should be no offset. + if (offset[0] != 0 || offset[1] != 0) { + return {}; + } + + // There should be no border. + if (border[0] != 0 || border[1] != 0) { + return {}; + } + + auto input = getInput(); + auto inputTy = input.getType().cast(); + auto resultTy = getType().cast(); + if (inputTy != resultTy) return {}; + + return input; +} + OpFoldResult ReverseOp::fold(ArrayRef operands) { auto operand = getInput(); auto operandTy = operand.getType().cast(); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -408,3 +408,21 @@ %1 = "tosa.reshape"(%0) {new_shape = [1]} : (tensor<1x1xi1>) -> tensor<1xi1> return %1 : tensor<1xi1> } + +// ----- + +// CHECK-LABEL: @fold_resize_nearest +func.func @fold_resize_nearest(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x1xi8> { + // return %arg0 + %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [2, 2, 1, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x15x13x1xi8>) -> tensor<1x15x13x1xi8> + return %resize : tensor<1x15x13x1xi8> +} + +// ----- + +// CHECK-LABEL: @fold_resize_bilinear +func.func @fold_resize_bilinear(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x1xi8> { + // return %arg0 + %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [2, 2, 1, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x15x13x1xi8>) -> tensor<1x15x13x1xi8> + return %resize : tensor<1x15x13x1xi8> +}