diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -918,6 +918,8 @@ let results = (outs Tosa_Tensor:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1027,3 +1027,13 @@ return {}; } + +OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) { + auto input = getInput1(); + // Element-wise abs(abs(x)) = abs(x) + if (auto op = input.getDefiningOp()) { + return input; + } + + return {}; +} diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -536,3 +536,13 @@ return %1 : tensor } +// ----- + +// CHECK-LABEL: @fold_abs_abs +func.func @fold_abs_abs(%arg0: tensor) -> tensor { + // CHECK: %[[ABS:.*]] = "tosa.abs"(%arg{{.*}}) : (tensor) -> tensor + // CHECK: return %[[ABS]] : tensor + %0 = "tosa.abs"(%arg0) : (tensor) -> tensor + %1 = "tosa.abs"(%0) : (tensor) -> tensor + return %1 : tensor +}