diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1043,6 +1043,8 @@ let results = (outs Tosa_Tensor:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -996,3 +996,13 @@ return getInput1(); } + +OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) { + auto input = getInput1(); + // Element-wise log(exp(x)) = x + if (auto op = input.getDefiningOp()) { + return op.getInput1(); + } + + return {}; +} diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -505,3 +505,13 @@ %2 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x3x12xf32> return %1, %2 : tensor<1x6x12xf32>, tensor<1x3x12xf32> } + +// ----- + +// CHECK-LABEL +func.func @fold_log_exp(%arg0: tensor) -> tensor { + // CHECK: return %arg{{.*}} : tensor + %0 = "tosa.exp"(%arg0) : (tensor) -> tensor + %1 = "tosa.log"(%0) : (tensor) -> tensor + return %1 : tensor +}