diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -999,6 +999,8 @@ let results = (outs Tosa_Tensor:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1006,3 +1006,13 @@ return {}; } + +OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) { + auto input = getInput1(); + // Element-wise exp(log(x)) = x + if (auto op = input.getDefiningOp()) { + return op.getInput1(); + } + + return {}; +} diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -515,3 +515,13 @@ %1 = "tosa.log"(%0) : (tensor) -> tensor return %1 : tensor } + +// ----- + +// CHECK-LABEL: @fold_exp_log +func.func @fold_exp_log(%arg0: tensor) -> tensor { + // CHECK: return %arg{{.*}} : tensor + %0 = "tosa.log"(%arg0) : (tensor) -> tensor + %1 = "tosa.exp"(%0) : (tensor) -> tensor + return %1 : tensor +}